#!/usr/bin/env python3
"""Lightweight OpenBB bridge for OpenClaw workflows.

Usage:
  python scripts/openbb_bridge.py quote TSLA
  python scripts/openbb_bridge.py history AAPL --interval 1d --limit 5
  python scripts/openbb_bridge.py topic lhlb --date 2026-03-13 --limit 20
"""

from __future__ import annotations

import argparse
import json
import ssl
import sys
import urllib.parse
import urllib.request
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

RUNTIME = Path("/Users/mibo/.openclaw/workspace/data/openbb-runtime")
VENV_PY = RUNTIME / ".venv/bin/python"

TZ_SH = "+08:00"


def _now_utc() -> str:
    return datetime.now(UTC).isoformat()


def _now_shanghai() -> str:
    return datetime.now(UTC).astimezone().astimezone().isoformat()


def _ensure_runtime() -> None:
    if not VENV_PY.exists():
        raise SystemExit(
            json.dumps(
                {
                    "ok": False,
                    "error": "openbb runtime missing",
                    "hint": f"expected python at {VENV_PY}",
                },
                ensure_ascii=False,
            )
        )


def _load_openbb():
    try:
        from openbb import obb  # type: ignore
    except Exception as e:  # pragma: no cover
        raise SystemExit(json.dumps({"ok": False, "error": f"import openbb failed: {e}"}, ensure_ascii=False))
    return obb


def _route_symbol(symbol: str) -> tuple[str, str, str | None]:
    """Return (asset_type, routed_symbol, note)."""
    s = symbol.strip()
    upper = s.upper()

    explicit_map = {
        "CN=F": ("equity", "ASHR", "mapped CN=F -> ASHR (A50 proxy)"),
        "XAUUSD=X": ("futures", "GC=F", "mapped XAUUSD=X -> GC=F"),
        "XAUUSD": ("futures", "GC=F", "mapped XAUUSD -> GC=F"),
    }
    if upper in explicit_map:
        return explicit_map[upper]

    if upper.endswith("=F"):
        return ("futures", s, None)
    if upper.endswith("=X") or "/" in upper:
        normalized = upper.replace("/", "").replace("=X", "")
        return ("currency", normalized, None)
    return ("equity", s, None)


def _history_df(obb: Any, symbol: str, interval: str):
    asset_type, routed_symbol, _ = _route_symbol(symbol)
    if asset_type == "futures":
        return obb.derivatives.futures.historical(routed_symbol, interval=interval).to_dataframe()
    if asset_type == "currency":
        return obb.currency.price.historical(routed_symbol, interval=interval, start_date="2025-01-01").to_dataframe()
    return obb.equity.price.historical(routed_symbol, interval=interval).to_dataframe()


def _http_get_json(url: str, params: dict[str, Any] | None = None, timeout: int = 12) -> dict[str, Any]:
    query = urllib.parse.urlencode(params or {}, doseq=True)
    full_url = f"{url}?{query}" if query else url
    req = urllib.request.Request(
        full_url,
        headers={
            "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0 Safari/537.36",
            "Accept": "application/json,text/plain,*/*",
            "Referer": "https://data.eastmoney.com/",
        },
    )
    ctx = ssl.create_default_context()
    with urllib.request.urlopen(req, timeout=timeout, context=ctx) as resp:
        raw = resp.read().decode("utf-8", errors="replace")
    return json.loads(raw)


def _compact_record(rec: dict[str, Any], keys: list[str]) -> dict[str, Any]:
    out = {}
    for k in keys:
        if k in rec:
            out[k] = rec[k]
    if not out:
        out = rec
    return out


def _eastmoney_topic(topic: str, date: str, limit: int) -> tuple[list[dict[str, Any]], str, str | None]:
    """Return (records, source, error)."""
    topic = topic.lower()

    # Multiple reportName candidates to maximize compatibility across EM schema revisions.
    if topic == "lhlb":
        candidates = [
            {
                "report": "RPT_DAILYBILLBOARD_DETAILSNEW",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "EXPLAIN", "CLOSE_PRICE", "CHANGE_RATE", "BILLBOARD_NET_AMT"],
            },
            {
                "report": "RPT_BILLBOARD_DAILYDETAILS",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "EXPLAIN", "CLOSE_PRICE", "CHANGE_RATE", "NET_BUY_AMT"],
            },
        ]
    elif topic == "rzrq":
        candidates = [
            {
                "report": "RPTA_WEB_RZRQ_GGMX",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "FIN_BUY_AMT", "FIN_BALANCE", "SEC_LENDING_BALANCE", "RZYE", "RQYE"],
            },
            {
                "report": "RPTA_WEB_RZRQ_GG",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "RZYE", "RQYE", "RZRQYE"],
            },
        ]
    elif topic == "dzjy":
        candidates = [
            {
                "report": "RPT_DATA_BLOCKTRADE",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "DEAL_PRICE", "DEAL_VOLUME", "DEAL_AMT", "PREMIUM_RATIO"],
            },
            {
                "report": "RPT_BLOCKTRADE_STA",
                "keys": ["SECURITY_CODE", "SECURITY_NAME_ABBR", "TRADE_DATE", "DEAL_PRICE", "DEAL_VOLUME", "DEAL_AMT"],
            },
        ]
    else:
        return [], "eastmoney", f"unsupported topic: {topic}"

    last_err: str | None = None
    for item in candidates:
        report_name = item["report"]
        keys = item["keys"]
        try:
            req = {
                "reportName": report_name,
                "columns": "ALL",
                "pageNumber": 1,
                "pageSize": max(50, min(limit * 8, 500)),
                "sortColumns": "TRADE_DATE",
                "sortTypes": "-1",
                "source": "WEB",
                "client": "WEB",
            }
            # Some reports (esp. rzrq) don't always support strict TRADE_DATE filters.
            if topic in ("lhlb", "dzjy"):
                req["filter"] = f"(TRADE_DATE='{date}')"

            payload = _http_get_json("https://datacenter-web.eastmoney.com/api/data/v1/get", req)
            result = payload.get("result") or {}
            data = result.get("data") or []
            if data:
                filtered = []
                for rec in data:
                    if not isinstance(rec, dict):
                        continue
                    if topic == "rzrq":
                        # Try to pin target date using any DATE-like field.
                        date_values = [str(v) for k, v in rec.items() if "DATE" in str(k).upper()]
                        if date_values and not any(v.startswith(date) for v in date_values):
                            continue
                    filtered.append(rec)
                use_rows = filtered if filtered else data
                rows = [_compact_record(r, keys) for r in use_rows[:limit] if isinstance(r, dict)]
                return rows, f"eastmoney:{report_name}", None
            last_err = f"{report_name}: empty result"
        except Exception as e:  # pragma: no cover
            last_err = f"{report_name}: {e}"

    return [], "eastmoney", last_err or "all eastmoney candidates failed"


def _eastmoney_legacy_rzrq(date: str, limit: int) -> tuple[list[dict[str, Any]], str, str | None]:
    try:
        payload = _http_get_json(
            "https://datacenter.eastmoney.com/api/data/get",
            {
                "type": "RPTA_WEB_RZRQ_GGMX",
                "sty": "ALL",
                "p": 1,
                "ps": max(50, min(limit * 8, 500)),
                "st": "DATE",
                "sr": -1,
            },
        )
        data = ((payload or {}).get("result") or {}).get("data") or []
        if not data:
            return [], "eastmoney-legacy:RPTA_WEB_RZRQ_GGMX", "empty result"
        filtered = [r for r in data if isinstance(r, dict) and str(r.get("DATE", "")).startswith(date)]
        rows = (filtered if filtered else data)[:limit]
        return rows, "eastmoney-legacy:RPTA_WEB_RZRQ_GGMX", None
    except Exception as e:
        return [], "eastmoney-legacy:RPTA_WEB_RZRQ_GGMX", str(e)


def _akshare_topic(topic: str, date: str, limit: int) -> tuple[list[dict[str, Any]], str, str | None]:
    try:
        import akshare as ak  # type: ignore
    except Exception as e:
        return [], "akshare", f"import failed: {e}"

    d = date.replace("-", "")
    try:
        if topic == "lhlb":
            df = ak.stock_lhb_detail_em(start_date=d, end_date=d)
        elif topic == "rzrq":
            sse = ak.stock_margin_detail_sse(date=d)
            szse = ak.stock_margin_detail_szse(date=d)
            try:
                import pandas as pd  # type: ignore
                df = pd.concat([sse, szse], ignore_index=True)
            except Exception:
                # fallback without pandas concat (very unlikely in runtime)
                rows = []
                for frame in (sse, szse):
                    rows.extend(frame.to_dict(orient="records"))
                return rows[:limit], "akshare:stock_margin_detail_sse+szse", None
        elif topic == "dzjy":
            df = ak.stock_dzjy_mrtj(start_date=d, end_date=d)
        else:
            return [], "akshare", f"unsupported topic: {topic}"

        if df is None or df.empty:
            return [], "akshare", "empty result"
        rows = df.head(limit).to_dict(orient="records")
        return rows, "akshare", None
    except Exception as e:
        return [], "akshare", str(e)


def cmd_quote(args: argparse.Namespace) -> int:
    obb = _load_openbb()
    asset_type, routed_symbol, note = _route_symbol(args.symbol)

    if asset_type == "equity":
        out = obb.equity.price.quote(routed_symbol)
        if not out.results:
            print(json.dumps({"ok": False, "error": "no quote result", "symbol": args.symbol}, ensure_ascii=False))
            return 2
        q = out.results[0]
        price = getattr(q, "last_price", None) or getattr(q, "price", None)
        prev_close = getattr(q, "prev_close", None) or getattr(q, "previous_close", None)
        change = getattr(q, "change", None)
        change_percent = getattr(q, "change_percent", None)

        if price is None:
            hdf = _history_df(obb, routed_symbol, "1d")
            if not hdf.empty and "close" in hdf.columns:
                closes = hdf["close"].dropna()
                if not closes.empty:
                    price = float(closes.iloc[-1])
                    if len(closes) > 1:
                        prev_close = float(closes.iloc[-2])
                        change = price - prev_close
                        change_percent = (change / prev_close * 100) if prev_close not in (None, 0) else None

        payload = {
            "ok": True,
            "kind": "quote",
            "symbol": getattr(q, "symbol", args.symbol),
            "requested_symbol": args.symbol,
            "price": price,
            "change": change,
            "change_percent": change_percent,
            "prev_close": prev_close,
            "source": "openbb",
            "as_of": _now_utc(),
            "fetched_at": _now_shanghai(),
        }
        if note:
            payload["note"] = note
        print(json.dumps(payload, ensure_ascii=False))
        return 0

    df = _history_df(obb, args.symbol, "1d")
    if df.empty:
        print(json.dumps({"ok": False, "error": "no quote result", "symbol": args.symbol}, ensure_ascii=False))
        return 2
    close_col = "close" if "close" in df.columns else df.columns[-1]
    last = df[close_col].dropna().iloc[-1] if not df[close_col].dropna().empty else None
    prev = df[close_col].dropna().iloc[-2] if len(df[close_col].dropna()) > 1 else None
    change = (last - prev) if (last is not None and prev is not None) else None
    change_pct = (change / prev * 100) if (change is not None and prev not in (None, 0)) else None

    payload = {
        "ok": True,
        "kind": "quote",
        "symbol": routed_symbol,
        "requested_symbol": args.symbol,
        "asset_type": asset_type,
        "price": None if last is None else float(last),
        "change": None if change is None else float(change),
        "change_percent": None if change_pct is None else float(change_pct),
        "prev_close": None if prev is None else float(prev),
        "source": "openbb",
        "as_of": _now_utc(),
        "fetched_at": _now_shanghai(),
    }
    if note:
        payload["note"] = note
    print(json.dumps(payload, ensure_ascii=False))
    return 0


def cmd_history(args: argparse.Namespace) -> int:
    obb = _load_openbb()
    asset_type, routed_symbol, note = _route_symbol(args.symbol)
    df = _history_df(obb, args.symbol, args.interval)
    if df.empty:
        print(json.dumps({"ok": False, "error": "empty history", "symbol": args.symbol}, ensure_ascii=False))
        return 2
    tail = df.tail(args.limit).reset_index()
    rows = []
    for _, row in tail.iterrows():
        rows.append({k: (None if str(v) == "nan" else (v.isoformat() if hasattr(v, "isoformat") else v)) for k, v in row.items()})
    payload = {
        "ok": True,
        "kind": "history",
        "symbol": routed_symbol,
        "requested_symbol": args.symbol,
        "asset_type": asset_type,
        "interval": args.interval,
        "rows": rows,
        "source": "openbb",
        "as_of": _now_utc(),
        "fetched_at": _now_shanghai(),
    }
    if note:
        payload["note"] = note
    print(json.dumps(payload, ensure_ascii=False))
    return 0


def cmd_topic(args: argparse.Namespace) -> int:
    # OpenBB currently does not expose CN lhlb/rzrq/dzjy through this bridge.
    openbb_gap_reason = "openbb bridge未暴露该专题接口"

    records, source, err = _eastmoney_topic(args.topic, args.date, args.limit)
    degrade_chain = []
    if err:
        degrade_chain.append(f"eastmoney失败: {err}")

    if not records and args.topic == "rzrq":
        legacy_rows, legacy_source, legacy_err = _eastmoney_legacy_rzrq(args.date, args.limit)
        if legacy_rows:
            records, source = legacy_rows, legacy_source
        else:
            degrade_chain.append(f"eastmoney-legacy失败: {legacy_err}")

    if not records:
        ak_rows, ak_source, ak_err = _akshare_topic(args.topic, args.date, args.limit)
        if ak_rows:
            records, source = ak_rows, ak_source
        else:
            degrade_chain.append(f"akshare失败: {ak_err}")

    if records:
        print(
            json.dumps(
                {
                    "ok": True,
                    "kind": "topic",
                    "topic": args.topic,
                    "date": args.date,
                    "rows": records,
                    "source": source,
                    "degraded_from": "openbb",
                    "degrade_reason": openbb_gap_reason,
                    "as_of": args.date,
                    "fetched_at": _now_shanghai(),
                },
                ensure_ascii=False,
            )
        )
        return 0

    print(
        json.dumps(
            {
                "ok": False,
                "kind": "topic",
                "topic": args.topic,
                "date": args.date,
                "error": "topic fetch failed",
                "source": source,
                "degraded_from": "openbb",
                "degrade_reason": f"{openbb_gap_reason}; {'; '.join(degrade_chain)}",
                "as_of": args.date,
                "fetched_at": _now_shanghai(),
            },
            ensure_ascii=False,
        )
    )
    return 2


def main() -> int:
    _ensure_runtime()
    parser = argparse.ArgumentParser(description="OpenBB bridge")
    sub = parser.add_subparsers(dest="command", required=True)

    p_quote = sub.add_parser("quote", help="latest quote")
    p_quote.add_argument("symbol", help="ticker symbol")
    p_quote.set_defaults(func=cmd_quote)

    p_hist = sub.add_parser("history", help="historical prices")
    p_hist.add_argument("symbol", help="ticker symbol")
    p_hist.add_argument("--interval", default="1d", help="interval, e.g. 1d/1h")
    p_hist.add_argument("--limit", type=int, default=5, help="rows to return")
    p_hist.set_defaults(func=cmd_history)

    p_topic = sub.add_parser("topic", help="CN special topics: lhlb/rzrq/dzjy")
    p_topic.add_argument("topic", choices=["lhlb", "rzrq", "dzjy"], help="topic code")
    p_topic.add_argument("--date", default=datetime.now(UTC).date().isoformat(), help="trade date, YYYY-MM-DD")
    p_topic.add_argument("--limit", type=int, default=20, help="rows to return")
    p_topic.set_defaults(func=cmd_topic)

    args = parser.parse_args()
    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())
