#!/usr/bin/env python3
"""
Pre-trade validator v3.
"""

from __future__ import annotations
import argparse
import json
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any

try:
    from zoneinfo import ZoneInfo
except Exception:
    ZoneInfo = None

try:
    import yaml
except Exception:
    yaml = None


@dataclass
class CheckResult:
    ok: bool
    code: str
    reason: str


def load_json(path: Path) -> dict:
    return json.loads(path.read_text(encoding="utf-8"))


def load_yaml(path: Path) -> dict:
    if yaml is None:
        raise RuntimeError("pyyaml not installed. Use .venv and install pyyaml.")
    return yaml.safe_load(path.read_text(encoding="utf-8"))


def parse_dt(value: str) -> datetime:
    if value.endswith("Z"):
        value = value[:-1] + "+00:00"
    return datetime.fromisoformat(value)


def in_trading_window(now_local: datetime, windows: list[str]) -> bool:
    now_min = now_local.hour * 60 + now_local.minute
    for w in windows:
        start, end = w.split("-")
        sh, sm = map(int, start.split(":"))
        eh, em = map(int, end.split(":"))
        if sh * 60 + sm <= now_min <= eh * 60 + em:
            return True
    return False


def check_required_fields(ev: dict, required: list[str]) -> CheckResult:
    metrics = ev.get("metrics", {})
    top_level = {"symbol": ev.get("symbol")}
    for f in required:
        if f in top_level and top_level[f] is not None:
            continue
        if f in metrics and metrics.get(f) is not None:
            continue
        return CheckResult(False, "DATA_MISSING", f"missing field: {f}")
    return CheckResult(True, "PASS", "required fields ok")


def check_independent_sources(ev: dict, minimum: int) -> CheckResult:
    n = sum(1 for s in ev.get("sources", []) if s.get("independent") is True)
    if n < minimum:
        return CheckResult(False, "DATA_INSUFFICIENT_SOURCES", f"independent sources {n} < {minimum}")
    return CheckResult(True, "PASS", "independent sources ok")


def check_quality_score(ev: dict, threshold: float = 0.75) -> CheckResult:
    v = ev.get("validation", {})
    scores = [
        float(v.get("consistency_score", 0)),
        float(v.get("freshness_score", 0)),
        float(v.get("completeness_score", 0)),
    ]
    avg = sum(scores) / len(scores)
    if avg < threshold:
        return CheckResult(False, "DATA_QUALITY_LOW", f"avg quality {avg:.3f} < {threshold}")
    return CheckResult(True, "PASS", f"quality score {avg:.3f}")


def check_trading_session(gates: dict, now_iso: str | None) -> CheckResult:
    eg = gates.get("execution_guards", {})
    if not eg.get("trading_session_required", True):
        return CheckResult(True, "PASS", "trading session check disabled")

    ts = gates.get("trading_sessions", {})
    windows = ts.get("windows", ["09:30-11:30", "13:00-15:00"])
    tz_name = ts.get("timezone", "Asia/Shanghai")

    if ZoneInfo is None:
        return CheckResult(False, "ENV_ZONEINFO_MISSING", "zoneinfo unavailable")

    tz = ZoneInfo(tz_name)
    now = parse_dt(now_iso) if now_iso else datetime.now(tz)
    if now.tzinfo is None:
        now = now.replace(tzinfo=tz)
    local = now.astimezone(tz)

    if local.weekday() >= 5:
        return CheckResult(False, "EXEC_NOT_TRADING_SESSION", f"weekend: {local.isoformat()}")
    if not in_trading_window(local, windows):
        return CheckResult(False, "EXEC_NOT_TRADING_SESSION", f"outside windows: {local.strftime('%H:%M')}")
    return CheckResult(True, "PASS", "trading session ok")


def check_enum_mapping(gates: dict, order: dict | None) -> CheckResult:
    if not order:
        return CheckResult(False, "ORDER_MISSING", "order payload is required for enum check")

    mapping = gates.get("enum_mapping", {})
    market_map = mapping.get("xtp_market_type", {"SH": 2, "SZ": 1})
    exch_map = mapping.get("xtp_exchange_type", {"SH": 1, "SZ": 2})

    venue = order.get("venue")
    m = order.get("market_type")
    e = order.get("exchange_type")

    if venue not in ("SH", "SZ"):
        return CheckResult(False, "ENUM_MAPPING_ERROR", f"venue must be SH/SZ, got {venue}")

    if m != market_map.get(venue) or e != exch_map.get(venue):
        return CheckResult(False, "ENUM_MAPPING_ERROR", f"venue={venue} expected ({market_map.get(venue)},{exch_map.get(venue)}), got ({m},{e})")
    return CheckResult(True, "PASS", "enum mapping ok")


def check_position_limit(gates: dict, evidence: dict, order: dict | None, portfolio: dict | None) -> CheckResult:
    if not order or not portfolio:
        return CheckResult(False, "PORTFOLIO_OR_ORDER_MISSING", "order+portfolio required")

    max_single = float(gates.get("portfolio_limits", {}).get("max_single_position_pct", 0.12))
    total_nav = float(portfolio.get("total_nav", 0))
    if total_nav <= 0:
        return CheckResult(False, "PORTFOLIO_INVALID", "total_nav must be > 0")

    symbol = order.get("symbol") or evidence.get("symbol")
    current_notional = float(portfolio.get("positions", {}).get(symbol, 0))
    order_notional = float(order.get("notional", 0))
    projected = (current_notional + order_notional) / total_nav

    if projected > max_single:
        return CheckResult(False, "RISK_LIMIT_BREACH", f"single position {projected:.4f} > {max_single:.4f}")
    return CheckResult(True, "PASS", f"projected position {projected:.4f}")


def check_industry_exposure(gates: dict, order: dict | None, portfolio: dict | None) -> CheckResult:
    if not order or not portfolio:
        return CheckResult(False, "PORTFOLIO_OR_ORDER_MISSING", "order+portfolio required")

    max_ind = float(gates.get("portfolio_limits", {}).get("max_industry_exposure_pct", 0.35))
    total_nav = float(portfolio.get("total_nav", 0))
    industry_map = portfolio.get("industry_positions", {})
    industry = order.get("industry")
    if not industry:
        return CheckResult(False, "ORDER_INDUSTRY_MISSING", "order.industry is required for industry check")

    current = float(industry_map.get(industry, 0))
    projected = (current + float(order.get("notional", 0))) / total_nav
    if projected > max_ind:
        return CheckResult(False, "RISK_LIMIT_BREACH", f"industry exposure {projected:.4f} > {max_ind:.4f}")
    return CheckResult(True, "PASS", f"projected industry exposure {projected:.4f}")


def check_drawdown_killswitch(gates: dict, portfolio: dict | None) -> CheckResult:
    if not portfolio:
        return CheckResult(False, "PORTFOLIO_MISSING", "portfolio required")
    lc = gates.get("loss_controls", {})
    dd_limit = float(lc.get("max_daily_drawdown_pct", 0.03))
    daily_dd = float(portfolio.get("daily_drawdown_pct", 0))
    if daily_dd > dd_limit:
        return CheckResult(False, "KILL_SWITCH_TRIGGERED", f"daily drawdown {daily_dd:.4f} > {dd_limit:.4f}")
    return CheckResult(True, "PASS", f"daily drawdown {daily_dd:.4f}")


def append_audit_log(path: Path, result: dict, context: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    row = {
        "ts": datetime.now().astimezone().isoformat(),
        "result": result,
        "context": context,
    }
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")


def run(evidence: dict, gates: dict, order: dict | None, portfolio: dict | None, now_iso: str | None) -> dict[str, Any]:
    dq = gates.get("data_quality_gates", {})
    checks = [
        check_required_fields(evidence, dq.get("required_fields", [])),
        check_independent_sources(evidence, int(dq.get("min_independent_sources", 2))),
        check_quality_score(evidence, 0.75),
        check_trading_session(gates, now_iso),
        check_enum_mapping(gates, order),
        check_position_limit(gates, evidence, order, portfolio),
        check_industry_exposure(gates, order, portfolio),
        check_drawdown_killswitch(gates, portfolio),
    ]

    failures = [c for c in checks if not c.ok]
    if failures:
        return {
            "ok": False,
            "code": "BLOCKED",
            "reason": f"{len(failures)} checks failed",
            "failures": [{"code": x.code, "reason": x.reason} for x in failures],
        }
    return {
        "ok": True,
        "code": "PASS",
        "reason": "all checks passed",
        "details": [{"code": x.code, "reason": x.reason} for x in checks],
    }


def main() -> int:
    p = argparse.ArgumentParser()
    p.add_argument("--evidence", required=True)
    p.add_argument("--gates", required=True)
    p.add_argument("--order", required=True)
    p.add_argument("--portfolio", required=True)
    p.add_argument("--now")
    p.add_argument("--audit-log", default="docs/ai-invest-research/audit/pretrade_audit.jsonl")
    args = p.parse_args()

    evidence = load_json(Path(args.evidence))
    gates = load_yaml(Path(args.gates))
    order = load_json(Path(args.order))
    portfolio = load_json(Path(args.portfolio))

    result = run(evidence, gates, order, portfolio, args.now)
    append_audit_log(Path(args.audit_log), result, {
        "symbol": order.get("symbol") or evidence.get("symbol"),
        "evidence_id": evidence.get("evidence_id"),
    })

    print(json.dumps(result, ensure_ascii=False))
    return 0 if result.get("ok") else 2


if __name__ == "__main__":
    sys.exit(main())
