"""
回测验证工具套件（对应文档"提示词十：回测验证清单"）

1. 蒙特卡洛模拟 — 打乱交易顺序，验证策略稳健性
2. 参数敏感性分析 — trailing/风险比例/zigzag阈值微调后的收益变化
3. 交易抽查工具 — 逐笔查看入场理由、K线环境、出场过程
4. 前视偏差检查器 — 静态分析代码中可能存在的前视偏差
"""

import numpy as np
import pandas as pd
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
import copy


# ============================================================
# 1. 蒙特卡洛模拟
# ============================================================

@dataclass
class MonteCarloResult:
    n_simulations: int
    median_final: float
    mean_final: float
    p5_final: float         # 5th percentile
    p25_final: float
    p75_final: float
    p95_final: float
    loss_probability: float  # 亏损概率（终值<初始）
    max_dd_median: float
    max_dd_p95: float


def monte_carlo(
    trades_pnl: List[float],
    initial_capital: float = 1000.0,
    n_simulations: int = 10000,
    seed: int = 42,
) -> MonteCarloResult:
    """
    蒙特卡洛模拟：打乱交易顺序，重新计算权益曲线

    验证策略的稳健性 — 如果只是碰巧某几笔大赢单的顺序好，
    打乱后收益会大幅下降。好策略在大多数排列下都盈利。

    Args:
        trades_pnl: 每笔交易的盈亏金额列表
        initial_capital: 初始资金
        n_simulations: 模拟次数
    Returns:
        MonteCarloResult
    """
    rng = np.random.RandomState(seed)
    pnl_array = np.array(trades_pnl)
    n_trades = len(pnl_array)

    finals = np.zeros(n_simulations)
    max_dds = np.zeros(n_simulations)

    for sim in range(n_simulations):
        # 打乱交易顺序
        shuffled = rng.permutation(pnl_array)

        # 计算权益曲线
        equity = np.zeros(n_trades + 1)
        equity[0] = initial_capital
        for i in range(n_trades):
            equity[i + 1] = equity[i] + shuffled[i]
            if equity[i + 1] <= 0:
                equity[i + 1:] = 0
                break

        finals[sim] = equity[-1]

        # 最大回撤
        peak = np.maximum.accumulate(equity)
        dd = np.where(peak > 0, (peak - equity) / peak, 0)
        max_dds[sim] = dd.max()

    return MonteCarloResult(
        n_simulations=n_simulations,
        median_final=float(np.median(finals)),
        mean_final=float(np.mean(finals)),
        p5_final=float(np.percentile(finals, 5)),
        p25_final=float(np.percentile(finals, 25)),
        p75_final=float(np.percentile(finals, 75)),
        p95_final=float(np.percentile(finals, 95)),
        loss_probability=float(np.mean(finals < initial_capital)),
        max_dd_median=float(np.median(max_dds) * 100),
        max_dd_p95=float(np.percentile(max_dds, 95) * 100),
    )


def print_monte_carlo(result: MonteCarloResult, initial_capital: float = 1000.0):
    """打印蒙特卡洛结果"""
    print(f"\n{'='*60}")
    print(f"蒙特卡洛模拟 ({result.n_simulations:,} 次)")
    print(f"{'='*60}")
    print(f"  初始资金:          ${initial_capital:>14,.0f}")
    print(f"{'─'*45}")
    print(f"  终值中位数:        ${result.median_final:>14,.0f}  "
          f"({result.median_final/initial_capital:,.0f}x)")
    print(f"  终值均值:          ${result.mean_final:>14,.0f}")
    print(f"  5th percentile:    ${result.p5_final:>14,.0f}")
    print(f"  25th percentile:   ${result.p25_final:>14,.0f}")
    print(f"  75th percentile:   ${result.p75_final:>14,.0f}")
    print(f"  95th percentile:   ${result.p95_final:>14,.0f}")
    print(f"{'─'*45}")
    print(f"  亏损概率:          {result.loss_probability:>13.1%}")
    print(f"  回撤中位数:        {result.max_dd_median:>13.1f}%")
    print(f"  回撤 95th pctile:  {result.max_dd_p95:>13.1f}%")
    print(f"{'='*60}")

    # 验收标准（文档要求）
    print(f"\n  验收检查:")
    check_loss = result.loss_probability < 0.01
    check_median = result.median_final > initial_capital * 1000
    print(f"    亏损概率 < 1%:        {'✓' if check_loss else '✗'} "
          f"(实际 {result.loss_probability:.2%})")
    print(f"    中位数收益 > 1000倍:  {'✓' if check_median else '✗'} "
          f"(实际 {result.median_final/initial_capital:,.0f}倍)")


# ============================================================
# 2. 参数敏感性分析
# ============================================================

def parameter_sensitivity(
    run_backtest_func,
    base_params: dict,
    param_name: str,
    param_values: List[float],
    n_runs: int = 1,
) -> pd.DataFrame:
    """
    参数敏感性分析

    对单个参数取不同值，分别运行回测，比较结果。

    Args:
        run_backtest_func: callable(params) -> (final_balance, n_trades, win_rate, max_dd, pf)
        base_params: 基准参数字典
        param_name: 要测试的参数名
        param_values: 参数取值列表
    Returns:
        DataFrame with results for each parameter value
    """
    results = []
    for val in param_values:
        params = copy.deepcopy(base_params)
        params[param_name] = val

        try:
            final, n_trades, wr, max_dd, pf = run_backtest_func(params)
            results.append({
                'param_value': val,
                'final_balance': final,
                'n_trades': n_trades,
                'win_rate': wr,
                'max_drawdown': max_dd,
                'profit_factor': pf,
            })
        except Exception as e:
            results.append({
                'param_value': val,
                'final_balance': 0,
                'n_trades': 0,
                'win_rate': 0,
                'max_drawdown': 100,
                'profit_factor': 0,
                'error': str(e),
            })

    return pd.DataFrame(results)


def print_sensitivity(df: pd.DataFrame, param_name: str):
    """打印参数敏感性结果"""
    print(f"\n{'='*60}")
    print(f"参数敏感性: {param_name}")
    print(f"{'='*60}")
    print(f"  {'值':<10} {'终值':>14} {'交易':>6} {'胜率':>7} {'回撤':>7} {'PF':>7}")
    print(f"  {'─'*52}")
    for _, row in df.iterrows():
        print(f"  {row['param_value']:<10.3f} "
              f"${row['final_balance']:>13,.0f} "
              f"{row['n_trades']:>5.0f} "
              f"{row['win_rate']:>6.1f}% "
              f"{row['max_drawdown']:>6.1f}% "
              f"{row['profit_factor']:>6.2f}")


# ============================================================
# 3. 交易抽查工具
# ============================================================

def inspect_trade(
    trade_record,
    df_4h: pd.DataFrame,
    df_1d: pd.DataFrame,
    context_bars: int = 20,
) -> str:
    """
    逐笔交易抽查 — 生成可读的交易详情报告

    Args:
        trade_record: TradeRecord对象
        df_4h: 该币种的4H K线
        df_1d: 该币种的日线
        context_bars: 显示入场前后的K线数量
    Returns:
        可读文本报告
    """
    t = trade_record
    lines = []
    lines.append(f"{'='*60}")
    lines.append(f"交易 #{t.trade_id}: {t.coin} {t.direction.upper()} ({t.signal_type})")
    lines.append(f"{'='*60}")

    lines.append(f"\n  入场时间:    {t.entry_time}")
    lines.append(f"  出场时间:    {t.exit_time}")
    lines.append(f"  持仓时间:    {t.hold_bars} 根4H K线 (~{t.hold_bars*4/24:.0f}天)")
    lines.append(f"  入场方式:    {t.entry_method}")

    lines.append(f"\n  入场价:      ${t.entry_price:,.2f}")
    lines.append(f"  出场价:      ${t.exit_price:,.2f}")
    lines.append(f"  止损价:      ${t.stop_loss:,.2f} (距离 {t.sl_distance:.2%})")

    lines.append(f"\n  保证金:      ${t.margin:,.2f}")
    lines.append(f"  名义持仓:    ${t.notional:,.2f}")
    lines.append(f"  杠杆:        {t.leverage}x")

    color = '+' if t.pnl_dollar >= 0 else ''
    lines.append(f"\n  盈亏:        {color}${t.pnl_dollar:,.2f} ({color}{t.pnl_pct:.2%})")
    lines.append(f"  最大浮盈:    {t.max_unrealized_pct:.2%}")
    lines.append(f"  出场效率:    {t.exit_efficiency:.2f}")
    lines.append(f"  出场原因:    {t.exit_reason}")
    lines.append(f"  手续费:      ${t.fees:,.2f}")

    lines.append(f"\n  余额变动:    ${t.balance_after:,.2f}")
    lines.append(f"  盒子宽度:    {t.box_width_pct:.2%}")
    lines.append(f"  量比:        {t.volume_ratio:.1f}x")
    lines.append(f"  BTC趋势:    {t.btc_trend_60d:.1%}")
    lines.append(f"  本币趋势:    {t.coin_trend_60d:.1%}")

    # 入场前后的K线概览
    if df_4h is not None and len(df_4h) > 0:
        entry_t = t.entry_time
        # Ensure timezone compatibility
        if df_4h['timestamp'].dt.tz is None and hasattr(entry_t, 'tz') and entry_t.tz is not None:
            entry_t = entry_t.tz_localize(None)
        elif df_4h['timestamp'].dt.tz is not None and (not hasattr(entry_t, 'tz') or entry_t.tz is None):
            entry_t = entry_t.tz_localize('UTC')

        mask = (df_4h['timestamp'] >= entry_t - pd.Timedelta(hours=4*context_bars)) & \
               (df_4h['timestamp'] <= entry_t + pd.Timedelta(hours=4*5))
        context = df_4h[mask]

        if len(context) > 0:
            lines.append(f"\n  入场前后4H K线:")
            lines.append(f"  {'时间':<22} {'开':>10} {'高':>10} {'低':>10} "
                         f"{'收':>10} {'量':>10} {'标记'}")
            lines.append(f"  {'─'*85}")
            for _, bar in context.iterrows():
                marker = ""
                try:
                    bar_ts = bar['timestamp']
                    ref_ts = entry_t
                    # normalize tz
                    if hasattr(bar_ts, 'tz') and bar_ts.tz is not None:
                        if not hasattr(ref_ts, 'tz') or ref_ts.tz is None:
                            ref_ts = pd.Timestamp(ref_ts, tz='UTC')
                    elif hasattr(ref_ts, 'tz') and ref_ts.tz is not None:
                        bar_ts = pd.Timestamp(bar_ts, tz='UTC')
                    if abs((bar_ts - ref_ts).total_seconds()) < 14400:
                        marker = " ← 入场"
                except Exception:
                    pass
                lines.append(
                    f"  {str(bar['timestamp'])[:19]:<22} "
                    f"{bar['open']:>10.2f} {bar['high']:>10.2f} "
                    f"{bar['low']:>10.2f} {bar['close']:>10.2f} "
                    f"{bar['volume']:>10.0f}{marker}"
                )

    lines.append(f"\n{'='*60}")
    return '\n'.join(lines)


def random_trade_inspection(
    trades: list,
    all_tf_data: dict,
    n_samples: int = 5,
    seed: int = 42,
) -> str:
    """
    随机抽取n笔交易进行详细检查

    Args:
        trades: List of TradeRecord
        all_tf_data: {symbol: {'4H': df, '1D': df, ...}}
        n_samples: 抽样数量
    Returns:
        可读文本报告
    """
    rng = np.random.RandomState(seed)
    indices = rng.choice(len(trades), min(n_samples, len(trades)), replace=False)
    indices.sort()

    reports = []
    for idx in indices:
        t = trades[idx]
        tf = all_tf_data.get(t.coin, {})
        df_4h = tf.get('4H')
        df_1d = tf.get('1D')
        reports.append(inspect_trade(t, df_4h, df_1d))

    return '\n\n'.join(reports)


# ============================================================
# 4. 快速统计工具
# ============================================================

def compute_stats(trades: list, initial_capital: float) -> dict:
    """
    从交易列表计算关键统计指标
    Returns dict with: final, n_trades, win_rate, max_dd, pf, payoff, etc.
    """
    if not trades:
        return {'final': initial_capital, 'n_trades': 0, 'win_rate': 0,
                'max_dd': 0, 'pf': 0, 'payoff': 0}

    pnl_list = [t.pnl_dollar for t in trades]
    balance_list = [t.balance_after for t in trades]

    wins = sum(1 for p in pnl_list if p > 0)
    losses = len(pnl_list) - wins

    gross_profit = sum(p for p in pnl_list if p > 0)
    gross_loss = abs(sum(p for p in pnl_list if p <= 0))

    win_pcts = [t.pnl_pct for t in trades if t.pnl_dollar > 0]
    loss_pcts = [abs(t.pnl_pct) for t in trades if t.pnl_dollar <= 0]

    avg_win = np.mean(win_pcts) if win_pcts else 0
    avg_loss = np.mean(loss_pcts) if loss_pcts else 0.001

    eq = np.array(balance_list)
    peak = np.maximum.accumulate(eq)
    dd = np.where(peak > 0, (peak - eq) / peak, 0)

    return {
        'final': balance_list[-1],
        'multiple': balance_list[-1] / initial_capital,
        'n_trades': len(trades),
        'wins': wins,
        'losses': losses,
        'win_rate': wins / len(trades) * 100,
        'avg_win_pct': avg_win * 100,
        'avg_loss_pct': avg_loss * 100,
        'payoff': avg_win / avg_loss if avg_loss > 0 else float('inf'),
        'pf': gross_profit / gross_loss if gross_loss > 0 else float('inf'),
        'max_dd': dd.max() * 100,
        'total_fees': sum(t.fees for t in trades),
    }


# ============================================================
# 5. 牛熊分段分析
# ============================================================

# BTC主要牛熊周期（用于分段统计）
MARKET_PHASES = [
    ('2017-08', '2017-12', '牛市', '2017牛市'),
    ('2018-01', '2018-12', '熊市', '2018大熊'),
    ('2019-01', '2019-06', '复苏', '2019复苏'),
    ('2019-07', '2019-12', '震荡', '2019下半年'),
    ('2020-01', '2020-03', '暴跌', 'COVID崩盘'),
    ('2020-04', '2020-09', '复苏', '2020复苏'),
    ('2020-10', '2021-04', '牛市', '2020-21牛市'),
    ('2021-05', '2021-07', '回调', '2021回调'),
    ('2021-08', '2021-11', '牛市', '2021二次高点'),
    ('2021-12', '2022-12', '熊市', '2022大熊'),
    ('2023-01', '2023-12', '复苏', '2023慢牛'),
    ('2024-01', '2024-12', '牛市', '2024ETF牛'),
    ('2025-01', '2025-12', '震荡', '2025'),
]


def bull_bear_analysis(trades: list) -> pd.DataFrame:
    """
    按牛熊周期分段分析交易表现

    文档要求:
    - 牛市应该赚大钱（占总利润90%+）
    - 熊市应该小亏或不亏（回撤<5%）
    """
    results = []

    for start, end, phase_type, label in MARKET_PHASES:
        start_ts = pd.Timestamp(start + '-01', tz='UTC')
        end_ts = pd.Timestamp(end + '-28', tz='UTC')  # approximate month end

        phase_trades = [t for t in trades
                        if start_ts <= t.entry_time <= end_ts]

        if not phase_trades:
            continue

        n = len(phase_trades)
        wins = sum(1 for t in phase_trades if t.pnl_dollar > 0)
        pnl = sum(t.pnl_dollar for t in phase_trades)

        results.append({
            'phase': label,
            'type': phase_type,
            'trades': n,
            'wins': wins,
            'win_rate': wins / n * 100 if n > 0 else 0,
            'net_pnl': pnl,
        })

    return pd.DataFrame(results)


def print_bull_bear(df: pd.DataFrame, total_profit: float = None):
    """打印牛熊分段分析"""
    print(f"\n{'='*60}")
    print(f"牛熊分段分析")
    print(f"{'='*60}")
    print(f"  {'阶段':<14} {'类型':<6} {'交易':>4} {'胜率':>7} {'净盈亏':>14}")
    print(f"  {'─'*50}")

    bull_pnl = 0
    bear_pnl = 0

    for _, row in df.iterrows():
        marker = '🟢' if row['net_pnl'] >= 0 else '🔴'
        print(f"  {row['phase']:<14} {row['type']:<6} "
              f"{row['trades']:>4} {row['win_rate']:>6.1f}% "
              f"${row['net_pnl']:>13,.0f} {marker}")

        if row['type'] in ('牛市', '复苏'):
            bull_pnl += row['net_pnl']
        elif row['type'] in ('熊市', '暴跌'):
            bear_pnl += row['net_pnl']

    if total_profit and total_profit > 0:
        print(f"\n  牛市+复苏利润:    ${bull_pnl:>13,.0f} "
              f"({bull_pnl/total_profit:.0%} of total)")
        print(f"  熊市+暴跌利润:    ${bear_pnl:>13,.0f}")

    print(f"{'='*60}")
