"""
Synthetic satellite image generator for Scenario S8: Red Vessel Escalation.
Generates 3 PNG images simulating Sentinel-2-style maritime surveillance imagery.
"""

import os
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle, FancyArrowPatch
from matplotlib.patheffects import withStroke

RNG = np.random.default_rng(42)

OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "data", "static")
SEA_COLOR = "#1a4a5c"
DPI = 100
W, H = 1200, 900  # pixels → data coords match pixel coords


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def make_fig():
    fig, ax = plt.subplots(figsize=(W / DPI, H / DPI), dpi=DPI)
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
    ax.set_xlim(0, W)
    ax.set_ylim(0, H)
    ax.set_aspect("equal")
    ax.axis("off")
    ax.set_facecolor(SEA_COLOR)
    fig.patch.set_facecolor(SEA_COLOR)
    return fig, ax


def add_sea_noise(ax):
    noise = RNG.random((H // 4, W // 4))
    ax.imshow(
        noise,
        extent=[0, W, 0, H],
        origin="lower",
        alpha=0.12,
        cmap="Blues",
        aspect="auto",
        interpolation="bilinear",
        zorder=1,
    )
    # subtle wave-like horizontal banding
    banding = np.sin(np.linspace(0, 30, H // 4))[:, None] * np.ones((1, W // 4))
    ax.imshow(
        banding * 0.5 + 0.5,
        extent=[0, W, 0, H],
        origin="lower",
        alpha=0.04,
        cmap="GnBu",
        aspect="auto",
        zorder=1,
    )


def add_coastline(ax):
    """Faint tan coastline strip along the top edge."""
    coast_y = H - 55
    xs = np.linspace(0, W, 300)
    ys = coast_y + RNG.normal(0, 6, 300).cumsum() * 0.08
    ys = np.clip(ys, coast_y - 20, coast_y + 35)
    ax.fill_between(xs, ys, H, color="#c8b88a", alpha=0.55, zorder=2)
    ax.fill_between(xs, ys - 6, ys, color="#a09070", alpha=0.25, zorder=2)


def draw_ship(ax, cx, cy, angle_deg=0):
    """Draw a stylised old cargo ship centred at (cx, cy)."""
    ship_w, ship_h = 60, 18  # length × beam

    # Hull — dark gray elongated rectangle
    hull = Rectangle(
        (cx - ship_w / 2, cy - ship_h / 2),
        ship_w,
        ship_h,
        linewidth=0.5,
        edgecolor="#2a2a2a",
        facecolor="#4a4a4a",
        zorder=5,
    )
    ax.add_patch(hull)

    # Bow taper (triangle at left end)
    bow = plt.Polygon(
        [
            [cx - ship_w / 2, cy - ship_h / 2],
            [cx - ship_w / 2, cy + ship_h / 2],
            [cx - ship_w / 2 - 12, cy],
        ],
        closed=True,
        facecolor="#3a3a3a",
        edgecolor="#2a2a2a",
        linewidth=0.5,
        zorder=5,
    )
    ax.add_patch(bow)

    # Bridge superstructure at stern (right end)
    bridge_x = cx + ship_w / 2 - 16
    bridge = Rectangle(
        (bridge_x, cy - ship_h / 2 + 2),
        14,
        ship_h - 4,
        linewidth=0.4,
        edgecolor="#555555",
        facecolor="#6a6a6a",
        zorder=6,
    )
    ax.add_patch(bridge)

    # Rust streak lines
    for _ in range(6):
        rx = cx - ship_w / 2 + RNG.uniform(5, ship_w - 5)
        ry = cy - ship_h / 2 + RNG.uniform(1, ship_h - 1)
        ax.plot(
            [rx, rx + RNG.uniform(-2, 2)],
            [ry, ry - RNG.uniform(1, 4)],
            color="#8b4513",
            alpha=0.45,
            linewidth=0.8,
            zorder=6,
        )


def draw_wake(ax, cx, cy, length=140):
    """Draw a light wake trail behind (to the right of) the ship."""
    for i, alpha in enumerate(np.linspace(0.35, 0.05, 8)):
        spread = i * 2.5
        x_start = cx + 30 + i * (length / 8)
        ax.plot(
            [cx + 30, x_start + length / 8],
            [cy + spread, cy + spread],
            color="#a8d8ea",
            alpha=alpha,
            linewidth=max(0.4, 1.5 - i * 0.15),
            zorder=4,
        )
        ax.plot(
            [cx + 30, x_start + length / 8],
            [cy - spread, cy - spread],
            color="#a8d8ea",
            alpha=alpha,
            linewidth=max(0.4, 1.5 - i * 0.15),
            zorder=4,
        )


def add_caption(ax, text):
    ax.text(
        18,
        16,
        text,
        color="white",
        fontsize=8.5,
        fontfamily="monospace",
        verticalalignment="bottom",
        zorder=10,
        path_effects=[withStroke(linewidth=2, foreground="black")],
    )


def save(fig, name):
    path = os.path.join(OUTPUT_DIR, name)
    fig.savefig(path, dpi=DPI, bbox_inches="tight", facecolor=SEA_COLOR)
    plt.close(fig)
    print(f"  Saved: {path}  ({os.path.getsize(path) // 1024} KB)")


# ---------------------------------------------------------------------------
# Image 1: Vessel at hold position
# ---------------------------------------------------------------------------

def generate_image_01():
    fig, ax = make_fig()
    add_sea_noise(ax)
    add_coastline(ax)

    ship_cx, ship_cy = 600, 420
    draw_wake(ax, ship_cx, ship_cy)
    draw_ship(ax, ship_cx, ship_cy)

    add_caption(ax, "2025-07-03 09:00 UTC \u2014 MV KASPIYSK at hold position")
    save(fig, "satellite_01_vessel_holding.png")


# ---------------------------------------------------------------------------
# Image 2: Drones launching from bow
# ---------------------------------------------------------------------------

DRONE_OFFSETS_CLOSE = [
    (-38, 6),
    (-44, 0),
    (-40, -7),
    (-34, 10),
    (-46, -3),
]

DRONE_COLORS = ["#ff8c00", "#ffa500", "#ff7f00", "#ffb347", "#e8780a"]


def generate_image_02():
    fig, ax = make_fig()
    add_sea_noise(ax)
    add_coastline(ax)

    ship_cx, ship_cy = 600, 420
    draw_wake(ax, ship_cx, ship_cy)
    draw_ship(ax, ship_cx, ship_cy)

    # Drones clustered at bow
    for (dx, dy), col in zip(DRONE_OFFSETS_CLOSE, DRONE_COLORS):
        x = ship_cx + dx
        y = ship_cy + dy
        # tiny motion blur
        ax.plot(
            [x - 3, x + 1],
            [y, y + 0.5],
            color=col,
            alpha=0.5,
            linewidth=1,
            zorder=7,
        )
        circle = plt.Circle((x, y), 4, color=col, zorder=8)
        ax.add_patch(circle)

    add_caption(ax, "2025-07-03 09:32 UTC \u2014 5 drone contacts detected at vessel")
    save(fig, "satellite_02_drones_launching.png")


# ---------------------------------------------------------------------------
# Image 3: Drones dispersing, vessel departing south
# ---------------------------------------------------------------------------

# Fan pattern toward top-left (coast direction)
DRONE_OFFSETS_FAN = [
    (-80, 110),
    (-130, 80),
    (-50, 140),
    (-160, 50),
    (-100, 55),
]


def generate_image_03():
    fig, ax = make_fig()
    add_sea_noise(ax)
    add_coastline(ax)

    # Ship shifted ~30px toward bottom (departing south)
    ship_cx, ship_cy = 600, 390
    draw_wake(ax, ship_cx, ship_cy, length=100)
    draw_ship(ax, ship_cx, ship_cy)

    # Drone origin (bow of ship)
    origin_x = ship_cx - 42
    origin_y = ship_cy

    for i, ((dx, dy), col) in enumerate(zip(DRONE_OFFSETS_FAN, DRONE_COLORS), start=1):
        x = origin_x + dx
        y = origin_y + dy

        # Green dashed trajectory line
        ax.plot(
            [origin_x, x],
            [origin_y, y],
            color="#00cc44",
            alpha=0.55,
            linewidth=0.9,
            linestyle="--",
            zorder=6,
        )

        # Red analysis circle
        ring = plt.Circle(
            (x, y),
            10,
            color="#ff2222",
            fill=False,
            linewidth=1.0,
            alpha=0.75,
            zorder=9,
        )
        ax.add_patch(ring)

        # Drone dot
        dot = plt.Circle((x, y), 4.5, color=col, zorder=10)
        ax.add_patch(dot)

        # Label
        ax.text(
            x + 12,
            y + 4,
            f"D{i}",
            color="white",
            fontsize=6.5,
            fontfamily="monospace",
            zorder=11,
            path_effects=[withStroke(linewidth=1.5, foreground="black")],
        )

    add_caption(ax, "2025-07-03 09:38 UTC \u2014 swarm dispersing, vessel departing")
    save(fig, "satellite_03_drones_dispersing.png")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print("Generating satellite images for Scenario S8...")
    generate_image_01()
    generate_image_02()
    generate_image_03()
    print("Done.")
