import os
import pandas as pd
import numpy as np
import sqlite3
import itertools
import MetaTrader5 as mt5  # Required by CointegrationRanker, must be installed and initialized if connecting to a broker
from datetime import datetime, timedelta, timezone
from statsmodels.tsa.vector_ar.vecm import coint_johansen  # Required by CointegrationRanker
from statsmodels.tsa.stattools import coint              # Required by CointegrationRanker
from typing import List, Dict, Any, Optional
from tqdm import tqdm
import logging
import time
import threading

# --- Paste the full CointegrationRanker class definition here ---
# (Assuming the class is defined in a file or can be imported;
# for a standalone script, the class definition needs to be included.)

class CointegrationRanker:
    """
    A class to screen and rank pairs or baskets of assets by cointegration strength.
    Integrates with SQLite market_data database and MetaTrader5 for missing or updated data.
    """

    def __init__(self, db_path: str):
        self.db_path = db_path
        self.conn = None
        self.logger = self._setup_logger()

    # -----------------------------------------------------
    # TIMEFRAME CONVERSION HELPERS
    # -----------------------------------------------------
    @staticmethod
    def _tf_to_mt5(timeframe: str) -> int:
        """Convert string timeframe ('D1', 'H1', etc.) to MetaTrader5 constant."""
        tf = timeframe.upper()
        tf_map = {
            "M1": mt5.TIMEFRAME_M1, "M2": mt5.TIMEFRAME_M2, "M3": mt5.TIMEFRAME_M3,
            "M4": mt5.TIMEFRAME_M4, "M5": mt5.TIMEFRAME_M5, "M6": mt5.TIMEFRAME_M6,
            "M10": mt5.TIMEFRAME_M10, "M12": mt5.TIMEFRAME_M12, "M15": mt5.TIMEFRAME_M15,
            "M20": mt5.TIMEFRAME_M20, "M30": mt5.TIMEFRAME_M30, "H1": mt5.TIMEFRAME_H1,
            "H2": mt5.TIMEFRAME_H2, "H3": mt5.TIMEFRAME_H3, "H4": mt5.TIMEFRAME_H4,
            "H6": mt5.TIMEFRAME_H6, "H8": mt5.TIMEFRAME_H8, "H12": mt5.TIMEFRAME_H12,
            "D1": mt5.TIMEFRAME_D1, "W1": mt5.TIMEFRAME_W1, "MN1": mt5.TIMEFRAME_MN1,
        }
        return tf_map.get(tf, mt5.TIMEFRAME_D1)

    @staticmethod
    def _mt5_to_tf(mt5_timeframe: int) -> str:
        """Convert MetaTrader5 constant to string timeframe."""
        reverse_map = {
            mt5.TIMEFRAME_M1: "M1", mt5.TIMEFRAME_M2: "M2", mt5.TIMEFRAME_M3: "M3",
            mt5.TIMEFRAME_M4: "M4", mt5.TIMEFRAME_M5: "M5", mt5.TIMEFRAME_M6: "M6",
            mt5.TIMEFRAME_M10: "M10", mt5.TIMEFRAME_M12: "M12", mt5.TIMEFRAME_M15: "M15",
            mt5.TIMEFRAME_M20: "M20", mt5.TIMEFRAME_M30: "M30", mt5.TIMEFRAME_H1: "H1",
            mt5.TIMEFRAME_H2: "H2", mt5.TIMEFRAME_H3: "H3", mt5.TIMEFRAME_H4: "H4",
            mt5.TIMEFRAME_H6: "H6", mt5.TIMEFRAME_H8: "H8", mt5.TIMEFRAME_H12: "H12",
            mt5.TIMEFRAME_D1: "D1", mt5.TIMEFRAME_W1: "W1", mt5.TIMEFRAME_MN1: "MN1",
        }
        return reverse_map.get(mt5_timeframe, "D1")

    # -----------------------------------------------------
    # LOGGER SETUP
    # -----------------------------------------------------
    def _setup_logger(self) -> logging.Logger:
        logger = logging.getLogger("CointegrationRanker")
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s"))
        if not logger.handlers:
            logger.addHandler(handler)
        return logger

    # -----------------------------------------------------
    # DATABASE CONNECTION MANAGEMENT
    # -----------------------------------------------------
    def _connect(self):
        if self.conn is None:
            self.conn = sqlite3.connect(self.db_path)

    def _disconnect(self):
        if self.conn:
            self.conn.close()
            self.conn = None

    # -----------------------------------------------------
    # AUTOMATIC MARKET DATA UPDATE FROM METATRADER 5
    # -----------------------------------------------------
    def update_market_data(self, timeframe: str, days_back: int = 30, mt5_timeout: int = 10, use_persistent_mt5: bool = True):

        """
        Synchronize local market_data table with MetaTrader5 for all active symbols.

        Parameters
        ----------
        timeframe : str
            MT5 timeframe (e.g. 'D1', 'H1', 'M15')
        days_back : int
            Number of past calendar days to fetch from MT5.
        mt5_timeout : int
            Timeout (seconds) for each MT5 request.
        """
        timeframe = timeframe.upper()
        self._connect()
        now = datetime.now(timezone.utc)
        since = int((now - timedelta(days=days_back)).timestamp())

        # Fetch all active symbols
        symbol_df = pd.read_sql("SELECT symbol_id, ticker FROM symbol;", self.conn)

        if not use_persistent_mt5:
            if not mt5.initialize():
                self.logger.error("Failed to initialize MetaTrader 5 terminal.")
                return


        self.logger.info(f"Updating {len(symbol_df)} symbols for timeframe {timeframe}...")
        for _, row in tqdm(symbol_df.iterrows(), total=len(symbol_df), desc="Synchronizing MT5 data"):
            sid, ticker = row["symbol_id"], row["ticker"]

            # Get last timestamp available in DB for this symbol/timeframe
            q = f"""
            SELECT MAX(tstamp) AS last_ts FROM market_data
            WHERE symbol_id = {sid} AND timeframe = '{timeframe}';
            """
            last_ts = pd.read_sql(q, self.conn)["last_ts"].iloc[0]
            from_ts = last_ts if pd.notnull(last_ts) else since

            # Pull new candles from MT5 with timeout
            rates = None

            def fetch_rates_update():
                nonlocal rates
                try:
                    # The original code uses 'datetime.utcnow()' and 'days_back' in copy_rates_from,
                    # which is not a standard signature. Assuming a time-based fetch is intended,
                    # but for updates, we should fetch *from* the last timestamp.
                    # Since the original code fetches 'days_back' and filters later, we will keep
                    # the original structure but note the potential for large data fetches if
                    # 'days_back' is long.
                    rates = mt5.copy_rates_from(ticker, self._tf_to_mt5(timeframe), datetime.now(timezone.utc), days_back)
                except Exception as e:
                    self.logger.warning(f"MT5 fetch error for {ticker}: {e}")
                    rates = None

            thread = threading.Thread(target=fetch_rates_update)
            thread.start()
            thread.join(timeout=mt5_timeout)

            if thread.is_alive():
                self.logger.warning(f"Timeout fetching data for {ticker} ({timeframe}). Skipping.")
                continue
            if rates is None or len(rates) == 0:
                self.logger.warning(f"MT5 returned no data for {ticker} ({timeframe}).")
                continue

            df = pd.DataFrame(rates)
            df = df.rename(columns={
                "time": "tstamp", "open": "price_open", "high": "price_high",
                "low": "price_low", "close": "price_close", "tick_volume": "tick_volume",
                "real_volume": "real_volume", "spread": "spread"
            })
            df["timeframe"] = timeframe
            df["symbol_id"] = sid

            # Keep only new candles
            if pd.notnull(last_ts):
                df = df[df["tstamp"] > last_ts]

            if len(df) == 0:
                continue

            df.to_sql("market_data", self.conn, if_exists="append", index=False)
            self.logger.info(f"Updated {ticker}: {len(df)} new bars added.")

        if not use_persistent_mt5:
            mt5.shutdown()

        self._disconnect()
        self.logger.info("Market data synchronization completed.")


    # -----------------------------------------------------
    # MULTI-TIMEFRAME SYNCHRONIZATION
    # -----------------------------------------------------
    def update_market_data_multi(self, timeframes: List[str], days_back: int = 30, delay_sec: int = 2, mt5_timeout: int = 10):
        """
        Update/synchronize multiple MT5 timeframes sequentially.

        Parameters
        ----------
        timeframes : list of str
            Example: ['H1', 'D1', 'W1']
        days_back : int
            Number of calendar days to fetch for each timeframe.
        delay_sec : int
            Pause between timeframe updates (helps avoid MT5 rate limits).
        mt5_timeout : int
            Timeout for each MT5 request.
        """
        self.logger.info(f"Starting multi-timeframe sync for: {timeframes}")
        for tf in timeframes:
            try:
                self.update_market_data(timeframe=tf, days_back=days_back, mt5_timeout=mt5_timeout)
            except Exception as e:
                self.logger.warning(f"Error updating timeframe {tf}: {e}")
            finally:
                time.sleep(delay_sec)  # small delay between timeframe updates
        self.logger.info("Multi-timeframe synchronization completed.")

    # -----------------------------------------------------
    # PRICE DATA LOADING (DB + MT5 FALLBACK)
    # -----------------------------------------------------
    def _load_price_data(self, timeframe: str, lookback_days: int, tickers_filter: Optional[List[str]] = None, mt5_timeout: int = 10) -> pd.DataFrame:
        """
        Load price_close series from DB for the requested tickers/timeframe/lookback.
        If missing, attempt to fetch full candles from MT5 and insert them (schema-compliant).
        """
        timeframe = timeframe.upper()
        self._connect()

        now = datetime.now(timezone.utc)
        start_ts = int((now - timedelta(days=lookback_days)).timestamp())

        symbol_df = pd.read_sql("SELECT symbol_id, ticker FROM symbol;", self.conn)
        if tickers_filter:
            symbol_df = symbol_df[symbol_df["ticker"].isin(tickers_filter)]
        symbol_map = dict(zip(symbol_df.ticker, symbol_df.symbol_id))

        price_data = {}

        for ticker, sid in tqdm(symbol_map.items(), desc="Loading price data"):
            query = f"""
            SELECT tstamp, price_close
            FROM market_data
            WHERE symbol_id = {sid}
              AND timeframe = '{timeframe}'
              AND tstamp >= {start_ts}
            ORDER BY tstamp;
            """
            df = pd.read_sql(query, self.conn)

            # If DB has insufficient data, try to fetch from MT5 (full candles)
            if len(df) == 0 or (len(df) > 0 and df["tstamp"].iloc[-1] < start_ts):
                self.logger.info(f"Fetching missing data for {ticker} from MetaTrader 5.")
                if not mt5.initialize():
                    self.logger.error("Failed to initialize MetaTrader 5 terminal.")
                    continue

                rates = None

                def fetch_rates_load():
                    nonlocal rates
                    try:
                        rates = mt5.copy_rates_from(ticker, self._tf_to_mt5(timeframe), datetime.now(timezone.utc), lookback_days)
                    except Exception as e:
                        self.logger.warning(f"MT5 fetch error for {ticker}: {e}")
                        rates = None

                thread = threading.Thread(target=fetch_rates_load)
                thread.start()
                thread.join(timeout=mt5_timeout)

                if thread.is_alive():
                    self.logger.warning(f"Timeout fetching data for {ticker} ({timeframe}). Skipping.")
                    continue

                if rates is None or len(rates) == 0:
                    self.logger.warning(f"MT5 returned no data for {ticker} ({timeframe}).")
                    continue

                # Build schema-complete DataFrame from MT5 rates
                df_rates = pd.DataFrame(rates).rename(columns={
                    "time": "tstamp",
                    "open": "price_open",
                    "high": "price_high",
                    "low": "price_low",
                    "close": "price_close",
                    "tick_volume": "tick_volume",
                    "real_volume": "real_volume",
                    "spread": "spread"
                })

                df_rates["timeframe"] = timeframe
                df_rates["symbol_id"] = sid

                # Append to market_data table
                try:
                    df_rates.to_sql("market_data", self.conn, if_exists="append", index=False)
                    self.logger.info(f"Inserted {len(df_rates)} fetched bars for {ticker} into DB.")
                except Exception as e:
                    self.logger.warning(f"Failed to insert fetched bars for {ticker}: {e}")
                    # continue without raising so ranking can proceed with available tickers

                # reload from DB to ensure consistency/formatting
                df = pd.read_sql(query, self.conn)

            # Convert tstamp and set index
            if len(df) == 0:
                # nothing to add for this ticker
                continue

            df["date"] = pd.to_datetime(df["tstamp"], unit="s")
            df.set_index("date", inplace=True)

            # Save the close series
            # If price_close missing, skip ticker
            if "price_close" not in df.columns:
                self.logger.warning(f"No price_close for {ticker} after DB/MT5 fetch. Skipping ticker.")
                continue

            price_data[ticker] = df["price_close"]

        # done with DB access
        self._disconnect()

        if not price_data:
            raise ValueError("No price data could be loaded or fetched.")

        prices = pd.concat(price_data, axis=1).dropna()
        self.logger.info(f"Loaded {len(prices)} rows of aligned price data for {len(prices.columns)} symbols.")
        return prices

    # -----------------------------------------------------
    # COINTEGRATION RANKING
    # -----------------------------------------------------
    def rank_cointegration(
        self,
        timeframe: str,
        lookback_days: int,
        baskets: Optional[List[List[str]]] = None,
        det_order: int = 0,
        k_ar_diff: int = 1,
        save_to_csv: str = "",
        save_to_db: str = ""
    ) -> pd.DataFrame:
        """
        Rank pairs/baskets of assets by cointegration strength.
        """
        # Normalize timeframe and determine which tickers need to be loaded
        timeframe = timeframe.upper()
        if baskets is not None:
            tickers_needed = sorted({ticker for basket in baskets for ticker in basket})
            self.logger.info(f"Loading price data only for specified basket tickers: {tickers_needed}")
        else:
            tickers_needed = None  # will trigger full load
            self.logger.info("No basket specified, loading all available symbols.")

        prices = self._load_price_data(timeframe, lookback_days, tickers_filter=tickers_needed)

        if baskets is None:
            baskets = [list(pair) for pair in itertools.combinations(prices.columns, 2)]

        results = []
        self.logger.info(f"Testing {len(baskets)} baskets for cointegration...")

        for basket in tqdm(baskets, desc="Running cointegration tests"):
            sub_prices = prices[basket].dropna()
            if len(sub_prices) < 50:
                continue

            try:
                if len(basket) == 2:
                    y0, y1 = sub_prices.iloc[:, 0], sub_prices.iloc[:, 1]
                    # Engle-Granger test for pairs
                    score, p_value, _ = coint(y0, y1)
                    results.append({
                        "assets": basket,
                        "method": "Engle-Granger",
                        "strength_stat": float(score),
                        "p_value": float(p_value),
                        "eigen_strength": None
                    })
                else:
                    # Johansen test for baskets > 2
                    # det_order=0 (constant) or det_order=1 (linear trend)
                    # k_ar_diff is the lag order
                    johansen_res = coint_johansen(sub_prices, det_order, k_ar_diff)
                    # Maximum eigenvalue is a common measure for strength
                    max_eig = float(max(johansen_res.eig))
                    results.append({
                        "assets": basket,
                        "method": "Johansen",
                        # Using the trace statistic (lr1) for overall cointegration existence
                        "strength_stat": float(max(johansen_res.lr1)),
                        "p_value": None,
                        "eigen_strength": max_eig
                    })
            except Exception as e:
                self.logger.warning(f"Error testing {basket}: {e}")

        if not results:
            self.logger.warning("No valid cointegration results obtained.")
            return pd.DataFrame()

        df = pd.DataFrame(results)
        df["rank_score"] = df.apply(
            # Higher Eigenvalue (Johansen) is better. Lower p-value (Engle-Granger) is better, hence the negation.
            lambda row: -row["p_value"] if row["method"] == "Engle-Granger" else row["eigen_strength"],
            axis=1
        )
        # Sort by rank_score (highest is best cointegration strength)
        df = df.sort_values("rank_score", ascending=False).reset_index(drop=True)

        if save_to_csv:
            df.to_csv(save_to_csv, index=False)
            self.logger.info(f"Results saved to {save_to_csv}")
        if save_to_db:
            # Add timestamp, timeframe, and lookback fields before saving
            # 1. Get a high-resolution base time (float)
            base_time = time.time()
            # 2. Convert the high-res base time to a single, large integer (microsecond-based).
            # We use standard Python's int() function, which works on a scalar float.
            base_time_us = int(base_time * 1e6)
            # 3. Add the unique integer index (0, 1, 2, ...) to the scalar.
            # Pandas broadcasts this addition, giving every row a unique microsecond timestamp.
            df["tstamp"] = base_time_us + df.index
            df["timeframe"] = timeframe
            df["lookback"] = lookback_days

            # Convert asset lists to comma-separated strings
            df["assets"] = df["assets"].apply(lambda x: ",".join(x) if isinstance(x, list) else str(x))

            # Ensure columns match the updated coint_rank table definition
            columns_order = [
                "tstamp", "timeframe", "lookback", "assets", "method",
                "strength_stat", "p_value", "eigen_strength", "rank_score"
            ]
            df = df[columns_order]

            self._connect()
            df.to_sql(save_to_db, self.conn, if_exists="append", index=False)
            self._disconnect()
            self.logger.info(
                f"Results appended to table '{save_to_db}' "
                f"(timeframe={timeframe}, lookback={lookback_days}, rows={len(df)})"
            )
        self.logger.info(f"Cointegration ranking completed for timeframe {timeframe}.")
        return df

# --- End of CointegrationRanker class definition ---

# -----------------------------------------------------
# SCRIPT CONFIGURATION
# -----------------------------------------------------

# ⚠️ IMPORTANT: Update this path to your actual SQLite database file.
DB_PATH = os.getenv("DB_PATH", "")
TIMEFRAME = "H4"  # Daily ('D1'), 4-Hour ('H4'), etc.
LOOKBACK_DAYS = 30 * 6  # this makes easy to cover multiple months
MIN_SYMBOLS = 2
MAX_SYMBOLS = 4
SAVE_CSV_TO = "coint_ranking_results.csv"
SAVE_DB_TO = "coint_rank" # SQLite table name for results

# Example list of 10 symbols (Replace with your actual symbols/tickers)
SYMBOLS = [
    "NVDA", "INTC", "AMD", "WOLF", "NVTS",
    "AVGO", "LAES", "MRVL", "MU", "ASX"
]

# -----------------------------------------------------
# SCRIPT EXECUTION
# -----------------------------------------------------

def run_cointegration_ranking():
    """
    Generates all combinations and runs the cointegration ranker.
    """
    print(f"Generating baskets of {MIN_SYMBOLS} to {MAX_SYMBOLS} symbols from {len(SYMBOLS)} tickers...")

    all_baskets = []
    # Generate all combinations from 2 up to 4 symbols
    for size in range(MIN_SYMBOLS, MAX_SYMBOLS + 1):
        combinations = list(itertools.combinations(SYMBOLS, size))
        # Convert tuples from combinations to lists for the ranker
        all_baskets.extend([list(c) for c in combinations])

    print(f"Total baskets to test: {len(all_baskets)}")

    if not all_baskets:
        print("No baskets generated. Exiting.")
        return

    try:
        # 1. Initialize MetaTrader 5 (required by the class)
        if not mt5.initialize():
            print("Failed to initialize MetaTrader 5 terminal. Check installation and path.")
            return

        # 2. Instantiate the Ranker
        ranker = CointegrationRanker(db_path=DB_PATH)

        # 3. Run the ranking
        print(f"Starting cointegration ranking for {TIMEFRAME} with {LOOKBACK_DAYS} days lookback...")
        ranked_df = ranker.rank_cointegration(
            timeframe=TIMEFRAME,
            lookback_days=LOOKBACK_DAYS,
            baskets=all_baskets,
            save_to_csv=SAVE_CSV_TO,
            save_to_db=SAVE_DB_TO
        )

        # 4. Shutdown MetaTrader 5
        mt5.shutdown()

        print("\n--- Top 10 Cointegrated Baskets ---")
        print(ranked_df.head(10).to_markdown(index=False))
        print(f"\nResults saved to '{SAVE_CSV_TO}' and appended to table '{SAVE_DB_TO}' in '{DB_PATH}'.")

    except Exception as e:
        print(f"An error occurred during execution: {e}")
        # Ensure MT5 is shut down even on error
        if mt5.last_error() != (0, 'MT5 not initialized'):
            mt5.shutdown()


if __name__ == "__main__":
    run_cointegration_ranking()