"""
实时信号扫描器

M1核心：每小时扫描所有币种，检测新信号并推送

架构:
1. 每小时整点触发
2. 拉取最新K线 → 聚合 → 跑信号管线
3. 新信号 → 过滤 → 仓位计算 → Telegram推送
4. 记录到数据库
"""

import logging
import pandas as pd
from typing import Dict, List, Optional
from datetime import datetime, timezone

from ..config.coins import COINS, get_coin
from ..config import settings as S
from ..core.pipeline import generate_long_signals, Signal
from ..core.shorts import find_short_signals
from ..core.filters import SignalFilter, calc_trend_60d
from ..core.position import calc_position, calc_effective_equity
from ..data.database import Database
from ..notify.telegram_bot import TelegramNotifier

logger = logging.getLogger(__name__)


class LiveScanner:
    """
    实时信号扫描器

    用法:
        scanner = LiveScanner(db, notifier, balance=1000)
        new_signals = scanner.scan_all(all_tf_data)
    """

    def __init__(
        self,
        db: Database,
        notifier: TelegramNotifier,
        balance: float = 1000.0,
        month_start_balance: Optional[float] = None,
    ):
        self.db = db
        self.notifier = notifier
        self.balance = balance
        self.month_start_balance = month_start_balance or balance

        self.signal_filter = SignalFilter()
        self.known_signal_keys = set()  # 已知信号去重

        # 从数据库恢复状态
        self._restore_state()

    def _restore_state(self):
        """从数据库恢复过滤器状态"""
        consec = self.db.get_state('consecutive_losses', '0')
        self.signal_filter.consecutive_losses = int(consec)
        self.signal_filter.paused_4h = self.db.get_state('paused_4h', '0') == '1'

        bal = self.db.get_state('balance', '')
        if bal:
            self.balance = float(bal)

        msb = self.db.get_state('month_start_balance', '')
        if msb:
            self.month_start_balance = float(msb)

        logger.info(f"State restored: balance=${self.balance:,.0f}, "
                    f"consecutive_losses={self.signal_filter.consecutive_losses}, "
                    f"paused_4h={self.signal_filter.paused_4h}")

    def _save_state(self):
        """保存状态到数据库"""
        self.db.set_state('consecutive_losses', str(self.signal_filter.consecutive_losses))
        self.db.set_state('paused_4h', '1' if self.signal_filter.paused_4h else '0')
        self.db.set_state('balance', str(self.balance))
        self.db.set_state('month_start_balance', str(self.month_start_balance))

    def _signal_key(self, sig: Signal) -> str:
        """生成信号唯一标识（用于去重）"""
        return f"{sig.coin}_{sig.direction}_{sig.signal_type}_{sig.entry_time.isoformat()}"

    def scan_all(
        self,
        all_tf_data: Dict[str, Dict[str, pd.DataFrame]],
    ) -> List[dict]:
        """
        扫描所有币种，返回新信号列表

        Args:
            all_tf_data: {symbol: {'1H': df, '4H': df, '1D': df}}
        Returns:
            List of signal dicts (已通过过滤，包含仓位建议)
        """
        scan_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')
        logger.info(f"Scanning all coins at {scan_time}")

        # 收集所有信号
        all_signals: List[Signal] = []
        btc_daily_breakout_times = []

        for symbol, tf_data in all_tf_data.items():
            if symbol not in COINS:
                continue

            cfg = COINS[symbol]

            # 做多信号
            try:
                long_sigs = generate_long_signals(
                    symbol, tf_data['1H'], tf_data['4H'], tf_data['1D']
                )
                all_signals.extend(long_sigs)

                # 收集BTC日线突破时间
                for s in long_sigs:
                    if s.coin == 'BTCUSDT' and s.signal_type in ('日线', '共振'):
                        btc_daily_breakout_times.append(s.entry_time)

            except Exception as e:
                logger.error(f"Long signal error {symbol}: {e}")

            # 做空信号
            if cfg.allow_short:
                try:
                    short_sigs = find_short_signals(symbol, tf_data['4H'])
                    all_signals.extend(short_sigs)
                except Exception as e:
                    logger.error(f"Short signal error {symbol}: {e}")

        # 按时间排序，只处理最近24小时的信号（避免重复处理历史信号）
        all_signals.sort(key=lambda s: s.entry_time)
        cutoff = pd.Timestamp.now(tz='UTC') - pd.Timedelta(hours=24)
        recent_signals = [s for s in all_signals if s.entry_time > cutoff]

        logger.info(f"Found {len(all_signals)} total signals, "
                    f"{len(recent_signals)} in last 24h")

        # 过滤和处理
        new_signals = []
        btc_tf = all_tf_data.get('BTCUSDT', {})
        btc_1d = btc_tf.get('1D')

        for sig in recent_signals:
            key = self._signal_key(sig)
            if key in self.known_signal_keys:
                continue  # 已处理过

            self.known_signal_keys.add(key)

            # 获取趋势数据
            tf_data = all_tf_data.get(sig.coin, {})
            coin_1d = tf_data.get('1D')
            if coin_1d is None or coin_1d.empty:
                continue

            coin_trend = calc_trend_60d(
                coin_1d['close'].values, coin_1d['timestamp'], sig.entry_time
            )
            btc_trend = 0.0
            if btc_1d is not None and not btc_1d.empty:
                btc_trend = calc_trend_60d(
                    btc_1d['close'].values, btc_1d['timestamp'], sig.entry_time
                )

            # 过滤
            passed, reason, adjustments = self.signal_filter.check_signal(
                sig, btc_trend, coin_trend, btc_daily_breakout_times
            )

            # 仓位计算
            effective_equity = calc_effective_equity(self.balance, self.month_start_balance)
            cfg = get_coin(sig.coin)

            if passed:
                pos = calc_position(
                    equity=effective_equity,
                    sl_distance=sig.sl_distance,
                    coin=sig.coin,
                    signal_type=sig.signal_type,
                    leverage=cfg.leverage,
                    position_mult=adjustments['position_mult'],
                    cross_resonance=adjustments.get('cross_resonance', False),
                )
            else:
                pos = {'margin': 0, 'notional': 0, 'position_pct': 0,
                       'skip': True, 'leverage': cfg.leverage}

            # 构建信号字典
            signal_dict = {
                'coin': sig.coin,
                'direction': sig.direction,
                'signal_type': sig.signal_type,
                'entry_time': str(sig.entry_time),
                'entry_price': sig.entry_price,
                'stop_loss': sig.stop_loss,
                'sl_distance': sig.sl_distance,
                'box_high': sig.box_high,
                'box_low': sig.box_low,
                'box_width_pct': sig.box_width_pct,
                'box_duration_bars': sig.box_duration_bars,
                'volume_ratio': sig.volume_ratio,
                'entry_method': sig.entry_method,
                'btc_trend_60d': btc_trend,
                'coin_trend_60d': coin_trend,
                'filter_passed': 1 if passed else 0,
                'filter_reason': reason,
                'suggested_position_pct': pos.get('position_pct', 0),
                'position_pct': pos.get('position_pct', 0),
                'notional': pos.get('notional', 0),
                'leverage': cfg.leverage,
                'cross_resonance': adjustments.get('cross_resonance', False),
            }

            # 保存到数据库
            signal_id = self.db.save_signal(signal_dict)

            # 通过过滤的信号 → 推送
            if passed and not pos.get('skip'):
                new_signals.append(signal_dict)
                logger.info(
                    f"NEW SIGNAL: {sig.coin} {sig.direction} {sig.signal_type} "
                    f"@ ${sig.entry_price:,.2f}, SL=${sig.stop_loss:,.2f}"
                )
            else:
                logger.debug(f"Filtered: {sig.coin} {sig.signal_type} - {reason}")

        self._save_state()
        return new_signals

    def get_coins_status(
        self,
        all_tf_data: Dict[str, Dict[str, pd.DataFrame]],
    ) -> List[dict]:
        """
        获取所有币种当前状态概览
        """
        status_list = []
        btc_tf = all_tf_data.get('BTCUSDT', {})
        btc_1d = btc_tf.get('1D')

        for symbol, tf_data in all_tf_data.items():
            if symbol not in COINS:
                continue

            df_1d = tf_data.get('1D')
            if df_1d is None or df_1d.empty:
                continue

            price = df_1d['close'].iloc[-1]
            now = df_1d['timestamp'].iloc[-1]

            trend = calc_trend_60d(df_1d['close'].values, df_1d['timestamp'], now)

            status_list.append({
                'coin': symbol,
                'price': price,
                'trend_60d': trend,
                'has_active_box': False,  # TODO: detect from pipeline state
            })

        return status_list
