"""
Step 1: ZigZag 摆动点识别
输入：K线数据
输出：高点(H)和低点(L)交替序列

只用价格 (high/low)，不用任何技术指标。
阈值：日线15% / 4H 8%
"""

import numpy as np
import pandas as pd
from typing import List, Tuple


def zigzag(
    highs: np.ndarray,
    lows: np.ndarray,
    threshold: float,
) -> List[Tuple[int, str, float]]:
    """
    ZigZag摆动点识别

    Args:
        highs: K线最高价数组
        lows: K线最低价数组
        threshold: 摆动阈值（如0.15 = 15%）

    Returns:
        List of (bar_index, type, price)
        type: 'H' = 高点, 'L' = 低点
        H 和 L 交替出现
    """
    n = len(highs)
    if n < 2:
        return []

    points = []       # (index, 'H'/'L', price)
    direction = 0     # 0=未定, 1=上行追踪高点, -1=下行追踪低点
    last_high_idx = 0
    last_high_val = highs[0]
    last_low_idx = 0
    last_low_val = lows[0]

    for i in range(1, n):
        if direction == 0:
            # 初始方向未定，看先涨还是先跌达到阈值
            if highs[i] > last_low_val * (1 + threshold):
                # 先涨达到阈值 → 确认低点，开始追踪高点
                points.append((last_low_idx, 'L', last_low_val))
                direction = 1
                last_high_idx = i
                last_high_val = highs[i]
            elif lows[i] < last_high_val * (1 - threshold):
                # 先跌达到阈值 → 确认高点，开始追踪低点
                points.append((last_high_idx, 'H', last_high_val))
                direction = -1
                last_low_idx = i
                last_low_val = lows[i]
            else:
                # 更新候选极值
                if highs[i] > last_high_val:
                    last_high_idx = i
                    last_high_val = highs[i]
                if lows[i] < last_low_val:
                    last_low_idx = i
                    last_low_val = lows[i]

        elif direction == 1:
            # 上行中追踪最高点
            if highs[i] > last_high_val:
                last_high_idx = i
                last_high_val = highs[i]
            elif lows[i] < last_high_val * (1 - threshold):
                # 从高点回撤超过阈值 → 确认高点
                points.append((last_high_idx, 'H', last_high_val))
                direction = -1
                last_low_idx = i
                last_low_val = lows[i]

        elif direction == -1:
            # 下行中追踪最低点
            if lows[i] < last_low_val:
                last_low_idx = i
                last_low_val = lows[i]
            elif highs[i] > last_low_val * (1 + threshold):
                # 从低点反弹超过阈值 → 确认低点
                points.append((last_low_idx, 'L', last_low_val))
                direction = 1
                last_high_idx = i
                last_high_val = highs[i]

    return points
