"""
回测引擎

执行流程:
1. 对每个币种独立跑信号管线 (做多 + 做空)
2. 合并所有信号，按时间排序
3. 逐信号：过滤 → 仓位 → 模拟持仓 → 更新余额
4. 输出完整统计
"""

import pandas as pd
import numpy as np
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from datetime import datetime

from ..config import settings as S
from ..config.coins import COINS, get_coin
from ..core.pipeline import generate_long_signals, Signal
from ..core.shorts import find_short_signals
from ..core.filters import SignalFilter, calc_trend_60d
from ..core.position import calc_position, calc_effective_equity
from ..core.trailing import simulate_exit, ExitResult
from ..data.aggregator import prepare_all_timeframes


@dataclass
class TradeRecord:
    """完整交易记录"""
    trade_id: int
    coin: str
    direction: str
    signal_type: str
    entry_time: pd.Timestamp
    entry_price: float
    exit_time: Optional[pd.Timestamp]
    exit_price: Optional[float]
    stop_loss: float
    sl_distance: float
    margin: float
    notional: float
    leverage: int
    pnl_pct: float
    pnl_dollar: float
    exit_reason: str
    max_unrealized_pct: float
    hold_bars: int
    exit_efficiency: float
    fees: float
    balance_after: float
    entry_method: str
    volume_ratio: float
    box_width_pct: float
    btc_trend_60d: float
    coin_trend_60d: float


class BacktestEngine:
    """
    回测引擎：从K线数据到完整统计

    用法:
        engine = BacktestEngine(initial_capital=1000)
        results = engine.run(all_data)
        engine.print_report()
    """

    def __init__(self, initial_capital: float = 1000.0, verbose: bool = False):
        self.initial_capital = initial_capital
        self.balance = initial_capital
        self.peak_balance = initial_capital
        self.month_start_balance = initial_capital
        self.current_month = None
        self.verbose = verbose

        self.trades: List[TradeRecord] = []
        self.filtered_log: List[dict] = []   # 被过滤信号的诊断日志
        self.signal_filter = SignalFilter()
        self.trade_counter = 0

        # 活跃持仓列表 [(exit_time, margin, coin)]
        self.active_positions: List[tuple] = []

        # BTC日线突破时间（用于跨币共振判断）
        self.btc_daily_breakout_times: List[pd.Timestamp] = []

        # 风控
        self.daily_pnl = 0.0
        self.current_day = None
        self.halted = False

    def run(self, all_coin_data: Dict[str, pd.DataFrame]) -> List[TradeRecord]:
        """
        运行完整回测

        Args:
            all_coin_data: {symbol: df_1h} 所有币种的1H K线数据
        Returns:
            List of TradeRecord
        """
        print("=" * 60)
        print("浪浪AI 回测引擎 启动")
        print(f"初始资金: ${self.initial_capital:,.0f}")
        print("=" * 60)

        # 1. 准备所有时间周期数据
        print("\n[1/4] 聚合K线数据...")
        all_tf_data = {}
        for symbol, df_1h in all_coin_data.items():
            if symbol not in COINS:
                continue
            cfg = COINS[symbol]
            # 过滤起始日期
            start = pd.Timestamp(cfg.data_start, tz='UTC')
            # Ensure timestamp column is tz-aware
            if df_1h['timestamp'].dt.tz is None:
                df_1h = df_1h.copy()
                df_1h['timestamp'] = df_1h['timestamp'].dt.tz_localize('UTC')
            df = df_1h[df_1h['timestamp'] >= start].reset_index(drop=True)
            if len(df) < 100:
                print(f"  {symbol}: 数据不足, 跳过")
                continue
            all_tf_data[symbol] = prepare_all_timeframes(df)
            print(f"  {symbol}: 1H={len(all_tf_data[symbol]['1H'])} "
                  f"4H={len(all_tf_data[symbol]['4H'])} "
                  f"1D={len(all_tf_data[symbol]['1D'])}")

        # 2. 生成所有信号
        print("\n[2/4] 生成交易信号...")
        all_signals: List[Signal] = []

        for symbol, tf_data in all_tf_data.items():
            # 做多信号
            long_signals = generate_long_signals(
                symbol, tf_data['1H'], tf_data['4H'], tf_data['1D']
            )
            all_signals.extend(long_signals)
            print(f"  {symbol} 做多: {len(long_signals)} 个信号")

            # 做空信号（只有BTC和ETH）
            cfg = COINS[symbol]
            if cfg.allow_short:
                short_signals = find_short_signals(symbol, tf_data['4H'])
                all_signals.extend(short_signals)
                print(f"  {symbol} 做空: {len(short_signals)} 个信号")

        # 提取BTC日线突破时间
        for sig in all_signals:
            if sig.coin == "BTCUSDT" and sig.direction == 'long':
                if sig.signal_type in ('日线', '共振'):
                    self.btc_daily_breakout_times.append(sig.entry_time)

        # 按时间排序
        all_signals.sort(key=lambda s: s.entry_time)
        print(f"\n  总信号数: {len(all_signals)}")

        # 3. 逐信号回测
        print("\n[3/4] 逐信号回测...")
        for sig in all_signals:
            self._process_signal(sig, all_tf_data)

        # 4. 输出结果
        print(f"\n[4/4] 回测完成!")
        print(f"  总交易: {len(self.trades)} 笔")
        print(f"  最终余额: ${self.balance:,.0f}")

        return self.trades

    def _get_used_margin(self, current_time: pd.Timestamp) -> float:
        """计算当前已使用保证金（清理已出场的仓位）"""
        self.active_positions = [
            (exit_t, margin, coin) for exit_t, margin, coin in self.active_positions
            if exit_t > current_time
        ]
        return sum(m for _, m, _ in self.active_positions)

    def _process_signal(self, signal: Signal, all_tf_data: dict):
        """处理单个信号"""

        # 风控熔断检查
        if self.halted:
            return

        # 更新月初余额
        self._update_month(signal.entry_time)

        # 更新日内盈亏追踪
        sig_day = signal.entry_time.date()
        if self.current_day != sig_day:
            self.current_day = sig_day
            self.daily_pnl = 0.0
            # 新日解除日亏熔断（总回撤熔断不解除）
            if self.halted and self.peak_balance > 0:
                total_dd = (self.peak_balance - self.balance) / self.peak_balance
                if total_dd <= S.TOTAL_DD_HALT:
                    self.halted = False

        # 获取趋势数据
        tf_data = all_tf_data.get(signal.coin)
        if tf_data is None:
            return

        df_1d = tf_data['1D']
        btc_data = all_tf_data.get("BTCUSDT", {}).get('1D')

        # 计算60日趋势
        coin_trend = calc_trend_60d(
            df_1d['close'].values, df_1d['timestamp'],
            signal.entry_time
        )
        btc_trend = 0.0
        if btc_data is not None:
            btc_trend = calc_trend_60d(
                btc_data['close'].values, btc_data['timestamp'],
                signal.entry_time
            )

        # 过滤
        passed, reason, adjustments = self.signal_filter.check_signal(
            signal, btc_trend, coin_trend, self.btc_daily_breakout_times
        )

        if not passed:
            if self.verbose:
                self.filtered_log.append({
                    'time': signal.entry_time,
                    'coin': signal.coin,
                    'type': signal.signal_type,
                    'direction': signal.direction,
                    'reason': reason,
                    'btc_trend': f"{btc_trend:.1%}",
                    'coin_trend': f"{coin_trend:.1%}",
                })
            return

        # 有效权益
        effective_equity = calc_effective_equity(self.balance, self.month_start_balance)

        # 当前已用保证金
        used_margin = self._get_used_margin(signal.entry_time)

        # 仓位计算
        cfg = get_coin(signal.coin)
        pos = calc_position(
            equity=effective_equity,
            sl_distance=signal.sl_distance,
            coin=signal.coin,
            signal_type=signal.signal_type,
            leverage=cfg.leverage,
            position_mult=adjustments['position_mult'],
            cross_resonance=adjustments.get('cross_resonance', False),
            used_margin=used_margin,
        )

        if pos['skip']:
            return

        # 保证金不能超过余额
        if pos['margin'] > self.balance:
            return

        # 模拟持仓
        df_4h = tf_data['4H']
        exit_result = simulate_exit(
            direction=signal.direction,
            entry_price=signal.entry_price,
            stop_loss=signal.stop_loss,
            coin=signal.coin,
            entry_method=signal.entry_method,
            notional=pos['notional'],
            df_4h=df_4h,
            entry_time=signal.entry_time,
        )

        if exit_result is None:
            return

        # 更新余额
        self.balance += exit_result.pnl_dollar
        if self.balance <= 0:
            self.balance = 0
            return

        self.peak_balance = max(self.peak_balance, self.balance)

        # 日亏熔断检查
        self.daily_pnl += exit_result.pnl_dollar
        if self.balance > 0 and self.daily_pnl / self.balance < -S.DAILY_LOSS_HALT:
            self.halted = True  # 当日暂停（简化处理：停止后续所有交易直到新日）
        # 总回撤熔断
        if self.peak_balance > 0:
            total_dd = (self.peak_balance - self.balance) / self.peak_balance
            if total_dd > S.TOTAL_DD_HALT:
                self.halted = True

        # 记录活跃持仓（用于并发保证金计算）
        self.active_positions.append(
            (exit_result.exit_time, pos['margin'], signal.coin)
        )

        # 更新过滤器状态
        self.signal_filter.update_trade_result(signal, exit_result.pnl_dollar)
        self.signal_filter.record_entry(signal)

        # 记录交易
        self.trade_counter += 1
        self.trades.append(TradeRecord(
            trade_id=self.trade_counter,
            coin=signal.coin,
            direction=signal.direction,
            signal_type=signal.signal_type,
            entry_time=signal.entry_time,
            entry_price=signal.entry_price,
            exit_time=exit_result.exit_time,
            exit_price=exit_result.exit_price,
            stop_loss=signal.stop_loss,
            sl_distance=signal.sl_distance,
            margin=pos['margin'],
            notional=pos['notional'],
            leverage=cfg.leverage,
            pnl_pct=exit_result.pnl_pct,
            pnl_dollar=exit_result.pnl_dollar,
            exit_reason=exit_result.exit_reason,
            max_unrealized_pct=exit_result.max_unrealized_pct,
            hold_bars=exit_result.hold_bars,
            exit_efficiency=exit_result.exit_efficiency,
            fees=exit_result.fees,
            balance_after=self.balance,
            entry_method=signal.entry_method,
            volume_ratio=signal.volume_ratio,
            box_width_pct=signal.box_width_pct,
            btc_trend_60d=btc_trend,
            coin_trend_60d=coin_trend,
        ))

    def _update_month(self, timestamp: pd.Timestamp):
        """月初余额更新"""
        month_key = (timestamp.year, timestamp.month)
        if self.current_month != month_key:
            self.current_month = month_key
            self.month_start_balance = self.balance
            # 新月解除日亏熔断
            if self.halted and self.peak_balance > 0:
                total_dd = (self.peak_balance - self.balance) / self.peak_balance
                if total_dd <= S.TOTAL_DD_HALT:
                    self.halted = False

    # ==================== 报告 ====================

    def print_report(self):
        """输出完整回测统计"""
        if not self.trades:
            print("没有交易记录")
            return

        df = pd.DataFrame([t.__dict__ for t in self.trades])

        print("\n" + "=" * 70)
        print("浪浪AI 回测报告")
        print("=" * 70)

        # 基础统计
        total = len(df)
        wins = len(df[df['pnl_dollar'] > 0])
        losses = len(df[df['pnl_dollar'] <= 0])
        win_rate = wins / total * 100

        avg_win = df[df['pnl_dollar'] > 0]['pnl_pct'].mean() * 100 if wins > 0 else 0
        avg_loss = abs(df[df['pnl_dollar'] <= 0]['pnl_pct'].mean()) * 100 if losses > 0 else 0
        payoff = avg_win / avg_loss if avg_loss > 0 else float('inf')

        gross_profit = df[df['pnl_dollar'] > 0]['pnl_dollar'].sum()
        gross_loss = abs(df[df['pnl_dollar'] <= 0]['pnl_dollar'].sum())
        pf = gross_profit / gross_loss if gross_loss > 0 else float('inf')

        final = self.balance
        multiple = final / self.initial_capital

        print(f"\n{'─'*40}")
        print(f"  初始资金:      ${self.initial_capital:>15,.0f}")
        print(f"  最终资金:      ${final:>15,.0f}")
        print(f"  收益倍数:      {multiple:>15,.0f}x")
        print(f"{'─'*40}")
        print(f"  总交易:        {total:>15}")
        print(f"  盈利:          {wins:>15}")
        print(f"  亏损:          {losses:>15}")
        print(f"  胜率:          {win_rate:>14.1f}%")
        print(f"  平均盈利:      {avg_win:>14.1f}%")
        print(f"  平均亏损:      {avg_loss:>14.1f}%")
        print(f"  盈亏比:        {payoff:>15.1f}")
        print(f"  PF:            {pf:>15.2f}")

        # 回撤
        equity_curve = df['balance_after'].values
        peak = np.maximum.accumulate(equity_curve)
        drawdown = (peak - equity_curve) / peak
        max_dd = drawdown.max() * 100

        # 连败
        streaks = []
        current_streak = 0
        for _, t in df.iterrows():
            if t['pnl_dollar'] <= 0:
                current_streak += 1
            else:
                if current_streak > 0:
                    streaks.append(current_streak)
                current_streak = 0
        if current_streak > 0:
            streaks.append(current_streak)
        max_streak = max(streaks) if streaks else 0

        print(f"{'─'*40}")
        print(f"  最大回撤:      {max_dd:>14.1f}%")
        print(f"  最大连败:      {max_streak:>15}")
        print(f"  最大单笔盈利:  ${df['pnl_dollar'].max():>14,.0f}")
        print(f"  最大单笔亏损:  ${df['pnl_dollar'].min():>14,.0f}")
        print(f"  总手续费:      ${df['fees'].sum():>14,.0f}")

        # 按年度
        print(f"\n{'─'*40}")
        print("  年度表现:")
        df['year'] = df['entry_time'].dt.year
        for year, grp in df.groupby('year'):
            n = len(grp)
            w = len(grp[grp['pnl_dollar'] > 0])
            l = len(grp[grp['pnl_dollar'] <= 0])
            pnl = grp['pnl_dollar'].sum()
            end_bal = grp['balance_after'].iloc[-1]
            print(f"  {year}: {n}笔 ({w}盈/{l}亏) "
                  f"盈亏 ${pnl:>12,.0f} 余额 ${end_bal:>14,.0f}")

        # 按币种
        print(f"\n{'─'*40}")
        print("  币种表现:")
        for coin, grp in df.groupby('coin'):
            n = len(grp)
            w = len(grp[grp['pnl_dollar'] > 0])
            pnl = grp['pnl_dollar'].sum()
            wr = w / n * 100 if n > 0 else 0
            print(f"  {coin:<12} {n:>3}笔 胜率{wr:>5.1f}% 盈亏 ${pnl:>14,.0f}")

        # 按信号类型
        print(f"\n{'─'*40}")
        print("  信号类型表现:")
        for stype, grp in df.groupby('signal_type'):
            n = len(grp)
            w = len(grp[grp['pnl_dollar'] > 0])
            pnl = grp['pnl_dollar'].sum()
            wr = w / n * 100 if n > 0 else 0
            print(f"  {stype:<8} {n:>3}笔 胜率{wr:>5.1f}% 盈亏 ${pnl:>14,.0f}")

        # 出场原因
        print(f"\n{'─'*40}")
        print("  出场原因分布:")
        for reason, grp in df.groupby('exit_reason'):
            n = len(grp)
            pnl = grp['pnl_dollar'].sum()
            print(f"  {reason:<20} {n:>3}笔 盈亏 ${pnl:>14,.0f}")

        print(f"\n{'='*70}")

    def to_dataframe(self) -> pd.DataFrame:
        """将交易记录转为DataFrame"""
        if not self.trades:
            return pd.DataFrame()
        return pd.DataFrame([t.__dict__ for t in self.trades])

    def save_trades(self, path: str):
        """保存交易记录到CSV"""
        df = self.to_dataframe()
        df.to_csv(path, index=False)
        print(f"交易记录已保存: {path}")

    def print_filter_summary(self):
        """打印被过滤信号的统计（需要 verbose=True）"""
        if not self.filtered_log:
            print("没有过滤日志（需要 verbose=True 运行回测）")
            return

        print(f"\n{'='*60}")
        print(f"信号过滤诊断 ({len(self.filtered_log)} 个信号被过滤)")
        print(f"{'='*60}")

        # 按过滤原因统计
        from collections import Counter
        reason_counts = Counter(f['reason'] for f in self.filtered_log)
        print("\n  过滤原因分布:")
        for reason, count in reason_counts.most_common():
            print(f"    {reason:<35} {count:>4}笔")

        # 按币种统计
        coin_counts = Counter(f['coin'] for f in self.filtered_log)
        print("\n  被过滤币种分布:")
        for coin, count in coin_counts.most_common():
            print(f"    {coin:<15} {count:>4}笔")

        # 按信号类型统计
        type_counts = Counter(f['type'] for f in self.filtered_log)
        print("\n  被过滤信号类型:")
        for stype, count in type_counts.most_common():
            print(f"    {stype:<10} {count:>4}笔")

        print(f"\n{'='*60}")

    def save_filter_log(self, path: str):
        """保存过滤日志到CSV"""
        if self.filtered_log:
            pd.DataFrame(self.filtered_log).to_csv(path, index=False)
            print(f"过滤日志已保存: {path}")
