"""
交易记忆库 — AI学习系统第一层

每笔交易存储完整的环境快照（对应提示词七的字段定义）:
- 基础交易数据（入场/出场/盈亏）
- 市场环境（BTC趋势、本币趋势）
- 盒子质量（宽度、持续K线数）
- 突破强度（量比）
- 持仓过程（最大浮盈、出场效率）
- 信号分类（共振/日线/4H/做空）

这些数据供月度复盘和参数自适应使用
"""

import json
import logging
from typing import List, Dict, Optional
from datetime import datetime, timezone

import pandas as pd
import numpy as np

from ..data.database import Database

logger = logging.getLogger(__name__)


# 记忆库字段定义（对应提示词七）
MEMORY_FIELDS = [
    'trade_id', 'coin', 'direction', 'signal_type',
    'entry_dt', 'exit_dt', 'entry_price', 'exit_price',
    'pnl_pct', 'pnl_dollar', 'hold_bars', 'hold_days',
    'btc_trend_60d', 'coin_trend_60d',
    'box_width_pct', 'box_duration_bars',
    'breakout_vol_ratio', 'sl_distance',
    'concurrent_positions', 'max_unrealized_pct',
    'exit_efficiency', 'exit_reason', 'market_phase',
    'entry_method', 'leverage', 'margin', 'notional',
    'fees', 'balance_after',
]


class TradeMemory:
    """
    交易记忆库

    用法:
        memory = TradeMemory(db)

        # 从回测结果导入
        memory.import_from_backtest(trades)

        # 添加单笔交易
        memory.add_trade(trade_snapshot)

        # 查询
        recent = memory.get_recent(months=3)
        by_type = memory.get_by_signal_type('共振')
        btc_trades = memory.get_by_coin('BTCUSDT')

        # 统计
        stats = memory.calc_stats(recent)
    """

    def __init__(self, db: Database):
        self.db = db
        self._ensure_table()

    def _ensure_table(self):
        """创建记忆表"""
        self.db.conn.execute("""
            CREATE TABLE IF NOT EXISTS trade_memory (
                trade_id TEXT PRIMARY KEY,
                coin TEXT NOT NULL,
                direction TEXT NOT NULL,
                signal_type TEXT NOT NULL,
                entry_dt TEXT,
                exit_dt TEXT,
                entry_price REAL,
                exit_price REAL,
                pnl_pct REAL,
                pnl_dollar REAL,
                hold_bars INTEGER,
                hold_days REAL,
                btc_trend_60d REAL,
                coin_trend_60d REAL,
                box_width_pct REAL,
                box_duration_bars INTEGER,
                breakout_vol_ratio REAL,
                sl_distance REAL,
                concurrent_positions INTEGER DEFAULT 0,
                max_unrealized_pct REAL,
                exit_efficiency REAL,
                exit_reason TEXT,
                market_phase TEXT,
                entry_method TEXT,
                leverage INTEGER,
                margin REAL,
                notional REAL,
                fees REAL,
                balance_after REAL,
                created_at TEXT DEFAULT (datetime('now'))
            )
        """)
        self.db.conn.execute("""
            CREATE INDEX IF NOT EXISTS idx_memory_coin
            ON trade_memory(coin, entry_dt)
        """)
        self.db.conn.execute("""
            CREATE INDEX IF NOT EXISTS idx_memory_type
            ON trade_memory(signal_type, entry_dt)
        """)
        self.db.conn.commit()

    # ==================== 数据写入 ====================

    def add_trade(self, snapshot: dict):
        """添加单笔交易快照"""
        # 自动判断市场阶段
        if 'market_phase' not in snapshot or not snapshot['market_phase']:
            snapshot['market_phase'] = self._detect_market_phase(
                snapshot.get('btc_trend_60d', 0),
                snapshot.get('entry_dt', ''),
            )

        cols = [f for f in MEMORY_FIELDS if f in snapshot]
        placeholders = ','.join(['?'] * len(cols))
        col_names = ','.join(cols)
        values = [snapshot.get(f) for f in cols]

        self.db.conn.execute(
            f"INSERT OR REPLACE INTO trade_memory ({col_names}) VALUES ({placeholders})",
            values
        )
        self.db.conn.commit()

    def import_from_backtest(self, trades: list):
        """
        从回测TradeRecord列表批量导入

        Args:
            trades: List of TradeRecord (from backtest engine)
        """
        count = 0
        for t in trades:
            snapshot = {
                'trade_id': f"bt_{t.trade_id}",
                'coin': t.coin,
                'direction': t.direction,
                'signal_type': t.signal_type,
                'entry_dt': str(t.entry_time),
                'exit_dt': str(t.exit_time),
                'entry_price': t.entry_price,
                'exit_price': t.exit_price,
                'pnl_pct': t.pnl_pct,
                'pnl_dollar': t.pnl_dollar,
                'hold_bars': t.hold_bars,
                'hold_days': t.hold_bars * 4 / 24,
                'btc_trend_60d': t.btc_trend_60d,
                'coin_trend_60d': t.coin_trend_60d,
                'box_width_pct': t.box_width_pct,
                'box_duration_bars': 0,
                'breakout_vol_ratio': t.volume_ratio,
                'sl_distance': t.sl_distance,
                'max_unrealized_pct': t.max_unrealized_pct,
                'exit_efficiency': t.exit_efficiency,
                'exit_reason': t.exit_reason,
                'entry_method': t.entry_method,
                'leverage': t.leverage,
                'margin': t.margin,
                'notional': t.notional,
                'fees': t.fees,
                'balance_after': t.balance_after,
            }
            self.add_trade(snapshot)
            count += 1

        logger.info(f"Imported {count} trades from backtest")

    # ==================== 数据查询 ====================

    def get_all(self) -> pd.DataFrame:
        """获取全部交易记录"""
        return pd.read_sql_query(
            "SELECT * FROM trade_memory ORDER BY entry_dt", self.db.conn
        )

    def get_recent(self, months: int = 3) -> pd.DataFrame:
        """获取最近N个月的交易"""
        return pd.read_sql_query(
            "SELECT * FROM trade_memory WHERE entry_dt >= datetime('now', ?) ORDER BY entry_dt",
            self.db.conn, params=[f'-{months} months']
        )

    def get_by_signal_type(self, signal_type: str, months: int = 0) -> pd.DataFrame:
        """按信号类型查询"""
        query = "SELECT * FROM trade_memory WHERE signal_type = ?"
        params = [signal_type]
        if months > 0:
            query += " AND entry_dt >= datetime('now', ?)"
            params.append(f'-{months} months')
        query += " ORDER BY entry_dt"
        return pd.read_sql_query(query, self.db.conn, params=params)

    def get_by_coin(self, coin: str, months: int = 0) -> pd.DataFrame:
        """按币种查询"""
        query = "SELECT * FROM trade_memory WHERE coin = ?"
        params = [coin]
        if months > 0:
            query += " AND entry_dt >= datetime('now', ?)"
            params.append(f'-{months} months')
        query += " ORDER BY entry_dt"
        return pd.read_sql_query(query, self.db.conn, params=params)

    def get_profitable(self, months: int = 0) -> pd.DataFrame:
        """获取盈利交易"""
        query = "SELECT * FROM trade_memory WHERE pnl_dollar > 0"
        params = []
        if months > 0:
            query += " AND entry_dt >= datetime('now', ?)"
            params.append(f'-{months} months')
        return pd.read_sql_query(query, self.db.conn, params=params)

    def count(self) -> int:
        """总交易数"""
        cursor = self.db.conn.execute("SELECT COUNT(*) FROM trade_memory")
        return cursor.fetchone()[0]

    # ==================== 统计分析 ====================

    def calc_stats(self, df: pd.DataFrame) -> dict:
        """
        计算交易统计指标

        Args:
            df: trade_memory DataFrame（get_recent等返回值）
        Returns:
            dict with win_rate, avg_win, avg_loss, payoff, pf, etc.
        """
        if df.empty:
            return {'n': 0, 'win_rate': 0, 'avg_win': 0, 'avg_loss': 0,
                    'payoff': 0, 'pf': 0, 'total_pnl': 0}

        n = len(df)
        wins = df[df['pnl_dollar'] > 0]
        losses = df[df['pnl_dollar'] <= 0]

        win_rate = len(wins) / n * 100 if n > 0 else 0

        avg_win = wins['pnl_pct'].mean() * 100 if len(wins) > 0 else 0
        avg_loss = abs(losses['pnl_pct'].mean()) * 100 if len(losses) > 0 else 0.01

        gross_profit = wins['pnl_dollar'].sum() if len(wins) > 0 else 0
        gross_loss = abs(losses['pnl_dollar'].sum()) if len(losses) > 0 else 0.01

        avg_exit_eff = wins['exit_efficiency'].mean() if len(wins) > 0 else 0

        return {
            'n': n,
            'wins': len(wins),
            'losses': len(losses),
            'win_rate': win_rate,
            'avg_win': avg_win,
            'avg_loss': avg_loss,
            'payoff': avg_win / avg_loss if avg_loss > 0 else 0,
            'pf': gross_profit / gross_loss if gross_loss > 0 else 0,
            'total_pnl': df['pnl_dollar'].sum(),
            'avg_exit_efficiency': avg_exit_eff,
            'avg_hold_bars': df['hold_bars'].mean(),
        }

    def calc_stats_by_group(self, df: pd.DataFrame, group_col: str) -> Dict[str, dict]:
        """按分组计算统计"""
        result = {}
        for name, grp in df.groupby(group_col):
            result[str(name)] = self.calc_stats(grp)
        return result

    # ==================== 辅助 ====================

    def _detect_market_phase(self, btc_trend: float, entry_dt: str) -> str:
        """根据BTC趋势简单判断市场阶段"""
        if btc_trend > 0.30:
            return 'bull'
        elif btc_trend > 0.05:
            return 'recovery'
        elif btc_trend > -0.15:
            return 'consolidation'
        else:
            return 'bear'
