"""
SQLite 数据库模块

存储:
1. K线数据（增量更新，避免重复拉取）
2. 信号历史（所有检测到的信号）
3. 系统状态（上次运行时间、连败计数等）
"""

import sqlite3
import os
import pandas as pd
import numpy as np
from typing import Optional, List, Dict
from datetime import datetime, timezone


DEFAULT_DB_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)), "langlang.db"
)


class Database:
    """
    SQLite数据库管理器

    用法:
        db = Database()
        db.save_klines('BTCUSDT', '1h', df)
        df = db.load_klines('BTCUSDT', '1h')
        db.save_signal(signal_dict)
    """

    def __init__(self, db_path: str = DEFAULT_DB_PATH):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self._init_tables()

    def _init_tables(self):
        """创建数据表"""
        cursor = self.conn.cursor()

        # K线数据表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS klines (
                symbol TEXT NOT NULL,
                interval TEXT NOT NULL,
                timestamp TEXT NOT NULL,
                open REAL NOT NULL,
                high REAL NOT NULL,
                low REAL NOT NULL,
                close REAL NOT NULL,
                volume REAL NOT NULL,
                PRIMARY KEY (symbol, interval, timestamp)
            )
        """)

        # 信号历史表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS signals (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                created_at TEXT NOT NULL DEFAULT (datetime('now')),
                coin TEXT NOT NULL,
                direction TEXT NOT NULL,
                signal_type TEXT NOT NULL,
                entry_time TEXT NOT NULL,
                entry_price REAL NOT NULL,
                stop_loss REAL NOT NULL,
                sl_distance REAL NOT NULL,
                box_high REAL,
                box_low REAL,
                box_width_pct REAL,
                box_duration_bars INTEGER,
                volume_ratio REAL,
                entry_method TEXT,
                suggested_position_pct REAL,
                btc_trend_60d REAL,
                coin_trend_60d REAL,
                filter_passed INTEGER DEFAULT 1,
                filter_reason TEXT,
                notified INTEGER DEFAULT 0,
                executed INTEGER DEFAULT 0,
                notes TEXT
            )
        """)

        # 系统状态表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS system_state (
                key TEXT PRIMARY KEY,
                value TEXT NOT NULL,
                updated_at TEXT NOT NULL DEFAULT (datetime('now'))
            )
        """)

        # 创建索引
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_klines_symbol_ts
            ON klines(symbol, interval, timestamp)
        """)
        cursor.execute("""
            CREATE INDEX IF NOT EXISTS idx_signals_coin_time
            ON signals(coin, entry_time)
        """)

        self.conn.commit()

    # ==================== K线数据 ====================

    def save_klines(self, symbol: str, interval: str, df: pd.DataFrame):
        """
        保存K线数据（增量upsert）
        """
        if df.empty:
            return

        data = []
        for _, row in df.iterrows():
            ts = str(row['timestamp'])
            data.append((
                symbol, interval, ts,
                float(row['open']), float(row['high']),
                float(row['low']), float(row['close']),
                float(row['volume'])
            ))

        self.conn.executemany("""
            INSERT OR REPLACE INTO klines
            (symbol, interval, timestamp, open, high, low, close, volume)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
        """, data)
        self.conn.commit()

    def load_klines(self, symbol: str, interval: str,
                    start: Optional[str] = None,
                    limit: Optional[int] = None) -> pd.DataFrame:
        """
        加载K线数据

        Args:
            symbol: 交易对
            interval: K线周期
            start: 起始时间 (ISO format)
            limit: 最多返回行数
        """
        query = "SELECT timestamp, open, high, low, close, volume FROM klines WHERE symbol=? AND interval=?"
        params = [symbol, interval]

        if start:
            query += " AND timestamp >= ?"
            params.append(start)

        query += " ORDER BY timestamp ASC"

        if limit:
            query += f" LIMIT {limit}"

        df = pd.read_sql_query(query, self.conn, params=params)

        if not df.empty:
            df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)

        return df

    def get_latest_timestamp(self, symbol: str, interval: str) -> Optional[str]:
        """获取最新K线时间戳"""
        cursor = self.conn.execute(
            "SELECT MAX(timestamp) FROM klines WHERE symbol=? AND interval=?",
            (symbol, interval)
        )
        result = cursor.fetchone()[0]
        return result

    def get_kline_count(self, symbol: str, interval: str) -> int:
        """获取K线数量"""
        cursor = self.conn.execute(
            "SELECT COUNT(*) FROM klines WHERE symbol=? AND interval=?",
            (symbol, interval)
        )
        return cursor.fetchone()[0]

    # ==================== 信号历史 ====================

    def save_signal(self, signal: dict) -> int:
        """
        保存信号记录
        Returns: signal id
        """
        cols = [
            'coin', 'direction', 'signal_type', 'entry_time', 'entry_price',
            'stop_loss', 'sl_distance', 'box_high', 'box_low', 'box_width_pct',
            'box_duration_bars', 'volume_ratio', 'entry_method',
            'suggested_position_pct', 'btc_trend_60d', 'coin_trend_60d',
            'filter_passed', 'filter_reason', 'notes'
        ]

        present_cols = [c for c in cols if c in signal]
        placeholders = ','.join(['?'] * len(present_cols))
        col_names = ','.join(present_cols)
        values = [signal[c] for c in present_cols]

        cursor = self.conn.execute(
            f"INSERT INTO signals ({col_names}) VALUES ({placeholders})",
            values
        )
        self.conn.commit()
        return cursor.lastrowid

    def get_recent_signals(self, hours: int = 24,
                           coin: Optional[str] = None) -> pd.DataFrame:
        """获取最近N小时的信号"""
        query = """
            SELECT * FROM signals
            WHERE created_at >= datetime('now', ?)
        """
        params = [f'-{hours} hours']

        if coin:
            query += " AND coin = ?"
            params.append(coin)

        query += " ORDER BY created_at DESC"
        return pd.read_sql_query(query, self.conn, params=params)

    def mark_signal_notified(self, signal_id: int):
        """标记信号已推送"""
        self.conn.execute(
            "UPDATE signals SET notified=1 WHERE id=?", (signal_id,)
        )
        self.conn.commit()

    # ==================== 系统状态 ====================

    def set_state(self, key: str, value: str):
        """设置系统状态"""
        self.conn.execute(
            "INSERT OR REPLACE INTO system_state (key, value, updated_at) VALUES (?, ?, datetime('now'))",
            (key, value)
        )
        self.conn.commit()

    def get_state(self, key: str, default: str = '') -> str:
        """获取系统状态"""
        cursor = self.conn.execute(
            "SELECT value FROM system_state WHERE key=?", (key,)
        )
        result = cursor.fetchone()
        return result[0] if result else default

    # ==================== 清理 ====================

    def close(self):
        """关闭数据库连接"""
        self.conn.close()

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()
