#!/usr/bin/env python3
"""Pyvorin Edge Benchmark CLI.

Measures latency percentiles, throughput, memory usage, and reduction ratio
for Edge workloads running under CPython and Pyvorin.
"""

import argparse
import io
import json
import math
import sys
import time
import tracemalloc
from contextlib import redirect_stdout
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional, Tuple

try:
    import pyvorin

    _HAS_PYVORIN = True
except Exception:  # pragma: no cover
    _HAS_PYVORIN = False

SUITES = [
    "sensor_pipeline",
    "window_aggregation",
    "privacy_filter",
    "cloud_sync",
    "full_stack",
]

_WorkloadDef = Tuple[str, str, Callable[[], Any]]


def _get_workloads() -> Dict[str, _WorkloadDef]:
    return {
        "sensor_pipeline": (
            "sensor_pipeline",
            """
def sensor_pipeline(data):
    total = 0.0
    count = 0
    for v in data:
        if v > 10.0:
            total += v
            count += 1
    return total / count if count else 0.0
""",
            lambda: [float(i) for i in range(200)],
        ),
        "window_aggregation": (
            "window_aggregation",
            """
def window_aggregation(data, window_size):
    results = []
    n = len(data)
    for i in range(n - window_size + 1):
        s = 0.0
        for j in range(window_size):
            s += data[i + j]
        results.append(s / window_size)
    return results
""",
            lambda: ([float(i % 50) for i in range(200)], 10),
        ),
        "privacy_filter": (
            "privacy_filter",
            """
def privacy_filter(values):
    out = []
    for v in values:
        if v > 100:
            v = 100
        out.append(v + 1)
    return out
""",
            lambda: [i % 150 for i in range(200)],
        ),
        "cloud_sync": (
            "cloud_sync",
            """
def cloud_sync(batch):
    total = 0
    for item in batch:
        total += len(str(item))
    return total
""",
            lambda: [i * 123 for i in range(100)],
        ),
        "full_stack": (
            "full_stack",
            """
def full_stack(data):
    # sensor pipeline
    total = 0.0
    count = 0
    for v in data["sensor"]:
        if v > 10.0:
            total += v
            count += 1
    avg = total / count if count else 0.0

    # window aggregation
    arr = data["window"]
    ws = data["window_size"]
    w_results = []
    n = len(arr)
    for i in range(n - ws + 1):
        s = 0.0
        for j in range(ws):
            s += arr[i + j]
        w_results.append(s / ws)

    # privacy filter
    out = []
    for v in data["privacy"]:
        if v > 100:
            v = 100
        out.append(v + 1)

    # cloud sync
    total_len = 0
    for item in data["cloud"]:
        total_len += len(str(item))

    return {
        "avg": avg,
        "windows": len(w_results),
        "privacy": len(out),
        "sync_len": total_len,
    }
""",
            lambda: {
                "sensor": [float(i) for i in range(200)],
                "window": [float(i % 50) for i in range(200)],
                "window_size": 10,
                "privacy": [i % 150 for i in range(200)],
                "cloud": [i * 123 for i in range(100)],
            },
        ),
    }


def _percentile(sorted_values: List[float], p: float) -> float:
    if not sorted_values:
        return 0.0
    k = (len(sorted_values) - 1) * p / 100.0
    f = math.floor(k)
    c = math.ceil(k)
    if f == c:
        return sorted_values[int(k)]
    return sorted_values[f] * (c - k) + sorted_values[c] * (k - f)


def _make_python_func(source_code: str, func_name: str) -> Callable[..., Any]:
    namespace: Dict[str, Any] = {}
    exec(source_code, namespace)
    return namespace[func_name]


def _measure(func: Callable[..., Any], args: Any, iterations: int) -> Dict[str, Any]:
    warmup = max(1, iterations // 10)
    for _ in range(warmup):
        if isinstance(args, tuple):
            func(*args)
        else:
            func(args)

    times: List[float] = []
    tracemalloc.start()
    overall_start = time.perf_counter()
    for _ in range(iterations):
        t0 = time.perf_counter()
        if isinstance(args, tuple):
            func(*args)
        else:
            func(args)
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000.0)
    overall_elapsed = time.perf_counter() - overall_start
    _, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    times.sort()
    p50 = _percentile(times, 50.0)
    p95 = _percentile(times, 95.0)
    p99 = _percentile(times, 99.0)
    throughput = iterations / overall_elapsed if overall_elapsed > 0 else 0.0
    memory_mb = peak / (1024 * 1024)

    return {
        "latency_ms": {
            "p50": round(p50, 4),
            "p95": round(p95, 4),
            "p99": round(p99, 4),
        },
        "throughput_ops_sec": round(throughput, 2),
        "memory_peak_mb": round(memory_mb, 4),
    }


def _compile_pyvorin(source_code: str, func_name: str) -> Optional[Callable[..., Any]]:
    if not _HAS_PYVORIN:
        return None
    try:
        compiler = pyvorin.PyvorinCompiler()
        with redirect_stdout(io.StringIO()):
            compiled = compiler.compile(source_code, function_name=func_name)
        return compiled
    except Exception:
        return None


def _run_single(suite_name: str, iterations: int) -> Dict[str, Any]:
    workloads = _get_workloads()
    if suite_name not in workloads:
        raise ValueError(f"Unknown suite: {suite_name}")

    func_name, source, data_factory = workloads[suite_name]
    func = _make_python_func(source, func_name)
    args = data_factory()
    metrics = _measure(func, args, iterations)
    metrics["reduction_ratio"] = None
    return {suite_name: metrics}


def _compare_single(suite_name: str, iterations: int) -> Dict[str, Any]:
    workloads = _get_workloads()
    if suite_name not in workloads:
        raise ValueError(f"Unknown suite: {suite_name}")

    func_name, source, data_factory = workloads[suite_name]
    args = data_factory()
    py_func = _make_python_func(source, func_name)

    cpython_metrics = _measure(py_func, args, iterations)

    compiled_func = _compile_pyvorin(source, func_name)
    pyvorin_metrics: Optional[Dict[str, Any]] = None
    backend = "unavailable"
    if compiled_func is not None:
        try:
            pyvorin_metrics = _measure(compiled_func, args, iterations)
            backend = getattr(compiled_func, "backend_used", "unknown")
        except Exception:
            compiled_func = None
            backend = "failed"

    reduction_ratio: Optional[float] = None
    if (
        pyvorin_metrics is not None
        and cpython_metrics["latency_ms"]["p50"] is not None
        and cpython_metrics["latency_ms"]["p50"] > 0
    ):
        reduction_ratio = round(
            1.0
            - (
                pyvorin_metrics["latency_ms"]["p50"]
                / cpython_metrics["latency_ms"]["p50"]
            ),
            4,
        )

    return {
        suite_name: {
            "cpython": cpython_metrics,
            "pyvorin": pyvorin_metrics
            if pyvorin_metrics is not None
            else {
                "latency_ms": {"p50": None, "p95": None, "p99": None},
                "throughput_ops_sec": None,
                "memory_peak_mb": None,
            },
            "reduction_ratio": reduction_ratio,
            "backend_used": backend,
        }
    }


def _write_output(data: Dict[str, Any], path: Optional[str]) -> None:
    payload = json.dumps(data, indent=2)
    if path:
        with open(path, "w", encoding="utf-8") as f:
            f.write(payload)
        print(f"Results written to {path}")
    else:
        print(payload)


def cmd_run(args: argparse.Namespace) -> int:
    try:
        result = {
            "command": "run",
            "suite": args.suite,
            "iterations": args.iterations,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "results": _run_single(args.suite, args.iterations),
        }
        _write_output(result, args.output)
        return 0
    except Exception as exc:
        print(f"Error: {exc}", file=sys.stderr)
        return 1


def cmd_compare(args: argparse.Namespace) -> int:
    try:
        result = {
            "command": "compare",
            "suite": args.suite,
            "iterations": args.iterations,
            "timestamp": datetime.now(timezone.utc).isoformat(),
            "results": _compare_single(args.suite, args.iterations),
        }
        _write_output(result, args.output)
        return 0
    except Exception as exc:
        print(f"Error: {exc}", file=sys.stderr)
        return 1


def cmd_list(_args: argparse.Namespace) -> int:
    print("Available benchmark suites:")
    for suite in SUITES:
        print(f"  - {suite}")
    return 0


def cmd_report(args: argparse.Namespace) -> int:
    try:
        with open(args.input, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as exc:
        print(f"Error reading input: {exc}", file=sys.stderr)
        return 1

    lines: List[str] = []
    lines.append("# Pyvorin Edge Benchmark Report\n")
    lines.append(f"**Command:** {data.get('command', 'unknown')}")
    lines.append(f"**Suite:** {data.get('suite', 'unknown')}")
    lines.append(f"**Iterations:** {data.get('iterations', 'unknown')}")
    lines.append(f"**Timestamp:** {data.get('timestamp', 'unknown')}\n")

    results = data.get("results", {})
    for key, metrics in results.items():
        if key.startswith("_"):
            continue
        lines.append(f"## {key}\n")
        if isinstance(metrics, dict) and "cpython" in metrics:
            lines.append("| Metric | CPython | Pyvorin |")
            lines.append("|--------|---------|---------|")
            c = metrics["cpython"]
            p = metrics["pyvorin"]
            lines.append(
                f"| p50 latency (ms) | {c['latency_ms']['p50']} | {p['latency_ms']['p50']} |"
            )
            lines.append(
                f"| p95 latency (ms) | {c['latency_ms']['p95']} | {p['latency_ms']['p95']} |"
            )
            lines.append(
                f"| p99 latency (ms) | {c['latency_ms']['p99']} | {p['latency_ms']['p99']} |"
            )
            lines.append(
                f"| throughput (ops/sec) | {c['throughput_ops_sec']} | {p['throughput_ops_sec']} |"
            )
            lines.append(
                f"| memory peak (MB) | {c['memory_peak_mb']} | {p['memory_peak_mb']} |"
            )
            lines.append(f"| reduction ratio | - | {metrics.get('reduction_ratio')} |")
            lines.append(f"| backend | - | {metrics.get('backend_used')} |")
            lines.append("")
        else:
            lines.append("| Metric | Value |")
            lines.append("|--------|-------|")
            lines.append(
                f"| p50 latency (ms) | {metrics['latency_ms']['p50']} |"
            )
            lines.append(
                f"| p95 latency (ms) | {metrics['latency_ms']['p95']} |"
            )
            lines.append(
                f"| p99 latency (ms) | {metrics['latency_ms']['p99']} |"
            )
            lines.append(
                f"| throughput (ops/sec) | {metrics['throughput_ops_sec']} |"
            )
            lines.append(
                f"| memory peak (MB) | {metrics['memory_peak_mb']} |"
            )
            lines.append("")

    md = "\n".join(lines)
    if args.output:
        with open(args.output, "w", encoding="utf-8") as f:
            f.write(md)
        print(f"Report written to {args.output}")
    else:
        print(md)
    return 0


def main(argv: Optional[List[str]] = None) -> int:
    parser = argparse.ArgumentParser(
        prog="pyv-edge-benchmark",
        description="Pyvorin Edge Benchmark CLI",
    )
    sub = parser.add_subparsers(dest="command", required=True)

    p_run = sub.add_parser("run", help="Run a benchmark suite")
    p_run.add_argument(
        "--suite", required=True, choices=SUITES, help="Benchmark suite"
    )
    p_run.add_argument(
        "--iterations", type=int, default=100, help="Number of iterations"
    )
    p_run.add_argument("--output", default=None, help="Output JSON file")
    p_run.set_defaults(func=cmd_run)

    p_compare = sub.add_parser("compare", help="Compare CPython vs Pyvorin")
    p_compare.add_argument(
        "--suite", required=True, choices=SUITES, help="Benchmark suite"
    )
    p_compare.add_argument(
        "--iterations", type=int, default=100, help="Number of iterations"
    )
    p_compare.add_argument("--output", default=None, help="Output JSON file")
    p_compare.set_defaults(func=cmd_compare)

    p_list = sub.add_parser("list", help="List available suites")
    p_list.set_defaults(func=cmd_list)

    p_report = sub.add_parser("report", help="Generate markdown report")
    p_report.add_argument("--input", required=True, help="Results JSON file")
    p_report.add_argument("--output", default=None, help="Output markdown file")
    p_report.set_defaults(func=cmd_report)

    args = parser.parse_args(argv)
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
