"""
信号生成管线 (Steps 2-6)
Step 2: 下跌段识别
Step 3: 盒子识别
Step 4: 放量阳线突破
Step 5: 多周期叠加
Step 6: 1H精准入场
"""

import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from datetime import timedelta

from .zigzag import zigzag
from ..config import settings as S


@dataclass
class DropLeg:
    """下跌段"""
    high_idx: int
    high_price: float
    low_idx: int
    low_price: float
    drop_pct: float


@dataclass
class Box:
    """盒子（底部过渡区）"""
    start_idx: int          # 盒子起始bar（下跌段终点）
    end_idx: int            # 盒子结束bar
    box_low: float          # 盒子底
    box_high: float         # 盒子顶
    width_pct: float        # 振幅
    duration_bars: int      # 持续K线数
    drop_leg: DropLeg       # 前置下跌段


@dataclass
class Breakout:
    """突破信号"""
    bar_idx: int            # 突破K线索引
    price: float            # 突破收盘价
    volume_ratio: float     # 量比
    box: Box                # 关联盒子
    timeframe: str          # '1D' or '4H'


@dataclass
class Signal:
    """完整交易信号"""
    coin: str
    direction: str          # 'long' or 'short'
    signal_type: str        # '共振', '日线', '4H'
    entry_time: pd.Timestamp
    entry_price: float
    stop_loss: float
    sl_distance: float      # 止损距离百分比
    box_high: float
    box_low: float
    box_width_pct: float
    box_duration_bars: int
    volume_ratio: float
    entry_method: str       # '1H回踩' or '4H直接'
    breakout_4h: Optional[Breakout] = None
    breakout_1d: Optional[Breakout] = None


# ==================== Step 2: 下跌段识别 ====================

def find_drop_legs(
    swing_points: list,
    min_drop: float,
) -> List[DropLeg]:
    """
    从摆动点序列中识别所有有效下跌段 (H→L)
    条件：跌幅 >= min_drop
    """
    legs = []
    for i in range(len(swing_points) - 1):
        idx_h, type_h, price_h = swing_points[i]
        idx_l, type_l, price_l = swing_points[i + 1]
        if type_h == 'H' and type_l == 'L':
            drop_pct = (price_h - price_l) / price_h
            if drop_pct >= min_drop:
                legs.append(DropLeg(
                    high_idx=idx_h,
                    high_price=price_h,
                    low_idx=idx_l,
                    low_price=price_l,
                    drop_pct=drop_pct,
                ))
    return legs


# ==================== Step 3: 盒子识别 ====================

def find_boxes(
    drop_legs: List[DropLeg],
    highs: np.ndarray,
    lows: np.ndarray,
    closes: np.ndarray,
    min_bars: int,
    max_width: float,
    new_low_tolerance: float = S.BOX_NEW_LOW_TOLERANCE,
    max_scan: int = S.BOX_MAX_SCAN_BARS,
) -> List[Box]:
    """
    从下跌段终点开始扫描，识别盒子（底部过渡区）

    三条件：
    1. 前置：有明显下跌
    2. 形态：不创新低（容忍1%），振幅不超max_width
    3. 最少持续min_bars根K线
    """
    n = len(highs)
    boxes = []

    for leg in drop_legs:
        start = leg.low_idx
        if start >= n - min_bars:
            continue

        leg_low = leg.low_price
        box_low = lows[start]
        box_high = highs[start]

        end_idx = start
        valid = True

        for j in range(start + 1, min(start + max_scan, n)):
            lo = lows[j]
            hi = highs[j]

            # 创新低（超过容忍范围）→ 盒子无效
            if lo < leg_low * (1 - new_low_tolerance):
                valid = False
                break

            box_low = min(box_low, lo)
            box_high = max(box_high, hi)

            # 振幅超限 → 停止扫描
            width = (box_high - box_low) / box_low
            if width > max_width:
                break

            end_idx = j

        duration = end_idx - start + 1
        if duration >= min_bars and valid:
            width_pct = (box_high - box_low) / box_low
            boxes.append(Box(
                start_idx=start,
                end_idx=end_idx,
                box_low=box_low,
                box_high=box_high,
                width_pct=width_pct,
                duration_bars=duration,
                drop_leg=leg,
            ))

    return boxes


# ==================== Step 4: 放量阳线突破 ====================

def find_breakouts(
    boxes: List[Box],
    opens: np.ndarray,
    highs: np.ndarray,
    lows: np.ndarray,
    closes: np.ndarray,
    volumes: np.ndarray,
    confirm_bars: int,
    max_wait: int = S.BREAKOUT_MAX_WAIT,
    vol_ratio_threshold: float = S.BREAKOUT_VOL_RATIO,
    timeframe: str = '1D',
) -> List[Breakout]:
    """
    在盒子后寻找放量阳线突破

    条件：
    1. 收盘 > 盒子顶
    2. 阳线（close > open）
    3. 成交量 > 前20根均量 × 1.5
    4. 后续confirm_bars根K线收盘不跌回盒顶
    每个盒子只取第一个有效突破。
    """
    n = len(closes)
    breakouts = []

    for box in boxes:
        found = False
        search_start = box.end_idx + 1
        search_end = min(search_start + max_wait, n)

        for i in range(search_start, search_end):
            if found:
                break

            # 条件1: 收盘 > 盒顶
            if closes[i] <= box.box_high:
                continue
            # 条件2: 阳线
            if closes[i] <= opens[i]:
                continue
            # 条件3: 放量
            vol_start = max(0, i - 20)
            avg_vol = np.mean(volumes[vol_start:i]) if i > vol_start else volumes[i]
            if avg_vol == 0:
                continue
            vol_ratio = volumes[i] / avg_vol
            if vol_ratio < vol_ratio_threshold:
                continue

            # 条件4: 后续不跌回盒顶
            confirm_end = min(i + confirm_bars + 1, n)
            confirmed = True
            for k in range(i + 1, confirm_end):
                if closes[k] < box.box_high:
                    confirmed = False
                    break

            if confirmed and (confirm_end - i - 1) >= confirm_bars:
                breakouts.append(Breakout(
                    bar_idx=i,
                    price=closes[i],
                    volume_ratio=vol_ratio,
                    box=box,
                    timeframe=timeframe,
                ))
                found = True

    return breakouts


# ==================== Step 5: 多周期叠加 ====================

def match_resonance(
    breakouts_4h: List[Breakout],
    breakouts_1d: List[Breakout],
    timestamps_4h: pd.DatetimeIndex,
    timestamps_1d: pd.DatetimeIndex,
    window_days: int = S.RESONANCE_WINDOW_DAYS,
) -> List[dict]:
    """
    多周期叠加：将4H和日线突破配对
    共振条件：4H突破前后window_days天内有日线突破

    Returns:
        List of dicts with keys:
        - signal_type: '共振', '日线', '4H'
        - breakout_4h: Breakout or None
        - breakout_1d: Breakout or None
        - primary_time: 入场参考时间
    """
    window = timedelta(days=window_days)
    results = []

    # 已匹配的突破索引
    matched_4h = set()
    matched_1d = set()

    # 先找共振（4H和日线在窗口内都有突破）
    for i, bo4 in enumerate(breakouts_4h):
        t4 = timestamps_4h[bo4.bar_idx]
        for j, bo1d in enumerate(breakouts_1d):
            t1d = timestamps_1d[bo1d.bar_idx]
            if abs(t4 - t1d) <= window:
                results.append({
                    'signal_type': '共振',
                    'breakout_4h': bo4,
                    'breakout_1d': bo1d,
                    'primary_time': max(t4, t1d),  # 以较晚的为入场时间
                })
                matched_4h.add(i)
                matched_1d.add(j)
                break  # 每个4H只配一个日线

    # 未匹配的日线单独信号
    for j, bo1d in enumerate(breakouts_1d):
        if j not in matched_1d:
            results.append({
                'signal_type': '日线',
                'breakout_4h': None,
                'breakout_1d': bo1d,
                'primary_time': timestamps_1d[bo1d.bar_idx],
            })

    # 未匹配的4H单独信号
    for i, bo4 in enumerate(breakouts_4h):
        if i not in matched_4h:
            results.append({
                'signal_type': '4H',
                'breakout_4h': bo4,
                'breakout_1d': None,
                'primary_time': timestamps_4h[bo4.bar_idx],
            })

    return results


# ==================== Step 6: 1H 精准入场 ====================

def find_1h_entry(
    box_high: float,
    breakout_time: pd.Timestamp,
    df_1h: pd.DataFrame,
    tolerance: float = S.ENTRY_1H_TOLERANCE,
    max_wait: int = S.ENTRY_1H_MAX_WAIT_BARS,
) -> Optional[dict]:
    """
    1H精准入场：在突破后等回踩盒顶

    1. 从突破时间开始，在1H级别等价格回踩盒顶(±tolerance)
    2. 回踩后出现阳线确认(close>open 且 close>盒顶)
    3. 入场价 = 这根阳线的收盘价
    4. max_wait小时内没回踩 → 返回None（降级4H直接入场）

    Returns:
        dict with 'entry_time', 'entry_price', 'entry_method' or None
    """
    mask = df_1h['timestamp'] > breakout_time
    bars = df_1h[mask].head(max_wait)

    low_target = box_high * (1 - tolerance)
    high_target = box_high * (1 + tolerance)

    touched = False
    for _, bar in bars.iterrows():
        # 检查是否回踩到盒顶区域
        if bar['low'] <= high_target:
            touched = True

        # 回踩后，等阳线确认
        if touched:
            if bar['close'] > bar['open'] and bar['close'] > box_high:
                return {
                    'entry_time': bar['timestamp'],
                    'entry_price': bar['close'],
                    'entry_method': '1H回踩',
                }

    return None  # 48小时内没找到 → 降级


# ==================== 完整信号管线 ====================

def run_pipeline_single_tf(
    df: pd.DataFrame,
    timeframe: str,
    zigzag_threshold: float,
    drop_min: float,
    box_min_bars: int,
    box_max_width: float,
    confirm_bars: int,
) -> List[Breakout]:
    """
    对单个时间周期运行 Step 1-4

    Args:
        df: K线数据 (timestamp, open, high, low, close, volume)
        timeframe: '1D' or '4H'
    Returns:
        List of Breakout
    """
    highs = df['high'].values
    lows = df['low'].values
    opens = df['open'].values
    closes = df['close'].values
    volumes = df['volume'].values

    # Step 1: ZigZag
    points = zigzag(highs, lows, zigzag_threshold)

    # Step 2: 下跌段
    legs = find_drop_legs(points, drop_min)

    # Step 3: 盒子
    boxes = find_boxes(legs, highs, lows, closes, box_min_bars, box_max_width)

    # Step 4: 突破
    breakouts = find_breakouts(
        boxes, opens, highs, lows, closes, volumes,
        confirm_bars=confirm_bars,
        timeframe=timeframe,
    )

    return breakouts


def generate_long_signals(
    coin: str,
    df_1h: pd.DataFrame,
    df_4h: pd.DataFrame,
    df_1d: pd.DataFrame,
) -> List[Signal]:
    """
    生成某个币种的所有做多信号

    完整管线：
    1. 日线 Step 1-4
    2. 4H Step 1-4
    3. Step 5: 多周期叠加
    4. Step 6: 1H精准入场
    """
    # 日线管线
    breakouts_1d = run_pipeline_single_tf(
        df_1d, '1D',
        S.ZIGZAG_THRESHOLD_DAILY,
        S.DROP_LEG_MIN_DAILY,
        S.BOX_MIN_BARS_DAILY,
        S.BOX_MAX_WIDTH_DAILY,
        S.BREAKOUT_CONFIRM_DAILY,
    )

    # 4H管线
    breakouts_4h = run_pipeline_single_tf(
        df_4h, '4H',
        S.ZIGZAG_THRESHOLD_4H,
        S.DROP_LEG_MIN_4H,
        S.BOX_MIN_BARS_4H,
        S.BOX_MAX_WIDTH_4H,
        S.BREAKOUT_CONFIRM_4H,
    )

    # Step 5: 多周期叠加
    matched = match_resonance(
        breakouts_4h, breakouts_1d,
        df_4h['timestamp'], df_1d['timestamp'],
    )

    signals = []
    for m in matched:
        # 确定使用哪个突破的盒子
        if m['breakout_1d'] is not None:
            primary_bo = m['breakout_1d']
            primary_df = df_1d
        else:
            primary_bo = m['breakout_4h']
            primary_df = df_4h

        box = primary_bo.box
        box_high = box.box_high

        # Step 6: 1H精准入场
        entry_info = find_1h_entry(box_high, m['primary_time'], df_1h)

        if entry_info is not None:
            entry_price = entry_info['entry_price']
            entry_time = entry_info['entry_time']
            entry_method = '1H回踩'
        else:
            # 降级4H直接入场
            bo = m['breakout_4h'] if m['breakout_4h'] else m['breakout_1d']
            tf_df = df_4h if m['breakout_4h'] else df_1d
            entry_price = tf_df.iloc[bo.bar_idx]['close']
            entry_time = tf_df.iloc[bo.bar_idx]['timestamp']
            entry_method = '4H直接'

        # 止损
        stop_loss = box_high * (1 - S.SL_LONG_OFFSET)
        sl_distance = (entry_price - stop_loss) / entry_price

        # 止损距离>7%不做
        if sl_distance > S.SL_MAX_DISTANCE:
            continue

        signals.append(Signal(
            coin=coin,
            direction='long',
            signal_type=m['signal_type'],
            entry_time=entry_time,
            entry_price=entry_price,
            stop_loss=stop_loss,
            sl_distance=sl_distance,
            box_high=box_high,
            box_low=box.box_low,
            box_width_pct=box.width_pct,
            box_duration_bars=box.duration_bars,
            volume_ratio=primary_bo.volume_ratio,
            entry_method=entry_method,
            breakout_4h=m['breakout_4h'],
            breakout_1d=m['breakout_1d'],
        ))

    return signals
