#!/usr/bin/env python3
"""
Merge multiple tweet CSVs into one, deduplicating by tweet_id.

Usage:
  tweet-csv-merge <csv1> <csv2> [csv3...] [-o <output>]

Examples:
  tweet-csv-merge day1.csv day2.csv -o week.csv
  tweet-csv-merge PHASES/PHASE-01-INGEST/tweets/archive/*/raw_feed.csv -o all_tweets.csv
"""

import argparse
import csv
import sys
from pathlib import Path
from datetime import datetime


def safe_str(value):
    """Safely convert value to string, handling None."""
    if value is None:
        return ''
    return str(value)


def normalize_key(key):
    """Normalize a key for case-insensitive, whitespace-insensitive lookup."""
    if key is None:
        return ''
    return str(key).strip().casefold()


def build_key_map(fieldnames):
    """
    Build a mapping from normalized keys to original fieldnames.
    """
    key_map = {}
    for fn in fieldnames:
        norm = normalize_key(fn)
        if norm not in key_map:  # First occurrence wins
            key_map[norm] = fn
    return key_map


def get_field(row, key_map, *candidates):
    """
    Get a field value trying multiple candidate names (normalized).
    Returns first non-empty match.
    """
    for candidate in candidates:
        norm = normalize_key(candidate)
        if norm in key_map:
            real_key = key_map[norm]
            value = safe_str(row.get(real_key, ''))
            if value:
                return value
    return ''


def main():
    parser = argparse.ArgumentParser(
        description='Merge multiple tweet CSVs, deduplicating by tweet_id',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  tweet-csv-merge day1.csv day2.csv -o week.csv
  tweet-csv-merge feed1.csv feed2.csv feed3.csv
  tweet-csv-merge PHASES/PHASE-01-INGEST/tweets/archive/2025-01-*/raw_feed.csv -o january.csv
        """
    )
    parser.add_argument('csv_files', nargs='+', help='CSV files to merge (shell expands globs)')
    parser.add_argument('-o', '--output', help='Output file (default: merged_TIMESTAMP.csv)')
    parser.add_argument('--keep', choices=['first', 'last'], default='first',
                        help='When duplicate tweet_id found, keep first or last occurrence (default: first)')
    parser.add_argument('--no-sort', action='store_true',
                        help='Disable sorting (faster, preserves input order)')

    args = parser.parse_args()

    # Validate input files
    valid_files = []
    for csv_file in args.csv_files:
        path = Path(csv_file)
        if path.exists():
            valid_files.append(path)
        else:
            print(f"Warning: Skipping non-existent file: {csv_file}", file=sys.stderr)

    if not valid_files:
        print("Error: No valid input files found", file=sys.stderr)
        sys.exit(1)

    # Determine output path early
    if args.output:
        output_path = Path(args.output)
    else:
        output_path = Path(f"merged_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")

    # Check output not in input files
    output_resolved = output_path.resolve()
    for vf in valid_files:
        if vf.resolve() == output_resolved:
            print(f"Error: Output file cannot be one of the input files: {vf}", file=sys.stderr)
            sys.exit(1)

    # Collect all rows, tracking by tweet_id for deduplication
    all_rows = {}  # tweet_id -> row
    all_fieldnames = []  # Union of all fieldnames, preserving order
    seen_fieldnames_norm = {}  # normalized -> original (first occurrence)
    files_processed = 0
    total_rows_read = 0
    rows_with_extra_fields = 0

    for csv_file in valid_files:
        try:
            # Use utf-8-sig to handle BOM automatically
            with open(csv_file, 'r', newline='', encoding='utf-8-sig') as f:
                reader = csv.DictReader(f)
                file_fieldnames = reader.fieldnames

                if not file_fieldnames:
                    print(f"Warning: Skipping file with no header: {csv_file}", file=sys.stderr)
                    continue

                # Build key map for this file
                file_key_map = build_key_map(file_fieldnames)

                # Union fieldnames, preserving order of first occurrence
                # Use normalized comparison but keep original name from first file
                for fn in file_fieldnames:
                    norm = normalize_key(fn)
                    if norm not in seen_fieldnames_norm:
                        seen_fieldnames_norm[norm] = fn
                        all_fieldnames.append(fn)

                file_count = 0
                for row in reader:
                    total_rows_read += 1
                    file_count += 1

                    # Check for extra fields (more columns than header)
                    if None in row:
                        rows_with_extra_fields += 1

                    # Get tweet_id using normalized lookup
                    tweet_id = get_field(row, file_key_map, 'tweet_id', 'id', 'id_str')

                    if not tweet_id:
                        # No tweet_id, include it anyway with a generated key
                        key = f"__no_id_{total_rows_read}"
                        all_rows[key] = row
                    elif tweet_id not in all_rows:
                        all_rows[tweet_id] = row
                    elif args.keep == 'last':
                        all_rows[tweet_id] = row
                    # else: keep first (already in dict)

                files_processed += 1
                print(f"  Read {file_count} rows from {csv_file.name}")

        except UnicodeDecodeError as e:
            print(f"Warning: Encoding error reading {csv_file}: {e}", file=sys.stderr)
        except csv.Error as e:
            print(f"Warning: CSV error reading {csv_file}: {e}", file=sys.stderr)
        except IOError as e:
            print(f"Warning: IO error reading {csv_file}: {e}", file=sys.stderr)

    if not all_fieldnames:
        print("Error: No valid CSV headers found in any file", file=sys.stderr)
        sys.exit(1)

    if not all_rows:
        print("Warning: No data found in any input files", file=sys.stderr)

    if rows_with_extra_fields > 0:
        print(f"Warning: {rows_with_extra_fields} rows had more columns than header (extras ignored)", file=sys.stderr)

    # Build output key map for normalized field access
    output_key_map = build_key_map(all_fieldnames)

    # Sort by timestamp (newest first) unless --no-sort
    if args.no_sort:
        sorted_rows = list(all_rows.values())
    else:
        def sort_key(row):
            # Try to get timestamp field
            ts = get_field(row, output_key_map, 'timestamp', 'created_at', 'date')
            tweet_id = get_field(row, output_key_map, 'tweet_id', 'id', 'id_str')
            return (ts, tweet_id)

        sorted_rows = sorted(all_rows.values(), key=sort_key, reverse=True)

    # Ensure output directory exists (skip if current directory)
    output_dir = output_path.parent
    if str(output_dir) not in ('', '.') and not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)

    # Write output with proper quoting and handling of extra/missing keys
    try:
        with open(output_path, 'w', newline='', encoding='utf-8') as f:
            writer = csv.DictWriter(
                f,
                fieldnames=all_fieldnames,
                quoting=csv.QUOTE_MINIMAL,
                extrasaction='ignore'  # Ignore extra keys not in fieldnames
            )
            writer.writeheader()

            # Write rows, filling missing fields with empty string
            for row in sorted_rows:
                # Ensure all fieldnames have a value (empty string if missing)
                complete_row = {}
                for fn in all_fieldnames:
                    # Try to get the value using the original key first
                    value = row.get(fn, '')
                    if not value:
                        # Try normalized lookup for cross-file compatibility
                        norm = normalize_key(fn)
                        for row_key in row:
                            if normalize_key(row_key) == norm:
                                value = row.get(row_key, '')
                                break
                    complete_row[fn] = safe_str(value)
                writer.writerow(complete_row)

    except IOError as e:
        print(f"Error: Failed to write output file: {e}", file=sys.stderr)
        sys.exit(1)

    duplicates = total_rows_read - len(sorted_rows)

    print(f"\nMerge complete:")
    print(f"  Files:      {files_processed}")
    print(f"  Total rows: {total_rows_read}")
    print(f"  Duplicates: {duplicates}")
    print(f"  Unique:     {len(sorted_rows)}")
    print(f"  Columns:    {len(all_fieldnames)}")
    print(f"  Output:     {output_path}")


if __name__ == '__main__':
    main()
