"""
6层信号过滤 + 做空过滤

按优先级依次检查，任一层被过滤则信号无效。
"""

import pandas as pd
import numpy as np
from typing import List, Optional, Dict
from datetime import timedelta

from .pipeline import Signal
from ..config import settings as S


def calc_trend_60d(closes: np.ndarray, timestamps: pd.DatetimeIndex,
                   target_date: pd.Timestamp) -> float:
    """
    计算60日趋势: (close[date] - close[date-60]) / close[date-60]
    """
    # 找到target_date对应的最近一根K线
    mask = timestamps <= target_date
    if mask.sum() < 2:
        return 0.0

    idx_now = mask.sum() - 1
    close_now = closes[idx_now]

    # 找60天前的close
    target_60d_ago = target_date - timedelta(days=60)
    mask_60 = timestamps <= target_60d_ago
    if mask_60.sum() < 1:
        return 0.0

    idx_60 = mask_60.sum() - 1
    close_60 = closes[idx_60]

    if close_60 == 0:
        return 0.0
    return (close_now - close_60) / close_60


class SignalFilter:
    """
    信号过滤器：维护状态，逐信号过滤
    """

    def __init__(self):
        self.consecutive_losses = 0         # 做多连败计数
        self.paused_4h = False              # 4H信号暂停
        self.last_entry_by_coin: Dict[str, pd.Timestamp] = {}  # 同币去重
        self.last_short_by_coin: Dict[str, pd.Timestamp] = {}  # 做空去重
        self.today_4h_entries: Dict[str, pd.Timestamp] = {}    # 同天去重

    def update_trade_result(self, signal: Signal, pnl: float):
        """
        交易结果回调：更新连败计数

        连败计数只对做多信号
        """
        if signal.direction == 'long':
            if pnl < 0:
                self.consecutive_losses += 1
                if self.consecutive_losses >= S.CONSECUTIVE_LOSS_PAUSE:
                    self.paused_4h = True
            else:
                self.consecutive_losses = 0

    def check_signal(
        self,
        signal: Signal,
        btc_trend_60d: float,
        coin_trend_60d: float,
        btc_daily_breakout_times: List[pd.Timestamp],
    ) -> tuple:
        """
        检查信号是否通过所有过滤层

        Returns:
            (passed: bool, reason: str, adjustments: dict)
            adjustments可能包含仓位调整因子
        """
        adjustments = {'position_mult': 1.0, 'cross_resonance': False}

        # ==================== 做空过滤（独立规则） ====================
        if signal.direction == 'short':
            return self._filter_short(signal, btc_trend_60d, coin_trend_60d)

        # ==================== 做多过滤（6层） ====================

        # 第1层：趋势死区
        if signal.signal_type == '4H':
            if S.DEAD_ZONE_LOW <= coin_trend_60d <= S.DEAD_ZONE_HIGH:
                return False, "趋势死区(-30%~-15%)", adjustments

        # 第2层：BTC大盘联动
        if signal.coin != "BTCUSDT" and signal.signal_type == '4H':
            if btc_trend_60d <= S.BTC_TREND_THRESHOLD:
                return False, "BTC趋势<=0,山寨4H不做", adjustments

        # 第3层：量比过滤
        from ..config.coins import get_coin
        cfg = get_coin(signal.coin)
        if signal.signal_type == '4H' and signal.coin != "BTCUSDT":
            if signal.volume_ratio > cfg.vol_filter_4h:
                return False, f"量比{signal.volume_ratio:.1f}x>{cfg.vol_filter_4h}x", adjustments
        if signal.coin == "BTCUSDT" and signal.volume_ratio > cfg.vol_filter_reduce:
            adjustments['position_mult'] *= 0.5

        # 第4层：同天去重
        if signal.signal_type == '4H':
            entry_date = signal.entry_time.date()
            date_key = str(entry_date)
            if date_key in self.today_4h_entries:
                if self.today_4h_entries[date_key] != signal.coin:
                    return False, "同天已有其他币4H信号", adjustments
            self.today_4h_entries[date_key] = signal.coin

        # 第5层：连败暂停
        if signal.signal_type == '4H' and self.paused_4h:
            # 日线或共振信号可以恢复
            return False, f"连败{self.consecutive_losses}笔,4H暂停", adjustments

        # 恢复条件：日线或共振信号出现时恢复4H
        if signal.signal_type in ('日线', '共振') and self.paused_4h:
            self.paused_4h = False
            self.consecutive_losses = 0

        # 第6层：同币去重
        if signal.coin in self.last_entry_by_coin:
            last_entry = self.last_entry_by_coin[signal.coin]
            if signal.signal_type == '4H':
                if (signal.entry_time - last_entry).days < S.SAME_COIN_COOLDOWN_DAYS:
                    return False, f"同币{S.SAME_COIN_COOLDOWN_DAYS}天内已有入场", adjustments

        # ==================== 仓位调整 ====================

        # 4H信号趋势减仓
        if signal.signal_type == '4H':
            if coin_trend_60d < S.ADJ_TREND_VERY_WEAK[0]:
                adjustments['position_mult'] *= S.ADJ_TREND_VERY_WEAK[1]
            elif coin_trend_60d < S.ADJ_TREND_WEAK[0]:
                adjustments['position_mult'] *= S.ADJ_TREND_WEAK[1]

        # 4H + SL距离1-4% + 趋势<0
        if signal.signal_type == '4H':
            sl_lo, sl_hi, sl_mult = S.ADJ_4H_SL_WIDE
            if sl_lo <= signal.sl_distance <= sl_hi and coin_trend_60d < 0:
                adjustments['position_mult'] *= sl_mult

        # 跨币共振（BTC日线突破14天内）
        if signal.coin != "BTCUSDT" and signal.signal_type in ('日线', '共振'):
            window_days, mult, cap = S.ADJ_CROSS_RESONANCE
            for bt in btc_daily_breakout_times:
                if abs((signal.entry_time - bt).days) <= window_days:
                    adjustments['cross_resonance'] = True
                    adjustments['position_mult'] *= mult
                    break

        return True, "通过", adjustments

    def _filter_short(self, signal: Signal, btc_trend: float, coin_trend: float):
        """做空信号过滤"""
        adjustments = {'position_mult': 1.0, 'cross_resonance': False}

        # BTC 60日趋势 > 30%: 不做空
        if btc_trend > S.SHORT_BTC_TREND_MAX:
            return False, "BTC趋势>30%,不做空", adjustments

        # 本币60日趋势 > 15%: 不做空
        if coin_trend > S.SHORT_COIN_TREND_MAX:
            return False, "本币趋势>15%,不做空", adjustments

        # 同币种做空间隔至少7天
        if signal.coin in self.last_short_by_coin:
            last = self.last_short_by_coin[signal.coin]
            if (signal.entry_time - last).days < S.SHORT_SAME_COIN_COOLDOWN:
                return False, "做空间隔不足7天", adjustments

        return True, "通过", adjustments

    def record_entry(self, signal: Signal):
        """记录入场（用于去重）"""
        self.last_entry_by_coin[signal.coin] = signal.entry_time
        if signal.direction == 'short':
            self.last_short_by_coin[signal.coin] = signal.entry_time
