#!/usr/bin/env python3
"""
Filter tweets from a CSV by:
  - username (name)
  - text/image_description content (content)
  - engagement metrics only (metrics)
  - OR combine name/content with metrics constraints.

Backwards compatible with existing behavior, plus:
  - Operator-based metric constraints (repeatable)
  - Multiple constraints in one arg
  - Range syntax (A-B or A..B)
  - Human numbers (1,200  1.2k  3M)
  - Fail-fast contradictions (bulletproof, no silent empty results)

USAGE (existing):
  tweet-filter name <username> <csv-file> [-o <output>] [METRIC FLAGS...]
  tweet-filter content <search-string> <csv-file> [-o <output>] [METRIC FLAGS...]

NEW:
  tweet-filter metrics <csv-file> [-o <output>] [METRIC FLAGS...]

METRIC FLAGS (apply to ALL modes):
  Operator expressions (repeatable):
    --likes EXPR
    --retweets EXPR
    --replies EXPR

    EXPR supports:
      - Comparisons: >=500, >500, <=500, <500, =500, ==500, !=500
      - Multiple in one arg: ">=500 <=2000 !=1337"
      - Ranges: 500-2000 or 500..2000 (inclusive)

  Legacy (still supported):
    --min-likes N      --max-likes N
    --min-retweets N   --max-retweets N
    --min-replies N    --max-replies N

Examples:
  tweet-filter name @sama raw_feed.csv
  tweet-filter content "claude code" with_images.csv

  # Multiple filters at once:
  tweet-filter content "agents" raw_feed.csv --likes ">=200" --retweets ">=30" --replies "<=50"

  # Repeat operator flags:
  tweet-filter metrics raw_feed.csv --likes ">=500" --likes "<=2000" --likes "!=1337"

  # Range:
  tweet-filter metrics raw_feed.csv --likes "500-2000"
"""

import argparse
import csv
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple


# -------------------------
# Generic helpers
# -------------------------

def safe_str(value) -> str:
    if value is None:
        return ''
    return str(value)


def normalize_key(key) -> str:
    if key is None:
        return ''
    return str(key).strip().casefold()


def build_key_map(fieldnames) -> Dict[str, str]:
    """
    Build normalized-key -> original-fieldname map.
    First occurrence wins.
    """
    key_map = {}
    for fn in fieldnames:
        norm = normalize_key(fn)
        if norm not in key_map:
            key_map[norm] = fn
    return key_map


def get_field(row, key_map, *candidates) -> str:
    """
    Get first non-empty field among candidate names (case-insensitive).
    """
    for candidate in candidates:
        norm = normalize_key(candidate)
        if norm in key_map:
            real_key = key_map[norm]
            value = safe_str(row.get(real_key, ''))
            if value:
                return value
    return ''


def sanitize_filename(s, max_length=120) -> str:
    """
    Sanitize string for filenames.
    Keeps alphanumeric, underscore, hyphen, and @.
    """
    sanitized = re.sub(r'[^\w\-@]', '_', s)
    sanitized = re.sub(r'_+', '_', sanitized).strip('_')
    if len(sanitized) > max_length:
        sanitized = sanitized[:max_length].rstrip('_')
    return sanitized or 'filtered'


# -------------------------
# Existing matchers
# -------------------------

def matches_name(row, key_map, username_norm: str) -> bool:
    row_username = get_field(row, key_map, 'username', 'user', 'screen_name')
    row_username = row_username.strip().lstrip('@').casefold()
    return row_username == username_norm


def matches_content(row, key_map, search_norm: str) -> bool:
    text = get_field(row, key_map, 'content', 'text', 'tweet_text').casefold()
    image_desc = get_field(row, key_map, 'image_description', 'description').casefold()
    return search_norm in text or search_norm in image_desc


# -------------------------
# NEW: robust numeric parsing (data + constraints)
# -------------------------

def parse_human_int(raw: str) -> Optional[int]:
    """
    Parse an integer from common forms:
      "1234", "1,234", "1.2k", "3M", "4.7b"
    Returns None if empty or unparseable.
    """
    s = (raw or "").strip()
    if not s:
        return None

    s = s.replace(",", "").replace("_", "")
    s_low = s.lower()

    mult = 1
    if s_low.endswith("k"):
        mult = 1_000
        s_low = s_low[:-1]
    elif s_low.endswith("m"):
        mult = 1_000_000
        s_low = s_low[:-1]
    elif s_low.endswith("b"):
        mult = 1_000_000_000
        s_low = s_low[:-1]

    # Accept ints or floats (float only meaningful with suffix, but we allow if integer-like)
    try:
        if "." in s_low:
            val = float(s_low)
        else:
            val = int(s_low)
    except ValueError:
        return None

    if val < 0:
        return None

    out = int(val * mult)
    return out


def get_metric(row, key_map, metric: str) -> Optional[int]:
    if metric == "likes":
        raw = get_field(row, key_map, "likes", "favorite_count", "favourite_count", "favorites", "favourites")
    elif metric == "retweets":
        raw = get_field(row, key_map, "retweets", "retweet_count", "rts")
    elif metric == "replies":
        raw = get_field(row, key_map, "replies", "reply_count", "comments")
    else:
        return None
    return parse_human_int(raw)


# -------------------------
# Operator constraints parsing + bulletproof validation
# -------------------------

@dataclass(frozen=True)
class Bound:
    value: int
    inclusive: bool  # True for >= / <=, False for > / <


@dataclass(frozen=True)
class MetricPredicate:
    eq: Optional[int]                # if set, must equal
    ne: frozenset[int]               # must NOT be these
    lower: Optional[Bound]           # min bound
    upper: Optional[Bound]           # max bound


@dataclass(frozen=True)
class MetricFilters:
    likes: Optional[MetricPredicate]
    retweets: Optional[MetricPredicate]
    replies: Optional[MetricPredicate]

    def any_active(self) -> bool:
        return any([self.likes, self.retweets, self.replies])

    def signature(self) -> str:
        """
        Deterministic short-ish signature for output filename.
        """
        parts = []
        for name, pred in (("likes", self.likes), ("rt", self.retweets), ("rep", self.replies)):
            if not pred:
                continue
            parts.append(f"{name}_{predicate_signature(pred)}")
        return "_".join(parts) if parts else ""


TOKEN_RE = re.compile(r"(>=|<=|==|!=|>|<|=|\.\.|-|\!|\d[\d,]*(?:\.\d+)?(?:[kKmMbB])?)")


def is_number_token(tok: str) -> bool:
    return bool(re.fullmatch(r"\d[\d,]*(?:\.\d+)?(?:[kKmMbB])?", tok))


def parse_metric_expressions(exprs: List[str], metric_name: str) -> List[Tuple[str, int]]:
    """
    Parse a list of expressions into a flat list of (op, value) conditions.

    Supports:
      - comparisons: >=500, !=100
      - operator separated by whitespace: ">= 500"
      - multiple in one arg: ">=500 <=2000 !=1337"
      - ranges: 500-2000 or 500..2000 (inclusive)
      - '!' as shorthand for '!=' when separated: "! 500" or "!500" (best effort)
    """
    conditions: List[Tuple[str, int]] = []

    for expr in exprs:
        raw = (expr or "").strip()
        if not raw:
            continue

        # Reject garbage characters (bulletproof): remove known tokens; leftover must be only separators.
        cleaned = TOKEN_RE.sub("", raw)
        cleaned = re.sub(r"[,\s]+", "", cleaned)
        if cleaned:
            raise ValueError(f"{metric_name}: invalid characters in expression: '{expr}'")

        tokens = [t for t in TOKEN_RE.findall(raw) if t and not t.isspace()]

        i = 0
        while i < len(tokens):
            t = tokens[i]

            # Shorthand "!" token (not "!=")
            if t == "!":
                if i + 1 >= len(tokens) or not is_number_token(tokens[i + 1]):
                    raise ValueError(f"{metric_name}: '!' must be followed by a number in '{expr}'")
                val = parse_human_int(tokens[i + 1])
                if val is None:
                    raise ValueError(f"{metric_name}: invalid number after '!': '{tokens[i + 1]}'")
                conditions.append(("!=", val))
                i += 2
                continue

            # Comparison operator
            if t in (">=", "<=", "==", "!=", ">", "<", "="):
                if i + 1 >= len(tokens) or not is_number_token(tokens[i + 1]):
                    raise ValueError(f"{metric_name}: operator '{t}' must be followed by a number in '{expr}'")
                op = "==" if t == "=" else t
                val = parse_human_int(tokens[i + 1])
                if val is None:
                    raise ValueError(f"{metric_name}: invalid number '{tokens[i + 1]}' in '{expr}'")
                conditions.append((op, val))
                i += 2
                continue

            # Number token (could be range or bare equality)
            if is_number_token(t):
                # Range: NUM - NUM  or NUM .. NUM
                if i + 2 < len(tokens) and tokens[i + 1] in ("-", "..") and is_number_token(tokens[i + 2]):
                    a = parse_human_int(t)
                    b = parse_human_int(tokens[i + 2])
                    if a is None or b is None:
                        raise ValueError(f"{metric_name}: invalid range numbers in '{expr}'")
                    lo, hi = (a, b) if a <= b else (b, a)
                    conditions.append((">=", lo))
                    conditions.append(("<=", hi))
                    i += 3
                    continue

                # Bare number => equality
                val = parse_human_int(t)
                if val is None:
                    raise ValueError(f"{metric_name}: invalid number '{t}' in '{expr}'")
                conditions.append(("==", val))
                i += 1
                continue

            # Unexpected token (should not happen due to garbage check)
            raise ValueError(f"{metric_name}: could not parse token '{t}' in '{expr}'")

    return conditions


def canonicalize_conditions(metric_name: str, conds: List[Tuple[str, int]]) -> Optional[MetricPredicate]:
    """
    Convert raw conditions into a simplified, validated predicate.
    Fail-fast on contradictions (bulletproof).
    """
    if not conds:
        return None

    eq_values = set()
    ne_values = set()
    lower: Optional[Bound] = None
    upper: Optional[Bound] = None

    for op, v in conds:
        if v < 0:
            raise ValueError(f"{metric_name}: negative values are not allowed")

        if op == "==":
            eq_values.add(v)
        elif op == "!=":
            ne_values.add(v)
        elif op == ">=":
            b = Bound(v, True)
            if lower is None or (b.value > lower.value) or (b.value == lower.value and lower.inclusive is False and b.inclusive is True):
                # Note: ">=" is weaker than ">" at same value; keep the stricter one.
                # We'll handle strictness below with normalization; simplest: choose higher value; if equal choose exclusive if any.
                lower = b if (lower is None or b.value > lower.value) else lower
                if lower and lower.value == v:
                    # If existing lower is same value but exclusive, keep exclusive.
                    pass
        elif op == ">":
            b = Bound(v, False)
            if lower is None or b.value > lower.value or (b.value == lower.value and lower.inclusive is True):
                lower = b
        elif op == "<=":
            b = Bound(v, True)
            if upper is None or b.value < upper.value or (b.value == upper.value and upper.inclusive is False and b.inclusive is True):
                upper = b if (upper is None or b.value < upper.value) else upper
                if upper and upper.value == v:
                    pass
        elif op == "<":
            b = Bound(v, False)
            if upper is None or b.value < upper.value or (b.value == upper.value and upper.inclusive is True):
                upper = b
        else:
            raise ValueError(f"{metric_name}: unsupported operator '{op}'")

    # Bulletproof: multiple distinct equality constraints is almost certainly user error.
    if len(eq_values) > 1:
        raise ValueError(f"{metric_name}: multiple equality constraints given {sorted(eq_values)} (AND semantics makes this impossible).")

    eq = next(iter(eq_values)) if eq_values else None

    # Validate contradictions with bounds + ne
    def min_allowed() -> int:
        if lower is None:
            return 0
        return lower.value if lower.inclusive else lower.value + 1

    def max_allowed() -> int:
        if upper is None:
            return 2**63 - 1
        return upper.value if upper.inclusive else upper.value - 1

    min_a = min_allowed()
    max_a = max_allowed()
    if min_a > max_a:
        raise ValueError(f"{metric_name}: contradictory bounds (no integer can satisfy).")

    if eq is not None:
        if eq < min_a or eq > max_a:
            raise ValueError(f"{metric_name}: equality {eq} contradicts bounds.")
        if eq in ne_values:
            raise ValueError(f"{metric_name}: equality {eq} contradicts !={eq}.")

    # Edge case: bounds collapse to a single value, but it's excluded by ne
    if eq is None and min_a == max_a and min_a in ne_values:
        raise ValueError(f"{metric_name}: only possible value {min_a} is excluded by !={min_a}.")

    return MetricPredicate(eq=eq, ne=frozenset(ne_values), lower=lower, upper=upper)


def predicate_signature(p: MetricPredicate) -> str:
    parts = []

    if p.eq is not None:
        parts.append(f"eq{p.eq}")
    else:
        if p.lower is not None:
            parts.append(("gte" if p.lower.inclusive else "gt") + str(p.lower.value))
        if p.upper is not None:
            parts.append(("lte" if p.upper.inclusive else "lt") + str(p.upper.value))

    if p.ne:
        # Keep signature short; include up to 3 explicit ne values, else neN
        ne_sorted = sorted(p.ne)
        if len(ne_sorted) <= 3:
            parts.extend([f"ne{v}" for v in ne_sorted])
        else:
            parts.append(f"ne{len(ne_sorted)}")

    return "_".join(parts) if parts else "any"


def metrics_pass_value(v: int, p: MetricPredicate) -> bool:
    if p.eq is not None:
        if v != p.eq:
            return False
    else:
        if p.lower is not None:
            if p.lower.inclusive:
                if v < p.lower.value:
                    return False
            else:
                if v <= p.lower.value:
                    return False
        if p.upper is not None:
            if p.upper.inclusive:
                if v > p.upper.value:
                    return False
            else:
                if v >= p.upper.value:
                    return False

    if p.ne and v in p.ne:
        return False

    return True


def metrics_pass(row, key_map, mf: MetricFilters) -> bool:
    """
    Strict behavior:
      If constraints exist for a metric but metric is missing/unparseable => FAIL.
    """
    if mf.likes:
        v = get_metric(row, key_map, "likes")
        if v is None or not metrics_pass_value(v, mf.likes):
            return False

    if mf.retweets:
        v = get_metric(row, key_map, "retweets")
        if v is None or not metrics_pass_value(v, mf.retweets):
            return False

    if mf.replies:
        v = get_metric(row, key_map, "replies")
        if v is None or not metrics_pass_value(v, mf.replies):
            return False

    return True


def build_default_output_path(mode: str, search: Optional[str], input_path: Path, mf: MetricFilters) -> Path:
    sig = mf.signature()
    if mode == "name":
        safe_search = sanitize_filename(search or "")
        if not safe_search.startswith("@"):
            safe_search = f"@{safe_search.lstrip('@')}"
        base = f"tweets-from-{safe_search}"
        if sig:
            base += f"__{sig}"
        return input_path.parent / f"{sanitize_filename(base)}.csv"

    if mode == "content":
        safe_search = sanitize_filename(search or "")
        base = f"tweets-containing-{safe_search}"
        if sig:
            base += f"__{sig}"
        return input_path.parent / f"{sanitize_filename(base)}.csv"

    # metrics-only
    base = f"tweets-metrics__{sig or 'no_constraints'}"
    return input_path.parent / f"{sanitize_filename(base)}.csv"


def add_metric_args(p: argparse.ArgumentParser) -> None:
    """
    Operator-style constraints are repeatable.
    Legacy min/max still supported.
    """
    p.add_argument('--likes', action='append', default=[], help='Likes constraint expr (repeatable). e.g. ">=500", "500-2000", "!=1337", ">=500 <=2000"')
    p.add_argument('--retweets', action='append', default=[], help='Retweets constraint expr (repeatable).')
    p.add_argument('--replies', action='append', default=[], help='Replies constraint expr (repeatable).')

    # Legacy range flags (still supported)
    p.add_argument('--min-likes', type=int, default=None, help='Minimum likes (legacy)')
    p.add_argument('--max-likes', type=int, default=None, help='Maximum likes (legacy)')
    p.add_argument('--min-retweets', type=int, default=None, help='Minimum retweets (legacy)')
    p.add_argument('--max-retweets', type=int, default=None, help='Maximum retweets (legacy)')
    p.add_argument('--min-replies', type=int, default=None, help='Minimum replies (legacy)')
    p.add_argument('--max-replies', type=int, default=None, help='Maximum replies (legacy)')


def build_metric_filters(args) -> MetricFilters:
    """
    Combine operator constraints + legacy min/max into canonical predicates.
    Fail-fast on contradictions.
    """
    def legacy_to_conds(min_v: Optional[int], max_v: Optional[int]) -> List[Tuple[str, int]]:
        out = []
        if min_v is not None:
            if min_v < 0:
                raise ValueError("legacy min cannot be negative")
            out.append((">=", min_v))
        if max_v is not None:
            if max_v < 0:
                raise ValueError("legacy max cannot be negative")
            out.append(("<=", max_v))
        return out

    likes_conds = parse_metric_expressions(args.likes, "likes") + legacy_to_conds(args.min_likes, args.max_likes)
    rt_conds = parse_metric_expressions(args.retweets, "retweets") + legacy_to_conds(args.min_retweets, args.max_retweets)
    rep_conds = parse_metric_expressions(args.replies, "replies") + legacy_to_conds(args.min_replies, args.max_replies)

    likes_pred = canonicalize_conditions("likes", likes_conds)
    rt_pred = canonicalize_conditions("retweets", rt_conds)
    rep_pred = canonicalize_conditions("replies", rep_conds)

    return MetricFilters(likes=likes_pred, retweets=rt_pred, replies=rep_pred)


# -------------------------
# CLI + streaming filter
# -------------------------

def main() -> int:
    parser = argparse.ArgumentParser(
        description='Filter tweets from CSV by username, content, and/or engagement metrics (operator constraints supported)',
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    subparsers = parser.add_subparsers(dest='mode', required=True)

    p_name = subparsers.add_parser('name', help='Filter by username (optionally also by metrics)')
    p_name.add_argument('search', help='Username to match (with or without @)')
    p_name.add_argument('csv_file', help='Input CSV file')
    p_name.add_argument('-o', '--output', help='Output file (default: auto-generated name)')
    add_metric_args(p_name)

    p_content = subparsers.add_parser('content', help='Filter by content/image_description (optionally also by metrics)')
    p_content.add_argument('search', help='Text to search for (case-insensitive)')
    p_content.add_argument('csv_file', help='Input CSV file')
    p_content.add_argument('-o', '--output', help='Output file (default: auto-generated name)')
    add_metric_args(p_content)

    p_metrics = subparsers.add_parser('metrics', help='Filter by metrics only (likes/retweets/replies)')
    p_metrics.add_argument('csv_file', help='Input CSV file')
    p_metrics.add_argument('-o', '--output', help='Output file (default: auto-generated name)')
    add_metric_args(p_metrics)

    args = parser.parse_args()

    # Normalize search
    search_normalized: Optional[str] = None
    if args.mode == 'name':
        search_normalized = args.search.strip().lstrip('@').casefold()
        if not search_normalized:
            print("Error: Username cannot be empty", file=sys.stderr)
            return 1
    elif args.mode == 'content':
        search_normalized = args.search.casefold()
        if not search_normalized.strip():
            print("Error: Search string cannot be empty", file=sys.stderr)
            return 1

    # Validate input
    input_path = Path(args.csv_file)
    if not input_path.exists():
        print(f"Error: File not found: {args.csv_file}", file=sys.stderr)
        return 1

    # Build metric filters (fail-fast on contradictions)
    try:
        mf = build_metric_filters(args)
    except ValueError as e:
        print(f"Error: {e}", file=sys.stderr)
        return 1

    # Metrics-only mode MUST have at least one active constraint
    if args.mode == 'metrics' and not mf.any_active():
        print("Error: metrics mode requires at least one constraint, e.g. --likes '>=100'", file=sys.stderr)
        return 1

    # Determine output path early
    if args.output:
        output_path = Path(args.output)
    else:
        output_path = build_default_output_path(args.mode, getattr(args, "search", None), input_path, mf)

    if output_path.resolve() == input_path.resolve():
        print("Error: Output file cannot be the same as input file", file=sys.stderr)
        return 1

    output_dir = output_path.parent
    if output_dir and str(output_dir) != '.' and not output_dir.exists():
        output_dir.mkdir(parents=True, exist_ok=True)

    input_count = 0
    output_count = 0

    try:
        with open(input_path, 'r', newline='', encoding='utf-8-sig') as infile:
            reader = csv.DictReader(infile)
            fieldnames = reader.fieldnames
            if not fieldnames:
                print("Error: CSV file has no header row", file=sys.stderr)
                return 1

            key_map = build_key_map(fieldnames)

            with open(output_path, 'w', newline='', encoding='utf-8') as outfile:
                writer = csv.DictWriter(
                    outfile,
                    fieldnames=fieldnames,
                    quoting=csv.QUOTE_MINIMAL,
                    extrasaction='ignore'
                )
                writer.writeheader()

                for row in reader:
                    input_count += 1

                    # Base match
                    if args.mode == 'name':
                        if not matches_name(row, key_map, search_normalized):  # type: ignore[arg-type]
                            continue
                    elif args.mode == 'content':
                        if not matches_content(row, key_map, search_normalized):  # type: ignore[arg-type]
                            continue
                    else:
                        # metrics-only has no base string filter
                        pass

                    # Metrics constraints (if any)
                    if mf.any_active() and not metrics_pass(row, key_map, mf):
                        continue

                    writer.writerow(row)
                    output_count += 1

    except UnicodeDecodeError as e:
        print(f"Error: Failed to decode file (check encoding): {e}", file=sys.stderr)
        return 1
    except csv.Error as e:
        print(f"Error: Failed to parse CSV: {e}", file=sys.stderr)
        return 1
    except IOError as e:
        print(f"Error: File I/O error: {e}", file=sys.stderr)
        return 1

    # Summary
    search_term = getattr(args, 'search', None)
    base_desc = {
        "name": f"username '{search_term}'",
        "content": f"content containing '{search_term}'",
        "metrics": "metrics-only",
    }[args.mode]

    sig = mf.signature()
    if sig:
        base_desc += f" + {sig}"

    print(f"Filtered by {base_desc}")
    print(f"  Input:  {input_count} tweets")
    print(f"  Output: {output_count} tweets")
    print(f"  File:   {output_path}")

    return 0


if __name__ == '__main__':
    raise SystemExit(main())
