"""
参数自适应优化器 — AI学习系统第三层

核心规则（提示词七）:
- 每个参数只允许±20%浮动
- 最少N个样本才能做调整
- 每次只调一个参数（控制变量）
- A/B测试: 70%旧参数 + 30%新参数，对比1个月

不用深度学习/神经网络，用贝叶斯更新+滑动窗口统计
"""

import logging
import json
from typing import Dict, List, Optional, Tuple
from datetime import datetime, timezone
from dataclasses import dataclass

import numpy as np
import pandas as pd

from .memory import TradeMemory
from ..data.database import Database
from ..config import settings as S

logger = logging.getLogger(__name__)


@dataclass
class ABTest:
    """A/B测试记录"""
    test_id: str
    param_name: str
    old_value: float
    new_value: float
    start_date: str
    allocation: float      # 新参数分配比例（默认30%）
    old_trades: int = 0
    old_pnl: float = 0
    new_trades: int = 0
    new_pnl: float = 0
    status: str = 'running'  # running / completed / rejected
    result: str = ''


# 参数配置表（对应提示词七）
PARAM_SPECS = {
    'btc_trailing': {
        'base': 0.15, 'min': 0.12, 'max': 0.18,
        'min_samples': 20, 'description': 'BTC trailing回撤%',
    },
    'eth_trailing': {
        'base': 0.20, 'min': 0.16, 'max': 0.24,
        'min_samples': 20, 'description': 'ETH trailing回撤%',
    },
    'alt_trailing': {
        'base': 0.25, 'min': 0.20, 'max': 0.30,
        'min_samples': 20, 'description': '山寨trailing回撤%',
    },
    '4h_position_limit': {
        'base': 0.25, 'min': 0.15, 'max': 0.35,
        'min_samples': 30, 'description': '4H信号仓位上限',
    },
    'consecutive_loss_pause': {
        'base': 5, 'min': 3, 'max': 7,
        'min_samples': 10, 'description': '连败暂停阈值',
    },
    'vol_filter_alt': {
        'base': 2.0, 'min': 1.5, 'max': 4.0,
        'min_samples': 20, 'description': '山寨量比过滤阈值',
    },
    'vol_filter_btc': {
        'base': 3.0, 'min': 2.0, 'max': 4.0,
        'min_samples': 20, 'description': 'BTC量比减仓阈值',
    },
}


class ParamOptimizer:
    """
    参数自适应优化器

    用法:
        optimizer = ParamOptimizer(memory, db)

        # 获取当前参数
        params = optimizer.get_current_params()

        # 应用月度复盘建议
        optimizer.apply_suggestion(suggestion)

        # 启动A/B测试
        optimizer.start_ab_test('btc_trailing', new_value=0.13)

        # 评估A/B测试结果（1个月后）
        optimizer.evaluate_ab_tests()

        # 获取某笔交易应该用的参数（考虑A/B分配）
        params = optimizer.get_trade_params(trade_id_hash)
    """

    def __init__(self, memory: TradeMemory, db: Database):
        self.memory = memory
        self.db = db

        # 当前参数
        self.current_params: Dict[str, float] = {}
        self.ab_tests: Dict[str, ABTest] = {}

        self._restore_state()

    def _restore_state(self):
        """从数据库恢复"""
        # 当前参数
        params_json = self.db.get_state('optimized_params', '{}')
        try:
            self.current_params = json.loads(params_json)
        except:
            self.current_params = {}

        # 填充默认值
        for name, spec in PARAM_SPECS.items():
            if name not in self.current_params:
                self.current_params[name] = spec['base']

        # A/B测试
        ab_json = self.db.get_state('ab_tests', '{}')
        try:
            ab_data = json.loads(ab_json)
            for tid, td in ab_data.items():
                self.ab_tests[tid] = ABTest(**td)
        except:
            self.ab_tests = {}

    def _save_state(self):
        """保存状态"""
        self.db.set_state('optimized_params', json.dumps(self.current_params))
        ab_data = {tid: {
            'test_id': t.test_id, 'param_name': t.param_name,
            'old_value': t.old_value, 'new_value': t.new_value,
            'start_date': t.start_date, 'allocation': t.allocation,
            'old_trades': t.old_trades, 'old_pnl': t.old_pnl,
            'new_trades': t.new_trades, 'new_pnl': t.new_pnl,
            'status': t.status, 'result': t.result,
        } for tid, t in self.ab_tests.items()}
        self.db.set_state('ab_tests', json.dumps(ab_data))

    # ==================== 参数管理 ====================

    def get_current_params(self) -> Dict[str, float]:
        """获取当前所有参数"""
        return dict(self.current_params)

    def get_param(self, name: str) -> float:
        """获取单个参数"""
        return self.current_params.get(name, PARAM_SPECS.get(name, {}).get('base', 0))

    def update_param(self, name: str, value: float, reason: str = '') -> bool:
        """
        更新单个参数（带范围检查）

        Returns: 是否更新成功
        """
        spec = PARAM_SPECS.get(name)
        if not spec:
            logger.error(f"Unknown param: {name}")
            return False

        # 范围检查
        clamped = max(spec['min'], min(spec['max'], value))
        if clamped != value:
            logger.warning(f"Param {name} clamped: {value} → {clamped}")

        # ±20%检查
        base = spec['base']
        if abs(clamped - base) / base > 0.20:
            logger.warning(f"Param {name} exceeds ±20% from base "
                           f"({base} → {clamped})")

        old = self.current_params.get(name, base)
        self.current_params[name] = clamped
        self._save_state()

        logger.info(f"Param updated: {name} = {old} → {clamped} ({reason})")
        return True

    # ==================== A/B测试 ====================

    def start_ab_test(
        self,
        param_name: str,
        new_value: float,
        allocation: float = 0.30,
    ) -> Optional[ABTest]:
        """
        启动A/B测试

        allocation: 新参数的仓位分配比例（默认30%）
        """
        # 检查是否已有该参数的测试
        for test in self.ab_tests.values():
            if test.param_name == param_name and test.status == 'running':
                logger.warning(f"AB test for {param_name} already running")
                return None

        test_id = f"ab_{param_name}_{int(datetime.now(timezone.utc).timestamp())}"
        old_value = self.current_params.get(param_name, PARAM_SPECS[param_name]['base'])

        test = ABTest(
            test_id=test_id,
            param_name=param_name,
            old_value=old_value,
            new_value=new_value,
            start_date=datetime.now(timezone.utc).isoformat(),
            allocation=allocation,
        )

        self.ab_tests[test_id] = test
        self._save_state()

        logger.info(f"AB test started: {param_name} "
                    f"{old_value} vs {new_value} ({allocation:.0%} new)")
        return test

    def get_trade_params(self, trade_hash: int) -> Dict[str, float]:
        """
        获取某笔交易应该使用的参数

        根据trade_hash决定是否使用A/B测试的新参数:
        - hash % 100 < allocation * 100 → 用新参数
        - 否则 → 用旧参数

        Args:
            trade_hash: 交易的hash值（用于确定分组）
        Returns:
            完整参数字典
        """
        params = dict(self.current_params)

        for test in self.ab_tests.values():
            if test.status != 'running':
                continue

            # 根据hash确定分组
            use_new = (trade_hash % 100) < (test.allocation * 100)
            if use_new:
                params[test.param_name] = test.new_value

        return params

    def record_ab_result(self, test_id: str, used_new: bool, pnl: float):
        """记录A/B测试的交易结果"""
        test = self.ab_tests.get(test_id)
        if not test or test.status != 'running':
            return

        if used_new:
            test.new_trades += 1
            test.new_pnl += pnl
        else:
            test.old_trades += 1
            test.old_pnl += pnl

        self._save_state()

    def evaluate_ab_tests(self) -> List[dict]:
        """
        评估所有运行中的A/B测试

        判定标准:
        - 至少运行1个月
        - 新旧参数各至少10笔交易
        - 新参数盈亏 >= 旧参数盈亏 → 采纳
        - 否则 → 回退

        Returns:
            List of evaluation results
        """
        results = []

        for test_id, test in list(self.ab_tests.items()):
            if test.status != 'running':
                continue

            total = test.old_trades + test.new_trades
            if total < 20:
                results.append({
                    'test_id': test_id,
                    'status': 'insufficient_data',
                    'detail': f'样本不足: {total}/20',
                })
                continue

            if test.old_trades < 10 or test.new_trades < 10:
                results.append({
                    'test_id': test_id,
                    'status': 'insufficient_per_group',
                    'detail': f'分组不足: old={test.old_trades}, new={test.new_trades}',
                })
                continue

            # 比较
            old_avg = test.old_pnl / test.old_trades
            new_avg = test.new_pnl / test.new_trades

            if new_avg >= old_avg:
                # 新参数更好 → 采纳
                test.status = 'completed'
                test.result = f'采纳: 新${new_avg:+,.0f}/笔 vs 旧${old_avg:+,.0f}/笔'
                self.update_param(
                    test.param_name, test.new_value,
                    f'AB测试通过: {test.result}'
                )
                logger.info(f"AB test {test_id}: ADOPTED "
                            f"({test.param_name} = {test.new_value})")
            else:
                # 旧参数更好 → 回退
                test.status = 'rejected'
                test.result = f'回退: 新${new_avg:+,.0f}/笔 vs 旧${old_avg:+,.0f}/笔'
                logger.info(f"AB test {test_id}: REJECTED")

            results.append({
                'test_id': test_id,
                'param': test.param_name,
                'old_value': test.old_value,
                'new_value': test.new_value,
                'old_avg_pnl': old_avg,
                'new_avg_pnl': new_avg,
                'status': test.status,
                'result': test.result,
            })

        self._save_state()
        return results

    # ==================== 自动发现 ====================

    def auto_discover(self) -> List[str]:
        """
        自动发现规则（提示词七第四层）:
        - 某币种连续3个月4H信号全亏 → 暂停
        - 某币种连续3笔日线信号盈利 → 加仓10%
        """
        discoveries = []
        all_trades = self.memory.get_all()

        if all_trades.empty:
            return discoveries

        # 检查每个币种的4H信号表现
        for coin in all_trades['coin'].unique():
            coin_4h = all_trades[
                (all_trades['coin'] == coin) &
                (all_trades['signal_type'] == '4H')
            ].sort_values('entry_dt')

            if len(coin_4h) < 5:
                continue

            # 最近的4H信号是否全亏
            recent_4h = coin_4h.tail(5)
            if all(recent_4h['pnl_dollar'] <= 0):
                msg = f"⚠️ {coin} 最近5笔4H信号全亏,建议暂停"
                discoveries.append(msg)
                logger.warning(msg)

            # 日线信号连续盈利
            coin_daily = all_trades[
                (all_trades['coin'] == coin) &
                (all_trades['signal_type'].isin(['日线', '共振']))
            ].sort_values('entry_dt')

            if len(coin_daily) >= 3:
                last3 = coin_daily.tail(3)
                if all(last3['pnl_dollar'] > 0):
                    msg = f"✅ {coin} 最近3笔日线信号盈利,可考虑加仓10%"
                    discoveries.append(msg)

        return discoveries
