﻿"""Scenario 06 — Drone Swarm from Ship (Coastal Attack Pattern): data generator.

Story: MV POHJANTUULI (a gray-zone research/utility vessel, MMSI 230991601) approaches
from the east and anchors ~12 NM south of Porvoo/Kilpilahti in open water. Five minutes
before launching its payload, the ship goes AIS-dark. Then six small drones lift off
simultaneously in a coordinated fan-pattern sweep toward Finnish coastal infrastructure:
two toward Kilpilahti petrochemical complex, two toward Porvoo harbor, and two toward
the Sipoo/Helsinki area. Each drone carries a DJI-OUI MAC that coastal and airborne
sensors detect as the swarm crosses the 12 NM mark. Patrol plane RAD-PLN-01 tracks
all six contacts simultaneously — the first time a single platform sees a multi-track
coordinated drone launch in the data record.

Pipeline:
  1. MV POHJANTUULI transits GoF westbound, holds anchor at (59.92°N, 25.55°E), then
     resumes after AIS reappears.
  2. 6 drones (T-DRN-SW-01 … -06) lift off from ship deck simultaneously at T+0,
     climb to 90–110 m, diverge toward six coastal targets in a ~25° arc.
  3. Each drone has a unique DJI-OUI MAC. Coastal sensors MAC-PRV-COAST-01/-02 and
     MAC-HEL-COAST-01 pick up individual MACs as drones enter range.
  4. Plane radar (RAD-PLN-01) acquires all six contacts within 90 s of takeoff.
  5. Patrol drone (RAD-DRN-PAT-01) from Helsinki-Malmi scrambles to intercept.
  6. No return flight — drones do not recover to the ship (expendable payload).
"""
from __future__ import annotations

import json
import math
import random
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any

REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from generators.common import (  # noqa: E402
    KN_MS,
    ambient_mmsi,
    haversine_m,
    iso_utc,
    load_infrastructure,
    load_sensors,
    maybe_decimate_mac_ndjson,
    maybe_decimate_ndjson,
    rssi_from_distance,
    sensor_lookup,
    write_csv,
    write_geojson,
    write_ndjson,
)
from generators.ais_generator import AisTrack, ais_snapshot_geojson, emit_ais  # noqa: E402
from generators.mac_generator import (  # noqa: E402
    MAC_CSV_HEADER,
    MacObservation,
    MovingMacEmitter,
    generate_background_macs,
    simulate_moving_mac,
)
from generators.radar_generator import RadarTrack, emit_drone_radar, emit_radar  # noqa: E402

UTC = timezone.utc
SCENARIO_DIR = Path(__file__).resolve().parent
OUT_REALTIME = SCENARIO_DIR / "data" / "realtime"
OUT_STATIC = SCENARIO_DIR / "data" / "static"
OUT_HISTORICAL = SCENARIO_DIR / "data" / "historical"

# ---------------------------------------------------------------------------
# Time anchors
# ---------------------------------------------------------------------------
WINDOW_OPEN  = datetime(2025, 6, 12,  9,  0, 0, tzinfo=UTC)
WINDOW_CLOSE = datetime(2025, 6, 12, 11,  0, 0, tzinfo=UTC)

AIS_DARK_START  = datetime(2025, 6, 12,  9, 28, 0, tzinfo=UTC)
SWARM_LAUNCH    = datetime(2025, 6, 12,  9, 33, 0, tzinfo=UTC)
AIS_DARK_END    = datetime(2025, 6, 12, 10, 10, 0, tzinfo=UTC)

PAT_SCRAMBLE    = datetime(2025, 6, 12,  9, 42, 0, tzinfo=UTC)
PAT_LAND        = datetime(2025, 6, 12, 10, 45, 0, tzinfo=UTC)

# ---------------------------------------------------------------------------
# Geographic anchors
# ---------------------------------------------------------------------------
# Mother ship anchor position (open Gulf, ~22 km south of Kilpilahti)
SHIP_LAT, SHIP_LON = 59.920, 25.555

# MV POHJANTUULI waypoints — must be strictly chronological
# Fast gray-zone utility vessel: enters ~12 NM east of anchor at window open,
# sprints to anchor (~46 kn), holds stationary, launches swarm under AIS dark,
# then departs slowly westward (~26 kn) after AIS reappears.
POHJANT_WAYPOINTS: list[tuple[datetime, float, float]] = [
    (WINDOW_OPEN,                                           59.910, 26.200),  # entry ~21 NM east of anchor
    (AIS_DARK_START - timedelta(minutes=3),                59.920, 25.555),  # anchor on-station (09:25)
    # AIS dark 09:28 → 10:10 (42 min); drones launch at 09:33 from anchor
    (AIS_DARK_END + timedelta(minutes=5),                  59.920, 25.550),  # AIS reappears, still at anchor (10:15)
    (WINDOW_CLOSE,                                         59.905, 24.900),  # gradual westward departure ~26 kn over 45 min
]

# Six drone targets along the Finnish coast (fan pattern, ~20–28 km from ship)
# Each tuple: (track_id, target_lat, target_lon, alt_cruise_m, mac_addr, seed)
SWARM: list[tuple[str, float, float, float, str, int]] = [
    ("T-DRN-SW-01", 60.300, 25.555,  95.0, "5C:E2:8C:AA:01:01", 601),  # Kilpilahti N
    ("T-DRN-SW-02", 60.280, 25.640,  90.0, "5C:E2:8C:AA:02:02", 602),  # Kilpilahti NE
    ("T-DRN-SW-03", 60.225, 25.660, 100.0, "5C:E2:8C:AA:03:03", 603),  # Porvoo harbor W
    ("T-DRN-SW-04", 60.215, 25.700, 105.0, "5C:E2:8C:AA:04:04", 604),  # Porvoo harbor E
    ("T-DRN-SW-05", 60.140, 25.350,  85.0, "5C:E2:8C:AA:05:05", 605),  # Sipoo coast W
    ("T-DRN-SW-06", 60.110, 25.200,  88.0, "5C:E2:8C:AA:06:06", 606),  # Sipoo coast NW
]

# Patrol drone (Border Guard, Helsinki-Malmi → intercept arc)
PATROL_PATH_3D: list[tuple[datetime, float, float, float]] = [
    (PAT_SCRAMBLE,                                           60.254, 25.041,   0.0),
    (PAT_SCRAMBLE + timedelta(minutes=6),                    60.220, 25.250, 220.0),
    (PAT_SCRAMBLE + timedelta(minutes=14),                   60.180, 25.430, 240.0),
    (PAT_SCRAMBLE + timedelta(minutes=22),                   60.110, 25.500, 235.0),  # intercept swarm
    (PAT_SCRAMBLE + timedelta(minutes=30),                   60.050, 25.540, 230.0),  # descend toward ship area
    (PAT_SCRAMBLE + timedelta(minutes=40),                   59.930, 25.555, 200.0),  # over mother ship
    (PAT_LAND,                                               60.254, 25.041,   0.0),
]
PATROL_WAYPOINTS_2D = [(t, la, lo) for (t, la, lo, _) in PATROL_PATH_3D]
PATROL_ALTITUDES    = [(t, al) for (t, _la, _lo, al) in PATROL_PATH_3D]

SENSORS_USED_IDS = {
    "MAC-PRV-COAST-01", "MAC-PRV-COAST-02",
    "MAC-HEL-COAST-01",
    "MAC-AIR-PLN-01", "MAC-AIR-DRN-01",
    "RAD-PLN-01", "RAD-DRN-PAT-01",
}

INFRA_USED_IDS = {
    "site-kilpilahti", "port-kilpilahti",
    "port-porvoo",
    "shipping-lane-eb", "shipping-lane-wb",
}


def _drone_profile(swarm_entry: tuple, n_ambient: int, seed_base: int):
    """Build 3D path and altitude profile for one swarm drone."""
    track_id, tgt_lat, tgt_lon, alt_m, mac, seed = swarm_entry
    # Compute cruise duration: distance at 18 m/s
    dist_m = haversine_m(SHIP_LAT, SHIP_LON, tgt_lat, tgt_lon)
    cruise_s = dist_m / 18.0

    t0 = SWARM_LAUNCH
    t_cruise = t0 + timedelta(seconds=cruise_s)

    # 3D path: lift → cruise → arrival
    path_3d = [
        (t0,                                      SHIP_LAT, SHIP_LON, 0.0),
        (t0 + timedelta(seconds=30),              SHIP_LAT, SHIP_LON, alt_m * 0.6),
        (t0 + timedelta(seconds=90),              SHIP_LAT + (tgt_lat - SHIP_LAT) * 0.05,
                                                  SHIP_LON + (tgt_lon - SHIP_LON) * 0.05, alt_m),
        (t_cruise - timedelta(seconds=30),        tgt_lat - (tgt_lat - SHIP_LAT) * 0.02,
                                                  tgt_lon - (tgt_lon - SHIP_LON) * 0.02, alt_m * 0.8),
        (t_cruise,                                tgt_lat, tgt_lon, 20.0),
    ]
    waypoints_2d = [(t, la, lo) for (t, la, lo, _) in path_3d]
    altitudes    = [(t, al) for (t, _la, _lo, al) in path_3d]
    return path_3d, waypoints_2d, altitudes, track_id, mac, seed


def build_ambient_ais(n_ships: int, seed: int) -> list[dict[str, Any]]:
    rng = random.Random(seed)
    out: list[dict[str, Any]] = []
    for i in range(n_ships):
        eastbound = rng.random() < 0.5
        lat0 = rng.uniform(59.65, 60.30)
        lat1 = lat0 + rng.uniform(-0.10, 0.10)
        lon0, lon1 = (22.5, 27.5) if eastbound else (27.5, 22.5)
        t_start = WINDOW_OPEN + timedelta(minutes=rng.uniform(0, 60))
        t_end = min(WINDOW_CLOSE, t_start + timedelta(minutes=rng.uniform(60, 120)))
        if t_end <= t_start:
            continue
        flag_roll = rng.random()
        flag = "FI" if flag_roll < 0.7 else ("EE" if flag_roll < 0.9 else "OTHER")
        mmsi = ambient_mmsi(rng, flag)
        track = AisTrack(
            mmsi=mmsi,
            waypoints=[(t_start, lat0, lon0), (t_end, lat1, lon1)],
            cadence_s=15.0,
            destination="FIHEL" if eastbound else "EETLL",
            seed=seed + i,
        )
        out.extend(emit_ais(track))
    return out


def generate_realtime() -> dict[str, int]:
    sensors = sensor_lookup()
    counts: dict[str, int] = {}

    # ----- AIS: mother ship + ambient -----
    pohjant_track = AisTrack(
        mmsi=230991601,
        waypoints=POHJANT_WAYPOINTS,
        cadence_s=3.0,
        dark_windows=[(AIS_DARK_START, AIS_DARK_END)],
        destination="FIHAN",
        nav_status=1,  # at anchor
        seed=601,
    )
    ship_msgs = emit_ais(pohjant_track)
    ambient_msgs = build_ambient_ais(n_ships=800, seed=700)
    ais_all = ship_msgs + ambient_msgs
    counts["ais.ndjson"] = write_ndjson(
        OUT_REALTIME / "ais.ndjson", ais_all, "s6-drone-swarm/ais")

    snapshot_features = ais_snapshot_geojson(ais_all)
    counts["ais_snapshot.geojson"] = write_geojson(
        OUT_REALTIME / "ais_snapshot.geojson", snapshot_features,
        "s6-drone-swarm/ais_snapshot")

    # ----- Drone radar: 6 swarm tracks + patrol intercept -----
    drone_recs: list[dict[str, Any]] = []
    mac_obs: list[Any] = []

    for entry in SWARM:
        path_3d, wp2d, alts, track_id, mac_addr, entry_seed = _drone_profile(entry, 0, entry[5])

        dr_track = RadarTrack(
            track_id=track_id,
            sensor_id="RAD-PLN-01",
            waypoints=wp2d,
            cadence_s=2.0,
            classification="airborne_small",
            rcs_m2=0.025,
            confidence=0.72,
            seed=entry_seed,
        )
        recs = emit_drone_radar(dr_track, alts)
        for r in recs:
            r["kind"] = "airborne"
            r["swarm_id"] = "SW-2025-0612-01"
        drone_recs.extend(recs)

        # MAC emission along drone path (high-tx-power, long-range LOS)
        em = MovingMacEmitter(
            mac=mac_addr,
            manufacturer="DJI",
            waypoints=wp2d,
            active_windows=[(wp2d[0][0], wp2d[-1][0])],
            seed=entry_seed + 1,
        )
        mac_obs.extend(simulate_moving_mac(em, sensors, rssi_threshold=-100.0))

    # Patrol drone track
    pat_track = RadarTrack(
        track_id="T-PAT-DRN-S6-01",
        sensor_id="RAD-DRN-PAT-01",
        waypoints=PATROL_WAYPOINTS_2D,
        cadence_s=3.0,
        classification="airborne_medium",
        rcs_m2=0.15,
        confidence=0.94,
        seed=699,
    )
    pat_recs = emit_drone_radar(pat_track, PATROL_ALTITUDES)
    for r in pat_recs:
        r["kind"] = "patrol"
    drone_recs.extend(pat_recs)

    counts["drone_radar.ndjson"] = write_ndjson(
        OUT_REALTIME / "drone_radar.ndjson", drone_recs, "s6-drone-swarm/drone_radar")

    # Plane radar: overview track of mother ship + wide-area picture
    plane_wp = [(t, la, lo) for (t, la, lo) in POHJANT_WAYPOINTS]
    plane_track = RadarTrack(
        track_id="T-PLN-S6-SURF-01",
        sensor_id="RAD-PLN-01",
        waypoints=plane_wp,
        cadence_s=4.0,
        classification="surface_large",
        rcs_m2=4500.0,
        confidence=0.91,
        seed=650,
    )
    plane_recs = emit_radar(plane_track)
    # mmsi_hint present outside AIS dark window
    for r in plane_recs:
        ts = datetime.strptime(r["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=UTC)
        r["mmsi_hint"] = None if AIS_DARK_START <= ts < AIS_DARK_END else 230991601
        r["platform"] = "MAC-AIR-PLN-01"
    counts["plane_radar.ndjson"] = write_ndjson(
        OUT_REALTIME / "plane_radar.ndjson", plane_recs, "s6-drone-swarm/plane_radar")

    # ----- MAC: drone MACs + background -----
    bg = generate_background_macs(sensors, WINDOW_OPEN, WINDOW_CLOSE,
                                  mac_count=100, cadence_s=120.0, seed=42)
    mac_obs.extend(bg)
    mac_nd = [m.to_ndjson() for m in mac_obs]
    counts["mac.ndjson"] = write_ndjson(
        OUT_REALTIME / "mac.ndjson", mac_nd, "s6-drone-swarm/mac")
    mac_rows = [m.to_csv_row() for m in mac_obs]
    counts["mac.csv"] = write_csv(
        OUT_REALTIME / "mac.csv", MAC_CSV_HEADER, mac_rows, "s6-drone-swarm/mac_sessions")

    # ----- Decimated companions -----
    AIS_DECIM_FIELDS = ["timestamp", "lat", "lon", "sog_kn", "cog_deg", "nav_status"]
    DRONE_DECIM_FIELDS = ["timestamp", "lat", "lon", "alt_m", "speed_mps", "heading_deg",
                          "rcs_m2", "classification", "kind", "swarm_id"]
    decim_reports = []
    for path, kw in [
        (OUT_REALTIME / "ais.ndjson",         {"key_field": "mmsi",     "ts_field": "ts_epoch_ms",
                                               "project_fields": AIS_DECIM_FIELDS}),
        (OUT_REALTIME / "drone_radar.ndjson", {"key_field": "track_id", "ts_field": "ts_epoch_ms",
                                               "project_fields": DRONE_DECIM_FIELDS}),
    ]:
        rep = maybe_decimate_ndjson(path, **kw)
        if rep:
            decim_reports.append(rep)
            counts[Path(rep["decimated"]).name] = rep["rows"] + 1
    mac_rep = maybe_decimate_mac_ndjson(OUT_REALTIME / "mac.ndjson")
    if mac_rep:
        decim_reports.append(mac_rep)
        counts[Path(mac_rep["decimated"]).name] = mac_rep["rows"] + 1
    if decim_reports:
        print("[S6] decimated companion files:")
        for r in decim_reports:
            print(f"  {Path(r['decimated']).name}  "
                  f"{r['source_bytes']/1024/1024:.1f}MB → {r['decimated_bytes']/1024/1024:.1f}MB"
                  f"  ({r['rows']} rows)")
    return counts


def generate_static() -> dict[str, int]:
    counts: dict[str, int] = {}

    # Area of interest polygon — covers ship anchor + all six drone endpoints
    aoi = {
        "type": "Feature",
        "properties": {"featureId": "s6-aoi", "name": "S6 Area of Interest",
                       "note": "Drone swarm fan: ship anchor + 6 coastal targets"},
        "geometry": {"type": "Polygon", "coordinates": [[
            [25.0, 59.85], [27.7, 59.85], [27.7, 60.45],
            [25.0, 60.45], [25.0, 59.85],
        ]]},
    }
    counts["area_of_interest.geojson"] = write_geojson(
        OUT_STATIC / "area_of_interest.geojson", [aoi], "s6-drone-swarm/aoi")

    # Swarm fan polygon (connecting launch origin to all six target points)
    fan_coords = [[SHIP_LON, SHIP_LAT]]
    for _, tgt_lat, tgt_lon, _, _, _ in SWARM:
        fan_coords.append([tgt_lon, tgt_lat])
    fan_coords.append([SHIP_LON, SHIP_LAT])
    fan = {
        "type": "Feature",
        "properties": {"featureId": "s6-swarm-fan", "name": "Swarm fan envelope",
                       "note": "Convex envelope of all 6 drone paths"},
        "geometry": {"type": "Polygon", "coordinates": [fan_coords]},
    }
    counts["swarm_fan.geojson"] = write_geojson(
        OUT_STATIC / "swarm_fan.geojson", [fan], "s6-drone-swarm/swarm_fan")

    sensors_fc = load_sensors()
    sensor_feats = [f for f in sensors_fc["features"]
                    if f["properties"]["sensorId"] in SENSORS_USED_IDS]
    counts["sensors_used.geojson"] = write_geojson(
        OUT_STATIC / "sensors_used.geojson", sensor_feats, "s6-drone-swarm/sensors_used")

    infra_fc = load_infrastructure()
    infra_feats = [f for f in infra_fc["features"]
                   if f["properties"]["featureId"] in INFRA_USED_IDS]
    counts["infrastructure_used.geojson"] = write_geojson(
        OUT_STATIC / "infrastructure_used.geojson", infra_feats, "s6-drone-swarm/infra_used")
    return counts


def generate_historical() -> dict[str, int]:
    counts: dict[str, int] = {}
    ais_baseline: list[dict[str, Any]] = []
    baseline_days = [WINDOW_OPEN - timedelta(days=d) for d in range(1, 7)]
    for d in baseline_days:
        day_start = d.replace(hour=9, minute=0, second=0, microsecond=0)
        hist_wp = [(day_start + timedelta(seconds=int((t - WINDOW_OPEN).total_seconds())),
                    la, lo) for (t, la, lo) in POHJANT_WAYPOINTS]
        hist_track = AisTrack(mmsi=230991601, waypoints=hist_wp, cadence_s=30.0,
                              destination="FIHAN", seed=8000 + d.day)
        ais_baseline.extend(emit_ais(hist_track))
    counts["ais_baseline.ndjson"] = write_ndjson(
        OUT_HISTORICAL / "ais_baseline.ndjson", ais_baseline, "s6-drone-swarm/ais_baseline")
    return counts


def dir_size_bytes(p: Path) -> int:
    return sum(f.stat().st_size for f in p.rglob("*") if f.is_file())


def main() -> int:
    for d in [OUT_REALTIME, OUT_STATIC, OUT_HISTORICAL]:
        d.mkdir(parents=True, exist_ok=True)

    print("[S6] generating realtime layer …")
    rt = generate_realtime()
    print("[S6] generating static layer …")
    st = generate_static()
    print("[S6] generating historical layer …")
    hi = generate_historical()

    print("\n===== Scenario 06 — Drone Swarm from Ship: generation summary =====")
    for section, data in [("realtime", rt), ("static", st), ("historical", hi)]:
        print(f"\n[{section}]")
        for k, v in data.items():
            print(f"  {k:<34} rows/features={v:>8}")

    rt_bytes = dir_size_bytes(OUT_REALTIME)
    print(f"\n[on disk] realtime {rt_bytes/1024/1024:.2f} MB")
    summary = {"scenario": "s6-drone-swarm-from-ship", "realtime": rt, "static": st, "historical": hi}
    (SCENARIO_DIR / "data" / "_generation_summary.json").write_text(
        json.dumps(summary, indent=2), encoding="utf-8")
    print("[done] All files written under scenarios/06-multi-stage-combo/data/")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
