"""Shared common helpers for generators.

Synthetic demo data inspired by real Baltic geography, MMSI/OUI conventions,
and infrastructure. Not real observations.
"""
from __future__ import annotations

import csv
import json
import math
import random
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from pathlib import Path
from typing import Any, Iterable

DISCLAIMER = (
    "Synthetic demo data inspired by real Baltic geography, MMSI/OUI conventions, "
    "and infrastructure. Not real observations."
)


def disclaimer_record(dataset: str) -> dict[str, Any]:
    # Standard json.dumps emits this as `{"__meta__": "synthetic", ...}` (with a
    # space after the colon). Consumers should detect meta records by the
    # presence of the `__meta__` field — see `is_meta_record()` — rather than
    # by an exact prefix match, because the JSON encoder may add or omit
    # whitespace.
    return {
        "__meta__": "synthetic",
        "disclaimer": DISCLAIMER,
        "dataset": dataset,
        "version": "1.0",
    }


def is_meta_record(line: str) -> bool:
    """Return True if `line` is a disclaimer / meta JSON record.

    Robust to JSON-encoder whitespace variation: matches both `{"__meta__"`
    and `{"__meta__":` after stripping leading whitespace.
    """
    s = line.lstrip()
    return s.startswith('{"__meta__"') or s.startswith('{"__meta__":')


# ---------- Time helpers ----------

def iso_utc(t: datetime) -> str:
    if t.tzinfo is None:
        t = t.replace(tzinfo=timezone.utc)
    return t.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.") + f"{t.microsecond:06d}" + "Z"


def epoch_ms(t: datetime) -> int:
    if t.tzinfo is None:
        t = t.replace(tzinfo=timezone.utc)
    return int(t.timestamp() * 1000)


def time_range(start: datetime, end: datetime, step_s: float) -> Iterable[datetime]:
    t = start
    delta = timedelta(seconds=step_s)
    while t < end:
        yield t
        t = t + delta


# ---------- Geo helpers ----------

EARTH_R_M = 6_371_000.0
NM_M = 1852.0
KN_MS = 0.514444  # knots -> m/s


def haversine_m(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    p1, p2 = math.radians(lat1), math.radians(lat2)
    dp = math.radians(lat2 - lat1)
    dl = math.radians(lon2 - lon1)
    a = math.sin(dp / 2) ** 2 + math.cos(p1) * math.cos(p2) * math.sin(dl / 2) ** 2
    return 2 * EARTH_R_M * math.asin(math.sqrt(a))


def bearing_deg(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
    p1, p2 = math.radians(lat1), math.radians(lat2)
    dl = math.radians(lon2 - lon1)
    y = math.sin(dl) * math.cos(p2)
    x = math.cos(p1) * math.sin(p2) - math.sin(p1) * math.cos(p2) * math.cos(dl)
    return (math.degrees(math.atan2(y, x)) + 360) % 360


def move_along(lat: float, lon: float, bearing_deg_: float, dist_m: float) -> tuple[float, float]:
    d = dist_m / EARTH_R_M
    th = math.radians(bearing_deg_)
    p1 = math.radians(lat)
    l1 = math.radians(lon)
    p2 = math.asin(math.sin(p1) * math.cos(d) + math.cos(p1) * math.sin(d) * math.cos(th))
    l2 = l1 + math.atan2(math.sin(th) * math.sin(d) * math.cos(p1), math.cos(d) - math.sin(p1) * math.sin(p2))
    return math.degrees(p2), (math.degrees(l2) + 540) % 360 - 180


def interp_waypoints(waypoints: list[tuple[datetime, float, float]], t: datetime) -> tuple[float, float, float, float]:
    """Linear interpolation between (time, lat, lon) waypoints.
    Returns (lat, lon, sog_kn, cog_deg)."""
    if not waypoints:
        raise ValueError("interp_waypoints: waypoints must be non-empty")
    if len(waypoints) < 2:
        # Single waypoint: no motion can be derived.
        return waypoints[0][1], waypoints[0][2], 0.0, 0.0
    if t <= waypoints[0][0]:
        wp0, wp1 = waypoints[0], waypoints[1]
    elif t >= waypoints[-1][0]:
        wp0, wp1 = waypoints[-2], waypoints[-1]
    else:
        for i in range(len(waypoints) - 1):
            if waypoints[i][0] <= t <= waypoints[i + 1][0]:
                wp0, wp1 = waypoints[i], waypoints[i + 1]
                break
        else:
            raise ValueError(
                "interp_waypoints: time t did not match any waypoint interval"
            )
    span = (wp1[0] - wp0[0]).total_seconds()
    f = 0.0 if span <= 0 else (t - wp0[0]).total_seconds() / span
    f = max(0.0, min(1.0, f))
    lat = wp0[1] + f * (wp1[1] - wp0[1])
    lon = wp0[2] + f * (wp1[2] - wp0[2])
    dist_m = haversine_m(wp0[1], wp0[2], wp1[1], wp1[2])
    sog_ms = 0.0 if span <= 0 else dist_m / span
    sog_kn = sog_ms / KN_MS
    cog = bearing_deg(wp0[1], wp0[2], wp1[1], wp1[2])
    return lat, lon, sog_kn, cog


# ---------- Signal helpers ----------

def rssi_from_distance(dist_m: float, *, tx_dbm: float = -15.0, n: float = 2.2, noise_db: float = 3.0,
                       rng: random.Random | None = None, floor_dbm: float = -110.0) -> float:
    """Path-loss model -> RSSI. Tuned for open-water propagation of ship-mounted
    devices reaching coastal sensors (effective range ~5–18 km)."""
    rng = rng or random
    d = max(dist_m, 1.0)
    rssi = tx_dbm - 10.0 * n * math.log10(d)
    rssi += rng.gauss(0.0, noise_db)
    return max(floor_dbm, rssi)


# ---------- Writers ----------

def _record_sort_key(rec: dict[str, Any]) -> tuple[int, float]:
    """Best-available timestamp for sorting NDJSON records.

    Tries (in order) `ts_epoch_ms`, `processingTimestamp` (epoch ms or ISO),
    `timestamp` (ISO), `sessionEnd` (ISO). Records without any usable field
    sort to the end while preserving their relative order (caller-stable).
    """
    v = rec.get("ts_epoch_ms")
    if isinstance(v, (int, float)):
        return (0, float(v))
    v = rec.get("processingTimestamp")
    if isinstance(v, (int, float)):
        return (0, float(v))
    if isinstance(v, str):
        try:
            dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
            return (0, dt.timestamp() * 1000.0)
        except ValueError:
            pass
    for key in ("timestamp", "sessionEnd"):
        v = rec.get(key)
        if isinstance(v, str):
            try:
                dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
                return (0, dt.timestamp() * 1000.0)
            except ValueError:
                continue
    return (1, 0.0)


def write_ndjson(path: str | Path, records: Iterable[dict[str, Any]], dataset: str,
                 *, sort: bool = True) -> int:
    """Write NDJSON; disclaimer is always line 1. When `sort=True` (default)
    records are sorted by best-available timestamp field before write so
    consumers don't see backward jumps. Sort is stable, so records with no
    usable timestamp retain their caller-given order at the end."""
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    recs = list(records)
    if sort:
        recs.sort(key=_record_sort_key)
    count = 0
    with p.open("w", encoding="utf-8") as f:
        f.write(json.dumps(disclaimer_record(dataset)) + "\n")
        count += 1
        for rec in recs:
            f.write(json.dumps(rec, default=_json_default) + "\n")
            count += 1
    return count


def _json_default(obj: Any) -> Any:
    if isinstance(obj, datetime):
        return iso_utc(obj)
    raise TypeError(f"Type {type(obj)} not JSON-serializable")


def _csv_row_sort_key(row: list[Any], idx: int) -> tuple[int, float]:
    if idx < 0 or idx >= len(row):
        return (1, 0.0)
    v = row[idx]
    if isinstance(v, (int, float)):
        return (0, float(v))
    if isinstance(v, str):
        try:
            return (0, float(v))
        except ValueError:
            try:
                dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
                return (0, dt.timestamp() * 1000.0)
            except ValueError:
                return (1, 0.0)
    return (1, 0.0)


def write_csv(path: str | Path, header: list[str], rows: Iterable[list[Any]], dataset: str,
              *, sort: bool = True) -> int:
    """Write CSV; disclaimer is a `#`-prefixed comment on line 1. When
    `sort=True` (default), rows are sorted by `ingestion_ts` if present in
    the header. Sort is stable."""
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    rows_list = list(rows)
    if sort and "ingestion_ts" in header:
        idx = header.index("ingestion_ts")
        rows_list.sort(key=lambda r: _csv_row_sort_key(r, idx))
    count = 0
    with p.open("w", encoding="utf-8", newline="") as f:
        f.write(f"# {json.dumps(disclaimer_record(dataset))}\n")
        w = csv.writer(f)
        w.writerow(header)
        for r in rows_list:
            w.writerow(r)
            count += 1
    return count


def write_geojson(path: str | Path, features: list[dict[str, Any]], dataset: str) -> int:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    fc = {
        "type": "FeatureCollection",
        "_meta": disclaimer_record(dataset),
        "features": features,
    }
    with p.open("w", encoding="utf-8") as f:
        json.dump(fc, f, default=_json_default)
    return len(features)


# ---------- Catalog loaders ----------

_CATALOG_DIR = Path(__file__).resolve().parents[2] / "catalogs"

# Module-level caches so repeated calls (very common in scenario generators)
# don't re-parse the JSON files on every call.
_PERSONAS_CACHE: dict[str, Any] | None = None
_SENSORS_CACHE: dict[str, Any] | None = None
_INFRASTRUCTURE_CACHE: dict[str, Any] | None = None
_SENSOR_LOOKUP_CACHE: dict[str, dict[str, Any]] | None = None
_VESSEL_LOOKUP_CACHE: dict[int, dict[str, Any]] | None = None


def clear_catalog_cache() -> None:
    """Reset all catalog caches. Intended for tests that mutate catalog files."""
    global _PERSONAS_CACHE, _SENSORS_CACHE, _INFRASTRUCTURE_CACHE
    global _SENSOR_LOOKUP_CACHE, _VESSEL_LOOKUP_CACHE
    _PERSONAS_CACHE = None
    _SENSORS_CACHE = None
    _INFRASTRUCTURE_CACHE = None
    _SENSOR_LOOKUP_CACHE = None
    _VESSEL_LOOKUP_CACHE = None


def load_personas() -> dict[str, Any]:
    global _PERSONAS_CACHE
    if _PERSONAS_CACHE is None:
        _PERSONAS_CACHE = json.loads(
            (_CATALOG_DIR / "personas.json").read_text(encoding="utf-8")
        )
    return _PERSONAS_CACHE


def load_sensors() -> dict[str, Any]:
    global _SENSORS_CACHE
    if _SENSORS_CACHE is None:
        _SENSORS_CACHE = json.loads(
            (_CATALOG_DIR / "sensors.geojson").read_text(encoding="utf-8")
        )
    return _SENSORS_CACHE


def load_infrastructure() -> dict[str, Any]:
    global _INFRASTRUCTURE_CACHE
    if _INFRASTRUCTURE_CACHE is None:
        _INFRASTRUCTURE_CACHE = json.loads(
            (_CATALOG_DIR / "infrastructure.geojson").read_text(encoding="utf-8")
        )
    return _INFRASTRUCTURE_CACHE


def sensor_lookup() -> dict[str, dict[str, Any]]:
    global _SENSOR_LOOKUP_CACHE
    if _SENSOR_LOOKUP_CACHE is None:
        out: dict[str, dict[str, Any]] = {}
        for feat in load_sensors()["features"]:
            sid = feat["properties"]["sensorId"]
            lon, lat = feat["geometry"]["coordinates"]
            out[sid] = {**feat["properties"], "lat": lat, "lon": lon}
        _SENSOR_LOOKUP_CACHE = out
    return _SENSOR_LOOKUP_CACHE


def vessel_lookup() -> dict[int, dict[str, Any]]:
    global _VESSEL_LOOKUP_CACHE
    if _VESSEL_LOOKUP_CACHE is None:
        out: dict[int, dict[str, Any]] = {}
        for v in load_personas()["vessels"]:
            if v.get("mmsi"):
                out[v["mmsi"]] = v
        _VESSEL_LOOKUP_CACHE = out
    return _VESSEL_LOOKUP_CACHE


def crew_by_ship(ship_name: str) -> list[dict[str, Any]]:
    return [p for p in load_personas()["persons"] if p["ship"] == ship_name]


# ---------- Decimation helpers ----------

def _epoch_minute(rec: dict[str, Any], ts_field: str) -> int | None:
    """Return the minute-floored epoch (ms) for a record's timestamp field.

    Accepts the field as either an int/float (epoch ms) or an ISO-8601 string.
    """
    v = rec.get(ts_field)
    if isinstance(v, (int, float)):
        return int(v // 60_000) * 60_000
    if isinstance(v, str):
        try:
            dt = datetime.fromisoformat(v.replace("Z", "+00:00"))
            return int(dt.timestamp() // 60) * 60_000
        except ValueError:
            return None
    return None


def decimate_by_minute(records: Iterable[dict[str, Any]],
                       *, key_field: str = "mmsi",
                       ts_field: str = "ts_epoch_ms",
                       project_fields: list[str] | None = None) -> list[dict[str, Any]]:
    """Pick one record per (key, minute) bin — keep the LATEST per bin.

    Records without a usable key or timestamp are dropped.
    Output is stable-sorted by (timestamp, key).
    If `project_fields` is provided, the output records are slimmed to only
    those fields (plus the key_field and ts_field).
    """
    latest: dict[tuple[Any, int], tuple[float, dict[str, Any]]] = {}
    for r in records:
        k = r.get(key_field)
        if k is None:
            continue
        m = _epoch_minute(r, ts_field)
        if m is None:
            continue
        # Use full ts (not minute) as ordering key within the bin
        ts_val = r.get(ts_field)
        if isinstance(ts_val, str):
            try:
                ts_num = datetime.fromisoformat(ts_val.replace("Z", "+00:00")).timestamp() * 1000.0
            except ValueError:
                ts_num = float(m)
        else:
            ts_num = float(ts_val or m)
        key = (k, m)
        prev = latest.get(key)
        if prev is None or ts_num > prev[0]:
            latest[key] = (ts_num, r)
    out = []
    if project_fields is not None:
        keep = set(project_fields) | {key_field, ts_field}
        for _, r in latest.values():
            out.append({k: v for k, v in r.items() if k in keep})
    else:
        out = [v[1] for v in latest.values()]
    out.sort(key=lambda r: (_epoch_minute(r, ts_field) or 0, str(r.get(key_field))))
    return out


def decimate_mac_by_minute(records: Iterable[dict[str, Any]],
                           *, sensor_field: str = "deviceId",
                           ts_field: str = "processingTimestamp",
                           mac_field: str = "macAddress",
                           sample_size: int = 12) -> list[dict[str, Any]]:
    """Aggregate MAC observations to one row per (sensor, minute).

    Each output row has: deviceId, minute_ts (epoch ms), observation_count,
    distinct_mac_count, sample_macs (up to sample_size).
    """
    bins: dict[tuple[str, int], dict[str, Any]] = {}
    for r in records:
        sid = r.get(sensor_field)
        if not sid:
            continue
        m = _epoch_minute(r, ts_field)
        if m is None:
            continue
        key = (sid, m)
        entry = bins.get(key)
        if entry is None:
            entry = {
                sensor_field: sid,
                "minute_ts_epoch_ms": m,
                "minute_ts": iso_utc(datetime.fromtimestamp(m / 1000.0, tz=timezone.utc)),
                "observation_count": 0,
                "distinct_macs": set(),
            }
            bins[key] = entry
        entry["observation_count"] += 1
        mac = r.get(mac_field)
        if mac:
            entry["distinct_macs"].add(mac)
    out: list[dict[str, Any]] = []
    for entry in bins.values():
        macs = sorted(entry.pop("distinct_macs"))
        entry["distinct_mac_count"] = len(macs)
        entry["sample_macs"] = macs[:sample_size]
        # ts_epoch_ms convention so write_ndjson sorting picks it up
        entry["ts_epoch_ms"] = entry["minute_ts_epoch_ms"]
        out.append(entry)
    out.sort(key=lambda r: (r["minute_ts_epoch_ms"], r[sensor_field]))
    return out


def maybe_decimate_ndjson(source_path: str | Path, *, key_field: str = "mmsi",
                          ts_field: str = "ts_epoch_ms",
                          project_fields: list[str] | None = None,
                          dataset_suffix: str = "_decimated",
                          size_threshold_bytes: int = 20 * 1024 * 1024
                          ) -> dict[str, Any] | None:
    """If `source_path` (NDJSON) exists and exceeds `size_threshold_bytes`,
    write a sibling `<stem>_decimated.ndjson` with one record per minute per
    key. Returns a small report dict, or None if not decimated.
    """
    p = Path(source_path)
    if not p.exists():
        return None
    size = p.stat().st_size
    if size <= size_threshold_bytes:
        return None
    records: list[dict[str, Any]] = []
    meta_dataset = "unknown"
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            if is_meta_record(line):
                try:
                    m = json.loads(line)
                    meta_dataset = m.get("dataset", meta_dataset)
                except Exception:
                    pass
                continue
            try:
                records.append(json.loads(line))
            except json.JSONDecodeError:
                continue
    decimated = decimate_by_minute(records, key_field=key_field, ts_field=ts_field,
                                   project_fields=project_fields)
    out_path = p.with_name(p.stem + dataset_suffix + p.suffix)
    write_ndjson(out_path, decimated, meta_dataset + dataset_suffix, sort=False)
    return {
        "source": str(p),
        "decimated": str(out_path),
        "source_bytes": size,
        "decimated_bytes": out_path.stat().st_size,
        "rows": len(decimated),
    }


def maybe_decimate_mac_ndjson(source_path: str | Path,
                              *, size_threshold_bytes: int = 20 * 1024 * 1024
                              ) -> dict[str, Any] | None:
    """Per-minute MAC summary companion file (1 row per sensor per minute)."""
    p = Path(source_path)
    if not p.exists():
        return None
    size = p.stat().st_size
    if size <= size_threshold_bytes:
        return None
    records: list[dict[str, Any]] = []
    meta_dataset = "unknown"
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            if is_meta_record(line):
                try:
                    m = json.loads(line)
                    meta_dataset = m.get("dataset", meta_dataset)
                except Exception:
                    pass
                continue
            try:
                records.append(json.loads(line))
            except json.JSONDecodeError:
                continue
    decimated = decimate_mac_by_minute(records)
    out_path = p.with_name(p.stem + "_decimated" + p.suffix)
    write_ndjson(out_path, decimated, meta_dataset + "_decimated", sort=False)
    return {
        "source": str(p),
        "decimated": str(out_path),
        "source_bytes": size,
        "decimated_bytes": out_path.stat().st_size,
        "rows": len(decimated),
    }


# ---------- Synthetic ambient MMSI helper ----------

def ambient_mmsi(rng: random.Random, flag: str = "FI") -> int:
    """Pick a random ambient/decoy MMSI from the synthetic 9XX… block defined
    in `catalogs/personas.json -> ambient_mmsi_blocks`. The block is reserved
    for background ships and never collides with real ITU MIDs (which live in
    the 2XX–7XX ranges). Falls back to the OTHER block if `flag` isn't known.
    """
    blocks = load_personas().get("ambient_mmsi_blocks", {}).get("ranges", {})
    lo_hi = blocks.get(flag) or blocks.get("OTHER")
    if not lo_hi:
        raise RuntimeError(
            "ambient_mmsi_blocks not configured in catalogs/personas.json"
        )
    lo, hi = lo_hi
    return rng.randint(int(lo), int(hi))
