"""Replay scenario NDJSON files to Fabric Eventstream, Event Hubs, or stdout.

Usage examples:
    # Stdout dry run
    python eventstream_replay.py --scenario 01-ais-dark-near-cable \
        --target stdout --stream ais --pace asfast

    # Fabric Eventstream custom endpoint (HTTP)
    python eventstream_replay.py --scenario 02-ship-to-ship-rendezvous \
        --target eventstream --endpoint-url "https://<eventstream-host>/..." \
        --pace accelerated 10

    # Event Hubs with Managed Identity / DefaultAzureCredential
    python eventstream_replay.py --scenario 03-loitering-critical-infra \
        --target eventhub --namespace my-ns.servicebus.windows.net --hub r-mac-events \
        --pace realtime --loop
"""

from __future__ import annotations

import argparse
import heapq
import json
import signal
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable, Iterator


def _is_meta_line(line: str) -> bool:
    """Robust disclaimer/meta detector — tolerates whitespace differences
    produced by different JSON encoders. See generators.common.is_meta_record."""
    s = line.lstrip()
    return s.startswith('{"__meta__"') or s.startswith('{"__meta__":')


META_PREFIX = '{"__meta__"'
DEFAULT_STREAMS = ("ais", "mac", "plane_radar", "drone_radar", "coastal_radar")

_stop_requested = False


def _install_signal_handlers() -> None:
    def _handler(signum, frame):  # noqa: ARG001
        global _stop_requested
        _stop_requested = True
        print("[replay] stop requested, finishing current event...", file=sys.stderr)

    try:
        signal.signal(signal.SIGINT, _handler)
    except (ValueError, OSError):
        pass
    if hasattr(signal, "SIGTERM"):
        try:
            signal.signal(signal.SIGTERM, _handler)
        except (ValueError, OSError):
            pass


def _scenario_dir(scenario: str) -> Path:
    here = Path(__file__).resolve().parent
    root = here.parent
    return root / "scenarios" / scenario / "data" / "realtime"


def _discover_streams(scenario: str, requested: list[str] | None) -> list[tuple[str, Path]]:
    base = _scenario_dir(scenario)
    if not base.is_dir():
        raise FileNotFoundError(f"Scenario realtime dir not found: {base}")
    candidates = requested if requested else list(DEFAULT_STREAMS)
    found: list[tuple[str, Path]] = []
    for name in candidates:
        p = base / f"{name}.ndjson"
        if p.is_file():
            found.append((name, p))
        elif requested:
            raise FileNotFoundError(f"Requested stream not found: {p}")
    if not found:
        raise FileNotFoundError(f"No NDJSON streams found under {base}")
    return found


def _resolve_event_time(evt: dict[str, Any]) -> datetime | None:
    for key in ("timestamp", "processingTimestamp"):
        v = evt.get(key)
        if isinstance(v, str):
            try:
                s = v.replace("Z", "+00:00")
                dt = datetime.fromisoformat(s)
                if dt.tzinfo is None:
                    dt = dt.replace(tzinfo=timezone.utc)
                return dt.astimezone(timezone.utc)
            except ValueError:
                pass
    v = evt.get("ts_epoch_ms")
    if isinstance(v, (int, float)):
        return datetime.fromtimestamp(v / 1000.0, tz=timezone.utc)
    return None


def _iter_stream_file(stream: str, path: Path) -> Iterator[tuple[datetime, str, dict[str, Any]]]:
    with path.open("r", encoding="utf-8") as fh:
        for line_no, line in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue
            if line_no == 1 and _is_meta_line(line):
                continue
            if _is_meta_line(line):
                continue
            try:
                evt = json.loads(line)
            except json.JSONDecodeError as exc:
                print(f"[replay] skip bad line {path}:{line_no}: {exc}", file=sys.stderr)
                continue
            ts = _resolve_event_time(evt)
            if ts is None:
                continue
            yield ts, stream, evt


def _parse_iso(s: str | None) -> datetime | None:
    if not s:
        return None
    dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc)


def _merged_events(
    streams: list[tuple[str, Path]],
    start: datetime | None,
    end: datetime | None,
) -> Iterator[tuple[datetime, str, dict[str, Any]]]:
    iters = []
    for name, path in streams:
        def _filter(it=_iter_stream_file(name, path)):
            for ts, n, e in it:
                if start is not None and ts < start:
                    continue
                if end is not None and ts > end:
                    return
                yield ts, n, e
        iters.append(_filter())
    yield from heapq.merge(*iters, key=lambda x: x[0])


class _Sink:
    def send(self, payload: dict[str, Any]) -> None:  # pragma: no cover
        raise NotImplementedError

    def close(self) -> None:
        pass


class _StdoutSink(_Sink):
    def send(self, payload: dict[str, Any]) -> None:
        sys.stdout.write(json.dumps(payload, separators=(",", ":")) + "\n")
        sys.stdout.flush()


class _EventstreamHttpSink(_Sink):
    def __init__(self, endpoint_url: str) -> None:
        import requests  # lazy

        self._requests = requests
        self._url = endpoint_url
        self._session = requests.Session()

    def send(self, payload: dict[str, Any]) -> None:
        resp = self._session.post(
            self._url,
            data=json.dumps(payload),
            headers={"Content-Type": "application/json"},
            timeout=30,
        )
        if resp.status_code >= 300:
            raise RuntimeError(
                f"Eventstream POST failed {resp.status_code}: {resp.text[:200]}"
            )

    def close(self) -> None:
        self._session.close()


class _EventHubSink(_Sink):
    def __init__(self, namespace: str, hub: str) -> None:
        from azure.eventhub import EventData, EventHubProducerClient  # lazy
        from azure.identity import DefaultAzureCredential  # lazy

        self._EventData = EventData
        cred = DefaultAzureCredential()
        if "." not in namespace:
            fqdn = f"{namespace}.servicebus.windows.net"
        else:
            fqdn = namespace
        self._producer = EventHubProducerClient(
            fully_qualified_namespace=fqdn,
            eventhub_name=hub,
            credential=cred,
        )

    def send(self, payload: dict[str, Any]) -> None:
        batch = self._producer.create_batch()
        batch.add(self._EventData(json.dumps(payload)))
        self._producer.send_batch(batch)

    def close(self) -> None:
        try:
            self._producer.close()
        except Exception:  # pragma: no cover
            pass


def _build_sink(args: argparse.Namespace) -> _Sink:
    if args.target == "stdout":
        return _StdoutSink()
    if args.target == "eventstream":
        if not args.endpoint_url:
            raise SystemExit("--endpoint-url is required for --target eventstream")
        return _EventstreamHttpSink(args.endpoint_url)
    if args.target == "eventhub":
        if not args.namespace or not args.hub:
            raise SystemExit("--namespace and --hub are required for --target eventhub")
        return _EventHubSink(args.namespace, args.hub)
    raise SystemExit(f"Unknown target: {args.target}")


def _pace_sleep(
    pace: str,
    factor: float,
    first_event_ts: datetime,
    wall_start: float,
    event_ts: datetime,
) -> None:
    if pace == "asfast":
        return
    scenario_elapsed = (event_ts - first_event_ts).total_seconds()
    if pace == "accelerated" and factor > 0:
        scenario_elapsed = scenario_elapsed / factor
    wall_elapsed = time.time() - wall_start
    delay = scenario_elapsed - wall_elapsed
    if delay > 0:
        time.sleep(delay)


def _run_once(args: argparse.Namespace, sink: _Sink) -> int:
    streams = _discover_streams(args.scenario, args.stream)
    start = _parse_iso(args.start)
    end = _parse_iso(args.end)

    pace = args.pace[0] if isinstance(args.pace, list) else args.pace
    factor = 1.0
    if isinstance(args.pace, list) and len(args.pace) > 1:
        try:
            factor = float(args.pace[1])
        except ValueError as exc:
            raise SystemExit(f"Invalid accelerated factor: {args.pace[1]}") from exc

    print(
        f"[replay] scenario={args.scenario} streams={[n for n,_ in streams]} "
        f"target={args.target} pace={pace}"
        + (f" factor={factor}" if pace == "accelerated" else ""),
        file=sys.stderr,
    )

    count = 0
    first_event_ts: datetime | None = None
    wall_start = time.time()

    for ts, name, evt in _merged_events(streams, start, end):
        if _stop_requested:
            break
        if first_event_ts is None:
            first_event_ts = ts
            wall_start = time.time()
        _pace_sleep(pace, factor, first_event_ts, wall_start, ts)

        evt_out = dict(evt)
        evt_out["_stream"] = name
        evt_out["_scenario"] = args.scenario
        evt_out["_emit_ts"] = datetime.now(timezone.utc).strftime(
            "%Y-%m-%dT%H:%M:%S.%fZ"
        )
        sink.send(evt_out)
        count += 1
        if count % 1000 == 0:
            print(
                f"[replay] emitted {count} events (last ts={ts.isoformat()})",
                file=sys.stderr,
            )

    print(f"[replay] done, emitted {count} events", file=sys.stderr)
    return count


def parse_args(argv: Iterable[str] | None = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Replay scenario NDJSON to Fabric streams")
    p.add_argument("--scenario", required=True, help="Scenario id, e.g. 01-ais-dark-near-cable")
    p.add_argument(
        "--stream",
        default=None,
        help="Comma-separated stream names (default: all NDJSON files in the scenario)",
    )
    p.add_argument(
        "--target",
        required=True,
        choices=("eventhub", "eventstream", "stdout"),
    )
    p.add_argument("--namespace", help="Event Hubs namespace (FQDN or short name)")
    p.add_argument("--hub", help="Event Hub name")
    p.add_argument("--endpoint-url", help="Fabric Eventstream custom endpoint URL")
    p.add_argument(
        "--pace",
        nargs="+",
        default=["asfast"],
        help="Pacing mode: realtime | accelerated <factor> | asfast",
    )
    p.add_argument("--start", help="UTC start ISO 8601 (inclusive)")
    p.add_argument("--end", help="UTC end ISO 8601 (inclusive)")
    p.add_argument("--loop", action="store_true", help="Loop indefinitely")

    ns = p.parse_args(list(argv) if argv is not None else None)

    if ns.stream:
        ns.stream = [s.strip() for s in ns.stream.split(",") if s.strip()]

    if isinstance(ns.pace, list):
        valid = {"realtime", "accelerated", "asfast"}
        if not ns.pace or ns.pace[0] not in valid:
            p.error(f"--pace must be one of {sorted(valid)}")
        if ns.pace[0] == "accelerated" and len(ns.pace) != 2:
            p.error("--pace accelerated requires a numeric factor, e.g. --pace accelerated 10")
        if ns.pace[0] != "accelerated" and len(ns.pace) != 1:
            p.error(f"--pace {ns.pace[0]} does not take extra args")

    return ns


def main(argv: Iterable[str] | None = None) -> int:
    args = parse_args(argv)
    _install_signal_handlers()
    sink = _build_sink(args)
    try:
        total = 0
        while True:
            total += _run_once(args, sink)
            if not args.loop or _stop_requested:
                break
            print("[replay] loop restart", file=sys.stderr)
        return 0 if total >= 0 else 1
    except KeyboardInterrupt:
        print("[replay] interrupted", file=sys.stderr)
        return 130
    finally:
        sink.close()


if __name__ == "__main__":
    sys.exit(main())
