"""Shared confidence-scoring signal helpers.

All signals return values in [0, 1]. See catalogs/ontology.md for definitions.
"""
from __future__ import annotations

import math
from typing import Iterable

from .common import haversine_m


def sigmoid(x: float) -> float:
    return 1.0 / (1.0 + math.exp(-x))


def clamp(x: float, lo: float = 0.0, hi: float = 1.0) -> float:
    return max(lo, min(hi, x))


def ais_gap_score(gap_seconds: float) -> float:
    return sigmoid((gap_seconds - 300.0) / 600.0)


def ais_radar_delta_score(max_delta_m: float, threshold_m: float = 1000.0) -> float:
    return clamp(max_delta_m / threshold_m)


def mac_first_seen_ratio(new_macs: int, total_macs: int) -> float:
    if total_macs <= 0:
        return 0.0
    return clamp(new_macs / total_macs)


def mac_count_zscore_signal(count: float, mean: float, std: float) -> float:
    if std <= 0:
        return 0.0
    z = (count - mean) / std
    return sigmoid((z - 3.0) / 2.0)


def mac_fingerprint_anomaly_score(today_macs: set[str], baseline_macs: set[str]) -> float:
    if not today_macs and not baseline_macs:
        return 0.0
    inter = len(today_macs & baseline_macs)
    union = len(today_macs | baseline_macs)
    jaccard = inter / union if union else 0.0
    return clamp(1.0 - jaccard)


def jensen_shannon(p: dict[str, float], q: dict[str, float]) -> float:
    keys = set(p) | set(q)
    def norm(d: dict[str, float]) -> dict[str, float]:
        s = sum(d.values()) or 1.0
        return {k: d.get(k, 0.0) / s for k in keys}
    pn, qn = norm(p), norm(q)
    m = {k: 0.5 * (pn[k] + qn[k]) for k in keys}
    def kl(a: dict[str, float], b: dict[str, float]) -> float:
        s = 0.0
        for k in keys:
            if a[k] > 0 and b[k] > 0:
                s += a[k] * math.log(a[k] / b[k])
        return s
    return clamp(0.5 * (kl(pn, m) + kl(qn, m)))


def spatial_proximity_infra_score(distance_m: float, threshold_m: float = 1000.0) -> float:
    if distance_m >= threshold_m:
        return 0.0
    return clamp(1.0 - distance_m / threshold_m)


def temporal_dwell_score(dwell_minutes: float) -> float:
    return sigmoid((dwell_minutes - 30.0) / 30.0)


def vessel_pair_proximity_score(min_distance_m: float, duration_s: float) -> float:
    if min_distance_m > 200:
        return 0.0
    a = min(1.0, duration_s / 600.0)
    b = 1.0 - min_distance_m / 200.0
    return clamp(a * b)


def track_origin_offshore_anomaly_score(offshore_km: float, has_flight_plan: bool, threshold_km: float = 10.0) -> float:
    base = clamp(offshore_km / (2.0 * threshold_km)) if offshore_km > threshold_km else 0.0
    return 0.0 if has_flight_plan else base


def co_observation_score(distinct_macs_above_threshold: int) -> float:
    return 1.0 if distinct_macs_above_threshold >= 2 else 0.0


def duplicate_mmsi_score(min_pair_separation_km: float, threshold_km: float = 5.0) -> float:
    return 1.0 if min_pair_separation_km >= threshold_km else 0.0


def dimension_speed_plausibility_score(declared_loa_m: float, observed_loa_m: float,
                                       declared_max_kn: float, observed_kn: float) -> float:
    loa_ratio = abs(observed_loa_m - declared_loa_m) / max(declared_loa_m, 1.0)
    spd_ratio = max(0.0, observed_kn - declared_max_kn) / max(declared_max_kn, 1.0)
    return clamp(0.5 * min(1.0, loa_ratio) + 0.5 * min(1.0, spd_ratio))


def composite_incident_score(signals: dict[str, float], weights: dict[str, float]) -> float:
    """signals[name]=value, weights[name]=weight. Σw should = 1.0."""
    total_w = sum(weights.values()) or 1.0
    s = 0.0
    for name, w in weights.items():
        s += (w / total_w) * float(signals.get(name, 0.0))
    return clamp(s)
