"""
Binance 实时数据拉取模块

与 binance_downloader.py 不同，这个模块负责:
- 每小时增量拉取最新K线
- 维护本地数据库的最新状态
- 处理API限流和错误重试
"""

import time
import logging
import requests
import pandas as pd
from typing import Optional, Dict, List
from datetime import datetime, timezone, timedelta

from ..config.coins import COINS
from ..data.database import Database
from ..data.aggregator import prepare_all_timeframes

logger = logging.getLogger(__name__)

BINANCE_KLINE_URL = "https://api.binance.com/api/v3/klines"


class BinanceLiveFetcher:
    """
    实时K线数据拉取器

    用法:
        fetcher = BinanceLiveFetcher(db)
        fetcher.update_all()  # 更新所有币种
        data = fetcher.get_all_timeframes('BTCUSDT')  # 获取多周期数据
    """

    def __init__(self, db: Database, max_retries: int = 3):
        self.db = db
        self.max_retries = max_retries
        self.session = requests.Session()
        self.session.headers['User-Agent'] = 'LangLangAI/1.0'

    def fetch_latest_klines(
        self,
        symbol: str,
        interval: str = '1h',
        limit: int = 500,
    ) -> pd.DataFrame:
        """
        拉取最新K线数据（增量）

        自动从数据库最新时间戳之后开始拉取
        """
        # 检查数据库中最新数据
        latest = self.db.get_latest_timestamp(symbol, interval)

        params = {
            "symbol": symbol,
            "interval": interval,
            "limit": limit,
        }

        if latest:
            # 从最新时间戳之后开始
            start_ts = int(pd.Timestamp(latest, tz='UTC').timestamp() * 1000) + 1
            params["startTime"] = start_ts

        for attempt in range(self.max_retries):
            try:
                resp = self.session.get(
                    BINANCE_KLINE_URL, params=params, timeout=30
                )
                resp.raise_for_status()
                data = resp.json()

                if not data:
                    return pd.DataFrame()

                df = pd.DataFrame(data, columns=[
                    'open_time', 'open', 'high', 'low', 'close', 'volume',
                    'close_time', 'quote_volume', 'trades', 'taker_buy_vol',
                    'taker_buy_quote_vol', 'ignore'
                ])

                df['timestamp'] = pd.to_datetime(df['open_time'], unit='ms', utc=True)
                for col in ['open', 'high', 'low', 'close', 'volume']:
                    df[col] = df[col].astype(float)

                df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]

                # 去掉未完成的最后一根K线（当前正在形成的）
                now = pd.Timestamp.now(tz='UTC')
                if interval == '1h':
                    current_bar_start = now.floor('1h')
                elif interval == '4h':
                    current_bar_start = now.floor('4h')
                else:
                    current_bar_start = now.floor('1D')

                df = df[df['timestamp'] < current_bar_start]

                logger.info(f"Fetched {len(df)} new bars for {symbol} {interval}")
                return df

            except requests.exceptions.RequestException as e:
                logger.warning(f"Fetch attempt {attempt+1} failed for {symbol}: {e}")
                if attempt < self.max_retries - 1:
                    time.sleep(2 ** attempt)  # exponential backoff
                continue

        logger.error(f"Failed to fetch {symbol} after {self.max_retries} attempts")
        return pd.DataFrame()

    def update_symbol(self, symbol: str) -> int:
        """
        更新单个币种的1H数据并存入数据库

        Returns: 新增K线数量
        """
        df = self.fetch_latest_klines(symbol, '1h')
        if df.empty:
            return 0

        self.db.save_klines(symbol, '1h', df)
        return len(df)

    def update_all(self) -> Dict[str, int]:
        """
        更新所有12币种

        Returns: {symbol: n_new_bars}
        """
        results = {}
        for symbol in COINS:
            n = self.update_symbol(symbol)
            results[symbol] = n
            if n > 0:
                logger.info(f"  {symbol}: +{n} bars")
            time.sleep(0.2)  # rate limiting
        return results

    def initialize_history(self, symbol: str, days: int = 365):
        """
        首次运行时，拉取历史数据填充数据库

        Args:
            symbol: 交易对
            days: 拉取多少天的历史
        """
        existing = self.db.get_kline_count(symbol, '1h')
        if existing > days * 20:  # 大概有足够数据了
            logger.info(f"  {symbol}: already has {existing} bars, skip init")
            return

        logger.info(f"  Initializing {symbol} ({days} days)...")
        cfg = COINS.get(symbol)
        if not cfg:
            return

        start_date = cfg.data_start
        end_ts = int(datetime.now(timezone.utc).timestamp() * 1000)
        start_ts = int(pd.Timestamp(start_date, tz='UTC').timestamp() * 1000)

        all_bars = 0
        current = start_ts

        while current < end_ts:
            params = {
                "symbol": symbol,
                "interval": "1h",
                "startTime": current,
                "limit": 1000,
            }
            try:
                resp = self.session.get(BINANCE_KLINE_URL, params=params, timeout=30)
                resp.raise_for_status()
                data = resp.json()
            except Exception as e:
                logger.error(f"  Init error {symbol}: {e}")
                time.sleep(5)
                continue

            if not data:
                break

            df = pd.DataFrame(data, columns=[
                'open_time', 'open', 'high', 'low', 'close', 'volume',
                'close_time', 'quote_volume', 'trades', 'taker_buy_vol',
                'taker_buy_quote_vol', 'ignore'
            ])
            df['timestamp'] = pd.to_datetime(df['open_time'], unit='ms', utc=True)
            for col in ['open', 'high', 'low', 'close', 'volume']:
                df[col] = df[col].astype(float)
            df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]

            self.db.save_klines(symbol, '1h', df)
            all_bars += len(df)
            current = data[-1][6] + 1
            time.sleep(0.3)

        logger.info(f"  {symbol}: initialized {all_bars} bars")

    def initialize_all(self, days: int = 365):
        """初始化所有币种历史数据"""
        for symbol in COINS:
            self.initialize_history(symbol, days)

    def get_all_timeframes(self, symbol: str) -> Dict[str, pd.DataFrame]:
        """
        从数据库加载并聚合出所有时间周期

        Returns: {'1H': df, '4H': df, '1D': df}
        """
        df_1h = self.db.load_klines(symbol, '1h')
        if df_1h.empty:
            return {}
        return prepare_all_timeframes(df_1h)
