#!/usr/bin/env python3
"""
BalticFusion · Fabric RTI streamer
===================================
Pushes events to a Fabric Eventstream custom HTTPS endpoint in one of three
switchable modes, selected by the ``MODE`` environment variable:

  replay  — replays pre-generated scenario NDJSON files in event-time order
  live    — polls real Digitraffic AIS + OpenSky aircraft + synthetic drones
  mixed   — replay base + real AIS/aircraft injected on top

Every event is a flat JSON object that is also GeoJSON-compatible:
  • ``lat`` / ``lon`` top-level float fields  → KQL table columns for Fabric Maps
  • ``geometry`` field (GeoJSON Point JSON string) → for geo-aware tooling
  • ``_stream`` routing key  → Eventstream filter → separate KQL tables
  • ``_mode``    → which mode produced this event
  • ``_emit_ts`` → wall-clock UTC at emission

Environment variables
---------------------
Required:
  EVENTSTREAM_URL        Fabric Eventstream custom-endpoint ingest URL

Optional (with defaults):
  MODE                   replay | live | mixed                (default: replay)
  SCENARIO_ID            e.g. 08-red-vessel-escalation       (default: 08-red-vessel-escalation, replay/mixed only)
  SPEED_MULTIPLIER       numeric, scenario-time speedup       (default: 1.0 → realtime)
                         0 = as-fast-as-possible, >1 = faster than real, <1 = slower
  LOOP                   1 | true  — loop replay indefinitely (default: 0, replay/mixed only)
  DRONE_COUNT            number of synthetic drones in live/mixed (default: 3)
  POLL_AIS_S             AIS poll interval in seconds         (default: 30)
  POLL_OSK_S             OpenSky poll interval in seconds     (default: 15)
  OSK_PROXY_URL          URL of the OpenSky proxy (default: http://localhost:8000/proxy/opensky/states/all?lamin=58&lomin=18&lamax=62&lomax=30)
  AIS_URL                Digitraffic AIS endpoint (default: https://meri.digitraffic.fi/api/ais/v1/locations)
  STREAMS                comma-separated NDJSON basenames to replay (default: all)
  DRY_RUN                1 | true  — print events to stdout instead of posting
  VERBOSE                1 | true  — print each event to stderr

Usage (container):
  docker run -e MODE=replay \\
             -e SCENARIO_ID=08-red-vessel-escalation \\
             -e SPEED_MULTIPLIER=10 \\
             -e LOOP=1 \\
             -e EVENTSTREAM_URL=https://... \\
             localhost/balticfusion:latest streamer

Usage (local):
  MODE=live EVENTSTREAM_URL=https://... python streaming/streamer.py
"""

from __future__ import annotations

import json
import math
import os
import random
import signal
import sys
import time
import threading
import urllib.parse
import urllib.request
import urllib.error
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# ---------------------------------------------------------------------------
# Config from environment
# ---------------------------------------------------------------------------

def _env_bool(key: str, default: bool = False) -> bool:
    v = os.environ.get(key, "").strip().lower()
    return v in ("1", "true", "yes") if v else default

def _env_float(key: str, default: float) -> float:
    try:
        return float(os.environ[key])
    except (KeyError, ValueError):
        return default

MODE           = os.environ.get("MODE", "replay").strip().lower()
SCENARIO_ID    = os.environ.get("SCENARIO_ID", "08-red-vessel-escalation").strip()
SPEED          = _env_float("SPEED_MULTIPLIER", 1.0)  # 0=asfast, 1=realtime, 10=10× faster
LOOP           = _env_bool("LOOP", False)
DRONE_COUNT    = int(_env_float("DRONE_COUNT", 3))
POLL_AIS_S     = int(_env_float("POLL_AIS_S", 30))
POLL_OSK_S     = int(_env_float("POLL_OSK_S", 15))
EVENTSTREAM_URL = os.environ.get("EVENTSTREAM_URL", "").strip()
DRY_RUN        = _env_bool("DRY_RUN", False)
VERBOSE        = _env_bool("VERBOSE", False)
STREAMS_ENV    = os.environ.get("STREAMS", "").strip()
OSK_PROXY_URL  = os.environ.get(
    "OSK_PROXY_URL",
    "http://localhost:8000/proxy/opensky/states/all?lamin=58&lomin=18&lamax=62&lomax=30",
)
AIS_URL        = os.environ.get(
    "AIS_URL",
    "https://meri.digitraffic.fi/api/ais/v1/locations",
)

# ---------------------------------------------------------------------------
# Global stop flag
# ---------------------------------------------------------------------------
_stop = False

def _install_signal_handlers() -> None:
    def _handler(signum, frame):
        global _stop
        _stop = True
        print("[streamer] stop requested", file=sys.stderr)
    for sig in (signal.SIGINT, signal.SIGTERM):
        try:
            signal.signal(sig, _handler)
        except (ValueError, OSError):
            pass

# ---------------------------------------------------------------------------
# Emitter — posts one event to Eventstream or stdout
# ---------------------------------------------------------------------------

def _emit(event: dict[str, Any]) -> None:
    body = json.dumps(event, separators=(",", ":")).encode("utf-8")
    if VERBOSE:
        print(json.dumps(event), file=sys.stderr)
    if DRY_RUN or not EVENTSTREAM_URL:
        sys.stdout.write(body.decode() + "\n")
        sys.stdout.flush()
        return
    req = urllib.request.Request(
        EVENTSTREAM_URL,
        data=body,
        headers={"Content-Type": "application/json"},
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=30):
            pass
    except urllib.error.HTTPError as e:
        msg = e.read(200).decode("utf-8", "replace")
        print(f"[streamer] POST failed {e.code}: {msg}", file=sys.stderr)
    except Exception as e:
        print(f"[streamer] POST error: {e}", file=sys.stderr)


def _now_iso() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")


def _geojson_point(lon: float, lat: float) -> str:
    return json.dumps({"type": "Point", "coordinates": [round(lon, 6), round(lat, 6)]},
                      separators=(",", ":"))


# ---------------------------------------------------------------------------
# Replay mode — wraps existing eventstream_replay logic
# ---------------------------------------------------------------------------

def _resolve_event_time(evt: dict[str, Any]) -> datetime | None:
    for key in ("timestamp", "processingTimestamp"):
        v = evt.get(key)
        if isinstance(v, str):
            try:
                dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
                return dt.astimezone(timezone.utc) if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
            except ValueError:
                pass
    v = evt.get("ts_epoch_ms")
    if isinstance(v, (int, float)):
        return datetime.fromtimestamp(v / 1000.0, tz=timezone.utc)
    return None


def _iter_ndjson(path: Path, stream_name: str):
    with path.open("r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            if i == 0 and line.startswith('{"__meta__"'):
                continue
            if line.startswith('{"__meta__"'):
                continue
            try:
                evt = json.loads(line)
            except json.JSONDecodeError:
                continue
            ts = _resolve_event_time(evt)
            if ts is None:
                continue
            yield ts, stream_name, evt


import heapq

def _merged_ndjson(scenario_id: str, stream_names: list[str]):
    base = REPO_ROOT / "scenarios" / scenario_id / "data" / "realtime"
    iters = []
    for name in stream_names:
        p = base / f"{name}.ndjson"
        if p.is_file():
            iters.append(_iter_ndjson(p, name))
        else:
            print(f"[streamer] warning: stream file not found: {p}", file=sys.stderr)
    if not iters:
        raise FileNotFoundError(f"No NDJSON files found in {base} for streams: {stream_names}")
    yield from heapq.merge(*iters, key=lambda x: x[0])


def _to_geojson_event(evt: dict[str, Any], stream: str, scenario: str, mode: str) -> dict[str, Any]:
    """Convert an NDJSON record to a flat JSON event with GeoJSON geometry field."""
    lat = evt.get("lat") or evt.get("latitude")
    lon = evt.get("lon") or evt.get("longitude")

    out = dict(evt)
    out["_stream"]    = stream
    out["_scenario"]  = scenario
    out["_mode"]      = mode
    out["_emit_ts"]   = _now_iso()

    # Ensure lat/lon are floats at top-level for Fabric Maps
    if lat is not None and lon is not None:
        out["lat"] = float(lat)
        out["lon"] = float(lon)
        out["geometry"] = _geojson_point(float(lon), float(lat))

    return out


DEFAULT_STREAMS = ["ais", "mac", "plane_radar", "drone_radar", "coastal_radar"]


def run_replay(scenario_id: str, stream_names: list[str], loop: bool, speed: float, mode: str = "replay") -> None:
    pace = "asfast" if speed == 0 else ("realtime" if speed == 1.0 else f"{speed:.2f}x")
    print(f"[streamer] mode={mode} scenario={scenario_id} streams={stream_names} pace={pace} loop={loop}",
          file=sys.stderr)
    iteration = 0
    while not _stop:
        iteration += 1
        first_ts: datetime | None = None
        wall_start = time.time()
        count = 0
        for ts, stream, evt in _merged_ndjson(scenario_id, stream_names):
            if _stop:
                break
            if first_ts is None:
                first_ts = ts
                wall_start = time.time()

            # Pacing
            if speed > 0 and first_ts is not None:
                scenario_elapsed = (ts - first_ts).total_seconds() / speed
                wall_elapsed = time.time() - wall_start
                delay = scenario_elapsed - wall_elapsed
                if delay > 0.001:
                    time.sleep(delay)

            out = _to_geojson_event(evt, stream, scenario_id, mode)
            _emit(out)
            count += 1

        print(f"[streamer] replay iteration {iteration} done, emitted {count} events", file=sys.stderr)
        if not loop or _stop:
            break
        print("[streamer] looping…", file=sys.stderr)


# ---------------------------------------------------------------------------
# Live data fetching helpers
# ---------------------------------------------------------------------------

def _http_get_json(url: str, headers: dict | None = None, timeout: int = 15) -> dict | list | None:
    req = urllib.request.Request(url, headers=headers or {"Accept": "application/json"})
    try:
        with urllib.request.urlopen(req, timeout=timeout) as resp:
            return json.loads(resp.read())
    except Exception as e:
        print(f"[streamer] fetch error {url}: {e}", file=sys.stderr)
        return None


def _fetch_ais() -> list[dict[str, Any]]:
    data = _http_get_json(
        AIS_URL,
        headers={
            "Accept": "application/json",
            "Digitraffic-User": "r-mac-data-scenarios/1.0",
        },
    )
    if not data:
        return []
    feats = (data.get("features") or []) if isinstance(data, dict) else []
    out = []
    for f in feats:
        if not (f.get("geometry") and f["geometry"].get("coordinates")):
            continue
        lon, lat = f["geometry"]["coordinates"][:2]
        props = f.get("properties") or {}
        mmsi = f.get("mmsi") or props.get("mmsi")
        ts = _now_iso()
        out.append({
            "_stream": "ais",
            "_scenario": "live",
            "_mode": "live",
            "_emit_ts": ts,
            "timestamp": ts,
            "event_type": "ais_position",
            "lat": round(lat, 6),
            "lon": round(lon, 6),
            "geometry": _geojson_point(lon, lat),
            "mmsi": mmsi,
            "sog_kn": props.get("sog"),
            "cog_deg": props.get("cog"),
            "nav_status": props.get("navStat"),
            "source": "digitraffic",
        })
    return out


def _fetch_opensky() -> list[dict[str, Any]]:
    data = _http_get_json(OSK_PROXY_URL)
    if not data:
        return []
    rows = (data.get("states") or []) if isinstance(data, dict) else []
    out = []
    ts = _now_iso()
    for s in rows:
        if s[5] is None or s[6] is None:
            continue
        if s[8]:  # on_ground
            continue
        lon, lat = float(s[5]), float(s[6])
        alt = s[7] or s[13] or 0.0
        out.append({
            "_stream": "plane_radar",
            "_scenario": "live",
            "_mode": "live",
            "_emit_ts": ts,
            "timestamp": ts,
            "event_type": "aircraft_position",
            "lat": round(lat, 6),
            "lon": round(lon, 6),
            "geometry": _geojson_point(lon, lat),
            "track_id": str(s[0]),
            "callsign": (s[1] or "").strip(),
            "alt_m": round(float(alt), 1),
            "on_ground": False,
            "classification": "airborne",
            "source": "opensky",
        })
    return out


# Synthetic live drones — simple circular orbits around fixed waypoints
_LIVE_DRONE_WAYPOINTS = [
    (60.30, 25.55),  # near Kilpilahti
    (60.15, 24.97),  # near Helsinki harbour
    (59.95, 23.82),  # Estlink corridor
]

def _gen_synthetic_drones(n: int, rng: random.Random) -> list[dict[str, Any]]:
    ts = _now_iso()
    epoch_s = time.time()
    out = []
    for i in range(min(n, len(_LIVE_DRONE_WAYPOINTS))):
        clat, clon = _LIVE_DRONE_WAYPOINTS[i]
        angle = (epoch_s / 120.0 + i * 2.1) * (2 * math.pi)  # full circle every 2 min
        r_deg = 0.008  # ~600 m radius
        lat = clat + r_deg * math.sin(angle)
        lon = clon + r_deg * math.cos(angle) / math.cos(math.radians(clat))
        alt = 80 + 20 * math.sin(epoch_s / 60.0 + i)
        out.append({
            "_stream": "drone_radar",
            "_scenario": "live_synthetic",
            "_mode": "live",
            "_emit_ts": ts,
            "timestamp": ts,
            "event_type": "drone_position",
            "lat": round(lat, 6),
            "lon": round(lon, 6),
            "geometry": _geojson_point(lon, lat),
            "track_id": f"T-SYNTH-DRN-{i+1:02d}",
            "sensor_id": "RAD-PLN-01",
            "alt_m": round(alt, 1),
            "classification": "drone_small",
            "rcs_m2": round(rng.gauss(0.025, 0.005), 4),
            "confidence": round(rng.uniform(0.70, 0.92), 3),
            "kind": "airborne",
            "source": "synthetic",
        })
    return out


def run_live(drone_count: int) -> None:
    print(f"[streamer] mode=live drones={drone_count} poll_ais={POLL_AIS_S}s poll_osk={POLL_OSK_S}s",
          file=sys.stderr)
    rng = random.Random()
    last_ais = 0.0
    last_osk = 0.0

    while not _stop:
        now = time.time()
        emitted = 0

        if now - last_ais >= POLL_AIS_S:
            for evt in _fetch_ais():
                if _stop: return
                _emit(evt)
                emitted += 1
            last_ais = time.time()

        if now - last_osk >= POLL_OSK_S:
            for evt in _fetch_opensky():
                if _stop: return
                _emit(evt)
                emitted += 1
            # Always emit synthetic drones alongside aircraft
            for evt in _gen_synthetic_drones(drone_count, rng):
                if _stop: return
                _emit(evt)
                emitted += 1
            last_osk = time.time()

        if emitted:
            print(f"[streamer] live batch emitted {emitted} events", file=sys.stderr)

        if not _stop:
            time.sleep(1)


def run_mixed(scenario_id: str, stream_names: list[str], speed: float, loop: bool, drone_count: int) -> None:
    """Run replay in a background thread while live data is polled in the foreground."""
    print(f"[streamer] mode=mixed scenario={scenario_id}", file=sys.stderr)

    def _replay_thread():
        run_replay(scenario_id, stream_names, loop, speed, mode="mixed")

    t = threading.Thread(target=_replay_thread, daemon=True, name="replay")
    t.start()
    run_live(drone_count)
    t.join(timeout=5)


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main() -> int:
    _install_signal_handlers()

    if not EVENTSTREAM_URL and not DRY_RUN:
        print(
            "[streamer] EVENTSTREAM_URL is not set — set DRY_RUN=1 for stdout or provide the endpoint URL.",
            file=sys.stderr,
        )
        return 1

    stream_names = [s.strip() for s in STREAMS_ENV.split(",") if s.strip()] if STREAMS_ENV else DEFAULT_STREAMS

    print(f"[streamer] starting  mode={MODE}  dry_run={DRY_RUN}  target={'stdout' if DRY_RUN else EVENTSTREAM_URL[:60]+'…'}",
          file=sys.stderr)

    try:
        if MODE == "replay":
            run_replay(SCENARIO_ID, stream_names, LOOP, SPEED)
        elif MODE == "live":
            run_live(DRONE_COUNT)
        elif MODE == "mixed":
            run_mixed(SCENARIO_ID, stream_names, SPEED, LOOP, DRONE_COUNT)
        else:
            print(f"[streamer] unknown MODE={MODE!r}, expected replay|live|mixed", file=sys.stderr)
            return 1
    except KeyboardInterrupt:
        print("[streamer] interrupted", file=sys.stderr)
    except Exception as e:
        print(f"[streamer] fatal: {e}", file=sys.stderr)
        return 1

    print("[streamer] done", file=sys.stderr)
    return 0


if __name__ == "__main__":
    sys.exit(main())
