import random
import copy
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional

# ---------------------------------------------------------------------------
# Hex grid (axial coordinates, flat-top, radius 4 = 5 tiles per edge)
# ---------------------------------------------------------------------------

GRID_RADIUS = 3

def hex_range(radius: int) -> list[tuple[int, int]]:
    hexes = []
    for q in range(-radius, radius + 1):
        r1 = max(-radius, -q - radius)
        r2 = min(radius, -q + radius)
        for r in range(r1, r2 + 1):
            hexes.append((q, r))
    return hexes

def hex_distance(q1: int, r1: int, q2: int = 0, r2: int = 0) -> int:
    return (abs(q1 - q2) + abs(q1 + r1 - q2 - r2) + abs(r1 - r2)) // 2

HEX_DIRECTIONS = [(1,0),(-1,0),(0,1),(0,-1),(1,-1),(-1,1)]
ALL_HEXES = set(hex_range(GRID_RADIUS))

def valid_neighbours(q: int, r: int) -> list[tuple[int, int]]:
    return [(q + dq, r + dr) for dq, dr in HEX_DIRECTIONS if (q+dq, r+dr) in ALL_HEXES]


# ---------------------------------------------------------------------------
# Population setup
# ---------------------------------------------------------------------------

PLAYERS = ["orange", "green", "red", "purple", "pink", "yellow"]
IMMUNE_PLAYER = "pink"

def mortal_population(hex_pops: dict) -> int:
    """Total non-immune population in a hex."""
    return sum(v for p, v in hex_pops.items() if p != IMMUNE_PLAYER and v > 0)

def make_populations(seed_populations: Optional[dict] = None) -> dict[tuple, dict]:
    if seed_populations:
        return copy.deepcopy(seed_populations)

    rng = random.Random(42)
    pops = {}
    for hex_coord in ALL_HEXES:
        dist = hex_distance(*hex_coord)
        base = max(0, 500 - dist * 60)
        hex_pop = {}
        for player in PLAYERS:
            if player == IMMUNE_PLAYER:
                hex_pop[player] = rng.randint(10, 50)
            else:
                hex_pop[player] = rng.randint(0, base) if base > 0 else 0
        pops[hex_coord] = hex_pop
    return pops


# ---------------------------------------------------------------------------
# Infection mechanic
# ---------------------------------------------------------------------------

def infectable_candidates(
    current: tuple[int, int],
    populations: dict,
    cleared: set,
) -> list[tuple[int, int]]:
    """
    Systems the infection can move to: current + neighbours,
    excluding cleared hexes and hexes with no mortal population.
    """
    candidates = [current] + valid_neighbours(*current)
    return [
        h for h in candidates
        if h not in cleared
        and mortal_population(populations[h]) > 0
    ]

def roll_d10() -> int:
    return random.randint(1, 10)

import math
from pathlib import Path

# ---------------------------------------------------------------------------
# HTML visualisation helpers (casualties)
# ---------------------------------------------------------------------------

# reuse PLAYER_COLOURS from population.py or define here
PLAYER_COLOURS = {
    "orange": "#e65c00",
    "green":  "#2e7d32",
    "red":    "#c62828",
    "pink":   "#c2185b",
    "yellow": "#f9a825",
    "purple": "#6a1b9a",
}

def _hex_points(x: float, y: float, size: float) -> str:
    pts = []
    for i in range(6):
        a = math.pi / 3 * i
        pts.append(f"{x + size * math.cos(a):.1f},{y + size * math.sin(a):.1f}")
    return " ".join(pts)

def _axial_to_pixel(q: int, r: int, cx: float, cy: float, hex_size: float) -> tuple[float, float]:
    x = cx + hex_size * (3/2 * q)
    y = cy + hex_size * (math.sqrt(3)/2 * q + math.sqrt(3) * r)
    return x, y

def _dominant_player(cas: dict) -> tuple[str, float] | None:
    mortal = {p: v for p, v in cas.items() if p != IMMUNE_PLAYER and v > 0}
    if not mortal:
        return None
    dominant = max(mortal, key=mortal.get)
    total = sum(mortal.values())
    return dominant, mortal[dominant] / total

def _build_hex_svg(
    data: dict[tuple, dict],   # hex -> {player: count (float or int)}
    hex_size: float = 54,
    W: float = 760,
    H: float = 540,
) -> tuple[str, int]:
    """
    Build SVG hex elements for a casualties map.
    Returns (svg_elements_str, max_total).
    """
    CX, CY = W / 2, H / 2

    max_total = max(
        (sum(v for p, v in cas.items() if p != IMMUNE_PLAYER) for cas in data.values()),
        default=1
    )
    if max_total == 0:
        max_total = 1

    elements = []

    for (q, r), cas in data.items():
        x, y    = _axial_to_pixel(q, r, CX, CY, hex_size)
        total   = sum(v for p, v in cas.items() if p != IMMUNE_PLAYER)
        pts     = _hex_points(x, y, hex_size - 2)
        is_empty = total < 0.1

        # greyscale fill
        if is_empty:
            grey_fill = "transparent"
            stroke    = "rgba(0,0,0,0.10)"
        else:
            ratio     = total / max_total
            lightness = int(210 - ratio * 170)
            grey_fill = f"rgb({lightness},{lightness},{lightness})"
            stroke    = "rgba(0,0,0,0.25)"

        text_col = "#fff" if (not is_empty and total / max_total > 0.5) else "#222"

        dom = _dominant_player(cas)
        tint_col     = PLAYER_COLOURS[dom[0]] if dom else "transparent"
        tint_opacity = round(0.25 + dom[1] * 0.45, 2) if dom else 0

        # greyscale base
        elements.append(
            f'<polygon points="{pts}" fill="{grey_fill}" '
            f'stroke="{stroke}" stroke-width="1" class="hex-grey"/>'
        )
        # tint overlay
        elements.append(
            f'<polygon points="{pts}" fill="{tint_col}" '
            f'opacity="{tint_opacity}" stroke="none" '
            f'class="hex-tint" style="display:none"/>'
        )
        # hover target
        elements.append(
            f'<polygon points="{pts}" fill="transparent" stroke="none" '
            f'class="hex-hover" data-q="{q}" data-r="{r}" '
            f'data-total="{total:.1f}" style="cursor:pointer"/>'
        )
        # coord
        elements.append(
            f'<text x="{x:.1f}" y="{y - hex_size*0.52:.1f}" '
            f'text-anchor="middle" font-size="9" fill="rgba(0,0,0,0.35)">{q},{r}</text>'
        )
        if not is_empty:
            elements.append(
                f'<text x="{x:.1f}" y="{y + 5:.1f}" '
                f'text-anchor="middle" font-size="13" font-weight="bold" fill="{text_col}">'
                f'{total:.1f}</text>'
            )

    return "\n    ".join(elements), int(max_total)


def _html_wrapper(
    svg_str: str,
    title: str,
    subtitle: str,
    viewbox: str,
    W: float,
    H: float,
    PAD: float = 40,
) -> str:
    legend_items = []
    for player in PLAYERS:
        col    = PLAYER_COLOURS[player]
        immune = " (immune)" if player == IMMUNE_PLAYER else ""
        legend_items.append(
            f'<span style="display:inline-flex;align-items:center;gap:4px;margin-right:12px">'
            f'<span style="width:10px;height:10px;border-radius:2px;background:{col};display:inline-block"></span>'
            f'<span style="font-size:11px;color:#666">{player}{immune}</span>'
            f'</span>'
        )
    legend_str = "\n".join(legend_items)

    scale_bar = (
        '<div style="display:flex;align-items:center;gap:8px;margin-top:4px;font-size:11px;color:#666">'
        '<span>no casualties</span>'
        '<div style="width:160px;height:10px;border-radius:3px;'
        'background:linear-gradient(to right,transparent,rgb(210,210,210),rgb(40,40,40));'
        'border:0.5px solid rgba(0,0,0,0.15)"></div>'
        '<span>high casualties</span>'
        '</div>'
    )

    return f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>{title}</title>
<style>
  body {{ margin:0; font-family:sans-serif; background:transparent; }}
  .hex-grey, .hex-tint {{ transition: opacity 0.2s; }}
  #tooltip {{
    display:none; position:fixed; bottom:12px; left:50%;
    transform:translateX(-50%);
    background:rgba(0,0,0,0.75); color:#fff;
    font-size:12px; padding:4px 10px; border-radius:4px; pointer-events:none;
  }}
  .toggle-track {{
    position:relative; width:32px; height:18px;
    background:#ccc; border-radius:9px;
    transition:background .2s; display:inline-block;
  }}
  .toggle-track:has(input:checked) {{ background:#1565c0; }}
  #tint-toggle:checked + span {{ transform:translateX(14px); }}
  h2 {{ margin:8px 16px 0; font-size:14px; font-weight:500; color:#333; }}
  p.sub {{ margin:2px 16px 6px; font-size:11px; color:#888; }}
</style>
</head>
<body>
<h2>{title}</h2>
<p class="sub">{subtitle}</p>
<div style="padding:4px 16px;display:flex;align-items:center;gap:12px;font-size:12px;color:#666">
  <label style="display:flex;align-items:center;gap:6px;cursor:pointer">
    <span class="toggle-track">
      <input type="checkbox" id="tint-toggle"
        style="position:absolute;opacity:0;width:100%;height:100%;cursor:pointer;margin:0"
        onchange="toggleTint(this.checked)">
      <span style="position:absolute;top:2px;left:2px;width:14px;height:14px;
        background:#fff;border-radius:50%;transition:transform .2s;pointer-events:none"></span>
    </span>
    Player colours
  </label>
</div>
<svg width="100%" viewBox="{viewbox}" xmlns="http://www.w3.org/2000/svg">
    {svg_str}
</svg>
<div style="padding:2px 16px 8px">
  {legend_str}
  {scale_bar}
</div>
<div id="tooltip"></div>
<script>
function toggleTint(on) {{
  document.querySelectorAll('.hex-tint').forEach(el => {{
    el.style.display = on ? 'block' : 'none';
  }});
}}
document.querySelectorAll('.hex-hover').forEach(h => {{
  const tip = document.getElementById('tooltip');
  h.addEventListener('mouseenter', () => {{
    tip.textContent = '(' + h.dataset.q + ', ' + h.dataset.r + ')  casualties: ' + h.dataset.total;
    tip.style.display = 'block';
  }});
  h.addEventListener('mouseleave', () => {{ tip.style.display = 'none'; }});
}});
</script>
</body>
</html>
"""


def save_run_html(
    casualties: dict[tuple, dict],
    run_index: int,
    output_dir: str = "runs",
    hex_size: float = 54,
    W: float = 760,
    H: float = 540,
    PAD: float = 40,
) -> None:
    """Save a single run's casualties as an HTML map."""
    Path(output_dir).mkdir(exist_ok=True)
    svg_str, max_total = _build_hex_svg(casualties, hex_size, W, H)
    viewbox = f"{-PAD} {-PAD} {W + PAD*2} {H + PAD*2}"
    total_dead = sum(
        v for cas in casualties.values()
        for p, v in cas.items() if p != IMMUNE_PLAYER
    )
    html = _html_wrapper(
        svg_str,
        title=f"Run {run_index + 1} — Casualties",
        subtitle=f"Total casualties: {total_dead}  ·  max in any system: {max_total}",
        viewbox=viewbox,
        W=W, H=H, PAD=PAD,
    )
    path = Path(output_dir) / f"run_{run_index + 1:04d}.html"
    path.write_text(html)


def save_average_html(
    mean_casualties: dict[tuple, dict],
    n: int,
    output_path: str = "average_casualties.html",
    hex_size: float = 54,
    W: float = 760,
    H: float = 540,
    PAD: float = 40,
) -> None:
    """Save the averaged casualties map as an HTML map."""
    svg_str, max_total = _build_hex_svg(mean_casualties, hex_size, W, H)
    viewbox = f"{-PAD} {-PAD} {W + PAD*2} {H + PAD*2}"
    avg_total = sum(
        v for cas in mean_casualties.values()
        for p, v in cas.items() if p != IMMUNE_PLAYER
    )
    html = _html_wrapper(
        svg_str,
        title=f"Average Casualties — {n} runs",
        subtitle=f"Mean total casualties: {avg_total:.1f}  ·  peak system mean: {max_total:.1f}",
        viewbox=viewbox,
        W=W, H=H, PAD=PAD,
    )
    Path(output_path).write_text(html)
    print(f"[+] average casualties map → {output_path}")

def run_simulation(
    seed_populations: Optional[dict] = None,
    origin: tuple[int, int] = (0, 0),
    max_steps: int = 10000,
) -> dict[tuple, dict]:
    """
    Run one infection simulation from origin.

    Mechanic per step:
      - Roll d10 in current system
      - <=2: infection halts entirely
      - >=3: kill one unit, then move to a valid adjacent-or-same system
             (must have mortal pop, must not be cleared)
      - If no valid move exists: infection halts
    
    Returns casualties map: {hex_coord: {player: int}}
    """
    if origin not in ALL_HEXES:
        raise ValueError(f"Origin {origin} is not on the map")

    populations = make_populations(seed_populations)
    casualties = {
        hex_coord: {player: 0 for player in PLAYERS}
        for hex_coord in ALL_HEXES
    }
    cleared = set()   # hexes with no remaining mortal population

    current = origin

    # origin must have mortal population to start
    if mortal_population(populations[current]) == 0:
        return casualties

    for _ in range(max_steps):
        roll = roll_d10()

        if roll <= 2:
            # infection halts
            break

        # roll >= 3: kill one unit in current system
        # pick a random mortal player with population > 0
        mortal_players = [
            p for p in PLAYERS
            if p != IMMUNE_PLAYER and populations[current].get(p, 0) > 0
        ]
        if not mortal_players:
            # shouldn't happen if cleared tracking is correct, but guard anyway
            cleared.add(current)
            break

        victim = random.choice(mortal_players)
        populations[current][victim] -= 1
        casualties[current][victim] += 1

        # check if hex is now cleared
        if mortal_population(populations[current]) == 0:
            cleared.add(current)

        # find valid next systems
        candidates = infectable_candidates(current, populations, cleared)

        if not candidates:
            # nowhere left to spread — halt
            break

        current = random.choice(candidates)

    return casualties


# ---------------------------------------------------------------------------
# Monte Carlo
# ---------------------------------------------------------------------------

def monte_carlo(
    n: int = 1000,
    seed_populations: Optional[dict] = None,
    origin: tuple[int, int] = (0, 0),
    max_steps: int = 10000,
    verbose: bool = True,
    save_runs: bool = True,          # save each run as HTML
    runs_dir: str = "runs",          # output folder for per-run maps
    save_average: bool = True,       # save averaged map
    average_path: str = "average_casualties.html",
) -> dict:
    runs = []
    total_per_run = []

    for i in range(n):
        if verbose and i % max(1, n // 10) == 0:
            print(f"[+] run {i}/{n}")

        casualties = run_simulation(
            seed_populations=seed_populations,
            origin=origin,
            max_steps=max_steps,
        )
        runs.append(casualties)

        if save_runs:
            save_run_html(casualties, run_index=i, output_dir=runs_dir)

        total = sum(
            v for hex_cas in casualties.values()
            for p, v in hex_cas.items() if p != IMMUNE_PLAYER
        )
        total_per_run.append(total)

    mean_casualties = {
        hex_coord: {
            player: sum(run[hex_coord][player] for run in runs) / n
            for player in PLAYERS
        }
        for hex_coord in ALL_HEXES
    }

    if save_average:
        save_average_html(mean_casualties, n=n, output_path=average_path)

    if verbose:
        avg = sum(total_per_run) / n
        print(f"[+] done — avg total casualties: {avg:.1f} over {n} runs")

    return {
        "mean_casualties": mean_casualties,
        "runs": runs,
        "total_casualties_per_run": total_per_run,
        "origin": origin,
        "n": n,
    }

# ---------------------------------------------------------------------------
# Summary helpers
# ---------------------------------------------------------------------------

def casualties_by_player(mean_casualties: dict) -> dict[str, float]:
    totals = defaultdict(float)
    for hex_cas in mean_casualties.values():
        for player, count in hex_cas.items():
            totals[player] += count
    return dict(totals)

def casualties_by_hex(mean_casualties: dict) -> dict[tuple, float]:
    return {
        hex_coord: sum(v for p, v in hex_cas.items() if p != IMMUNE_PLAYER)
        for hex_coord, hex_cas in mean_casualties.items()
    }

def print_summary(results: dict) -> None:
    mean = results["mean_casualties"]
    totals = results["total_casualties_per_run"]
    sorted_totals = sorted(totals)
    n = results["n"]

    print(f"\n{'='*50}")
    print(f"Monte Carlo summary — {n} runs")
    print(f"Origin: {results['origin']}")
    print(f"{'='*50}")

    print("\nMean casualties by player:")
    for player, count in sorted(casualties_by_player(mean).items()):
        immune = " [IMMUNE]" if player == IMMUNE_PLAYER else ""
        print(f"  {player:12s}: {count:8.2f}{immune}")

    print(f"\nCasualties per run:")
    print(f"  mean : {sum(totals)/n:.1f}")
    print(f"  min  : {min(totals)}")
    print(f"  max  : {max(totals)}")
    print(f"  p10  : {sorted_totals[n//10]}")
    print(f"  p90  : {sorted_totals[int(n*0.9)]}")

    print("\nTop 5 hexes by mean casualties:")
    hex_totals = casualties_by_hex(mean)
    for coord, count in sorted(hex_totals.items(), key=lambda x: -x[1])[:5]:
        dist = hex_distance(*coord)
        print(f"  {str(coord):12s} dist={dist}  mean casualties: {count:.2f}")


COLOUR_CODES = {
    "o":  "orange",
    "g":  "green",
    "r":  "red",
    "pi": "pink",
    "y":  "yellow",
    "pu": "purple",
}

PLAYERS       = list(COLOUR_CODES.values())
IMMUNE_PLAYER = "pink"

# ---------------------------------------------------------------------------
# Scan order — ring 4 down to 0, clockwise from top each ring
# ---------------------------------------------------------------------------

RING_DIRS = [
    ( 1,  0),   # SE
    ( 0,  1),   # S
    (-1,  1),   # SW
    (-1,  0),   # NW
    ( 0, -1),   # N
    ( 1, -1),   # NE
]

def ring_hexes(radius: int) -> list[tuple[int, int]]:
    if radius == 0:
        return [(0, 0)]
    hexes = []
    q, r = 0, -radius          # start at top of ring
    for dq, dr in RING_DIRS:
        for _ in range(radius):
            hexes.append((q, r))
            q += dq
            r += dr
    return hexes

import re

TOKEN_RE = re.compile(r'(\d+)(o|pi|pu|g|r|y)', re.IGNORECASE)

def parse_hex_line(line: str) -> dict[str, int]:
    """
    Parse one hex line e.g. '2o 1pi 3g' → {'orange':2, 'pink':1, 'green':3, ...rest 0}
    Blank line or '-' → all zeros.
    """
    pop = {p: 0 for p in PLAYERS}
    line = line.strip()
    if not line or line == '-':
        return pop

    # sort by length desc so 'pi'/'pu' are matched before 'p' (if ever added)
    matches = TOKEN_RE.findall(line)
    if not matches:
        raise ValueError(f"Could not parse line: {repr(line)}")

    for count_str, code in matches:
        player = COLOUR_CODES[code.lower()]
        pop[player] += int(count_str)

    return pop

def scan_order(grid_radius: int = 4) -> list[tuple[int, int]]:
    order = []
    for ring in range(grid_radius, -1, -1):
        order.extend(ring_hexes(ring))
    return order

SCAN_ORDER = scan_order(3)   # 61 hexes


def load_population(path: str = "population.txt") -> dict[tuple, dict]:
    """
    Read population.txt and return a seed dict for monte_carlo().

    File format — one entry per hex in scan order (37 hexes, radius 3).
    Each line is space-separated tokens: <count><colour_code>
    Colour codes: o=orange  g=green  r=red  pi=pink  y=yellow  pu=purple
    Blank lines or '-' = empty hex.
    Lines starting with # are ignored as comments.

    Example:
        2o 1pi
        3g
        -
        1o 1g 1r 1pi 1y 2pu
    """
    with open(path) as f:
        lines = [l for l in f.readlines() if not l.strip().startswith('#')]

    if len(lines) != len(SCAN_ORDER):
        raise ValueError(
            f"{path} has {len(lines)} entries, expected {len(SCAN_ORDER)}. "
            f"Use blank lines or '-' for empty hexes."
        )

    return {
        hex_coord: parse_hex_line(line)
        for hex_coord, line in zip(SCAN_ORDER, lines)
    }

# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":

    start_pop = load_population("population.txt")

    # load from top, go clockwise
    # format is 1y = 1 yellow
    # - = empty system

    # change origin to any valid axial coord on the grid
    results = monte_carlo(
        seed_populations=start_pop,
        n=500,
        origin=(-1, 1),
        verbose=True,
    )
    print_summary(results)