import os
import sqlite3
import pandas as pd
import MetaTrader5 as mt5
import json
from datetime import datetime, timedelta, timezone
from statsmodels.tsa.vector_ar.vecm import coint_johansen

class SymbolJohansenMulti:
    def __init__(self, db_path='market_data.db'):
        """
        Initializes the SymbolJohansenMulti object with the database path.
        """
        self.db_path = os.getenv('DB_PATH', db_path)
        self.conn = None

    def connect_db(self):
        """
        Establishes a connection to the SQLite database.
        """
        try:
            self.conn = sqlite3.connect(self.db_path)
            print("Successfully connected to the SQLite database.")
        except sqlite3.Error as e:
            print(f"Error connecting to database: {e}")
            self.conn = None

    def close_db(self):
        """
        Closes the database connection.
        """
        if self.conn:
            self.conn.close()
            print("Database connection closed.")
            
    def connect_mt5(self):
        """
        Connects to the MetaTrader 5 terminal.
        """
        if not mt5.initialize():
            print("Failed to initialize MT5, error code:", mt5.last_error())
            return False
        return True

    def get_symbol_id(self, ticker):
        """
        Retrieves the symbol_id from the 'symbol' table for a given ticker.
        """
        if not self.conn:
            print("Database connection not established.")
            return None
        try:
            query = "SELECT symbol_id FROM symbol WHERE ticker = ?"
            cursor = self.conn.cursor()
            cursor.execute(query, (ticker,))
            result = cursor.fetchone()
            return result[0] if result else None
        except sqlite3.Error as e:
            print(f"Error retrieving symbol_id for {ticker}: {e}")
            return None

    def get_market_data_from_db(self, symbol_id, timeframe, lookback):
        """
        Retrieves market data from the 'market_data' table.
        It also checks if the data exists, and if not, it fetches it from MT5.
        """
        if not self.conn:
            print("Database connection not established.")
            return pd.DataFrame(columns=['price_close', 'tstamp'])

        end_datetime = datetime.now(timezone.utc)
        start_datetime = end_datetime - timedelta(days=lookback)

        end_timestamp = int(end_datetime.timestamp())
        start_timestamp = int(start_datetime.timestamp())

        try:
            query = """
            SELECT price_close, tstamp
            FROM market_data
            WHERE symbol_id = ? AND timeframe = ? AND tstamp >= ? AND tstamp <= ?
            ORDER BY tstamp ASC
            LIMIT ?
            """
            df = pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, start_timestamp, end_timestamp, lookback))
            # Check if we have enough data
            # Correctly check for enough data and fetch from MT5 if needed
            if len(df) < lookback:
                print(f"Not enough data in DB for symbol_id {symbol_id}, timeframe {timeframe}, lookback {lookback}.")
                mt5_df = self.fetch_market_data_from_mt5(symbol_id, timeframe, lookback)
                if mt5_df is not None and not mt5_df.empty:
                    # Re-query the database to get the combined data
                    df = pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, start_timestamp, end_timestamp, lookback))
                else:
                    print(f"Failed to fetch data for symbol_id {symbol_id} from MT5. Returning empty DataFrame.")
                    return pd.DataFrame(columns=['price_close', 'tstamp']) # Return a DataFrame with the correct columns
            return df
        except pd.io.sql.DatabaseError as e:
            print(f"Error querying market data: {e}")
            return pd.DataFrame(columns=['price_close', 'tstamp'])
    
    def fetch_market_data_from_mt5(self, symbol_id, timeframe, lookback):
        """
        Fetches market data from MetaTrader 5 and stores it in the database.
        """
        if not self.connect_mt5():
            print("MT5 connection failed.")
            return pd.DataFrame(columns=['price_close', 'tstamp'])

        cursor = self.conn.cursor()
        cursor.execute("SELECT ticker FROM symbol WHERE symbol_id = ?", (symbol_id,))
        ticker_name = cursor.fetchone()[0]
        
        mt5_timeframes = {
            'M1': mt5.TIMEFRAME_M1, 'M5': mt5.TIMEFRAME_M5, 'M15': mt5.TIMEFRAME_M15,
            'M30': mt5.TIMEFRAME_M30, 'H1': mt5.TIMEFRAME_H1, 'H4': mt5.TIMEFRAME_H4, 'D1': mt5.TIMEFRAME_D1,
            'W1': mt5.TIMEFRAME_W1, 'MN1': mt5.TIMEFRAME_MN1
        }
        if timeframe not in mt5_timeframes:
            print(f"Unsupported timeframe: {timeframe}")
            return pd.DataFrame(columns=['price_close', 'tstamp'])
            
        rate = mt5.copy_rates_from_pos(ticker_name, mt5_timeframes[timeframe], 0, lookback)
        mt5.shutdown()

        if rate is None or rate.size == 0:
            print(f"Failed to fetch data for {ticker_name} from MT5.")
            return pd.DataFrame(columns=['price_close', 'tstamp'])

        df = pd.DataFrame(rate)
        df['tstamp'] = pd.to_datetime(df['time'], unit='s')
        
        records_to_insert = []
        for _, row in df.iterrows():
            record = (
                int(row['time']),
                timeframe,
                row['open'],
                row['high'],
                row['low'],
                row['close'],
                row['tick_volume'],
                row['real_volume'],
                row['spread'],
                symbol_id
            )
            records_to_insert.append(record)
            
        try:
            cursor.executemany("""
                INSERT OR IGNORE INTO market_data (tstamp, timeframe, price_open, price_high, price_low, price_close, tick_volume, real_volume, spread, symbol_id)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, records_to_insert)
            self.conn.commit()
            print(f"Successfully stored {len(records_to_insert)} records for {ticker_name} in the database.")
            
            query = """
                SELECT price_close, tstamp FROM market_data
                WHERE symbol_id = ? AND timeframe = ?
                ORDER BY tstamp ASC
                LIMIT ?
            """
            return pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, lookback))
        except sqlite3.Error as e:
            print(f"Error storing market data: {e}")
            return pd.DataFrame(columns=['price_close', 'tstamp'])

    
    def calculate_johansen_test(self, prices_df, det_order=0, k_ar_diff=1):
        """
        Calculates the Johansen cointegration test results for multiple assets.
        Returns a tuple with (trace_stats, trace_crit_vals, eigen_stats, eigen_crit_vals, coint_rank, coint_vectors).
        """
        if prices_df.empty:
            return None, None, None, None, 0, None
        
        # Perform the Johansen test on the price data
        johansen_result = coint_johansen(prices_df.to_numpy(), det_order, k_ar_diff)
        
        # Determine the cointegrating rank
        coint_rank = 0
        for i in range(len(johansen_result.lr1)):
            # Use 5% critical values for a conservative test
            if johansen_result.lr1[i] > johansen_result.cvt[i][1]:
                coint_rank += 1
            else:
                break
        
        # Extract the primary cointegrating vector (hedge ratios) if a relationship exists
        coint_vectors = None
        if coint_rank > 0:
            # The eigenvectors are in the .evec attribute, with each column representing a vector.
            # We select the first column, as it corresponds to the largest eigenvalue and thus the most significant relationship.
            v = johansen_result.evec[:, 0].tolist()
            
            # Normalize the vector for interpretability. We divide by the first element
            # so the first asset's hedge ratio is 1.0.
            if v[0] != 0:
                coint_vectors = [element / v[0] for element in v]
        
        return (
            johansen_result.lr1.tolist(),  # Trace statistics
            johansen_result.cvt.tolist(),  # Trace critical values
            johansen_result.lr2.tolist(),  # Max-eigenvalue statistics
            johansen_result.cvm.tolist(),  # Max-eigenvalue critical values
            coint_rank,
            coint_vectors
        )

    def store_johansen_results(self, tstamp, timeframe, lookback, num_assets, asset_symbol_ids, trace_stats, trace_crit_vals, eigen_stats, eigen_crit_vals, coint_rank, coint_vectors):
        """
        Stores the Johansen test results in the 'coint_johansen_test' and 'coint_johansen_test_assets' tables.
        """
        if not self.conn:
            print("Database connection not established.")
            return

        try:
            cursor = self.conn.cursor()
            
            # Insert into the main test table
            cursor.execute("""
                INSERT INTO coint_johansen_test (tstamp, timeframe, lookback, num_assets, trace_stats_json, trace_crit_vals_json, eigen_stats_json, eigen_crit_vals_json, coint_rank,coint_vectors_json)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?,?)
            """, (
                tstamp,
                timeframe,
                lookback,
                num_assets,
                json.dumps(trace_stats),
                json.dumps(trace_crit_vals),
                json.dumps(eigen_stats),
                json.dumps(eigen_crit_vals),
                coint_rank,
                json.dumps(coint_vectors)
            ))
            
            test_id = cursor.lastrowid
            
            # Insert into the junction table for each asset
            asset_records = [(test_id, symbol_id) for symbol_id in asset_symbol_ids]
            cursor.executemany("""
                INSERT INTO coint_johansen_test_assets (test_id, symbol_id)
                VALUES (?, ?)
            """, asset_records)
            
            self.conn.commit()
            print(f"Successfully stored Johansen test results for {num_assets} assets with test_id: {test_id}.")
        except sqlite3.Error as e:
            print(f"Error storing Johansen results: {e}")

    def run_johansen_analysis(self, asset_tickers, timeframe, lookback):
        """
        Main function to orchestrate the entire Johansen analysis.
        """
        self.connect_db()
        if not self.conn:
            return

        asset_symbol_ids = []
        all_dataframes = []

        try:
            # First, validate all tickers and collect data
            for ticker in asset_tickers:
                symbol_id = self.get_symbol_id(ticker)
                if symbol_id is None:
                    print(f"Symbol '{ticker}' not found in the database. Exiting.")
                    self.close_db()
                    return
                asset_symbol_ids.append(symbol_id)
                
                print(f"Fetching data for symbol: {ticker}...")
                data = self.get_market_data_from_db(symbol_id, timeframe, lookback)
                
                if data.empty:
                    print(f"Skipping analysis due to missing data for {ticker}.")
                    self.close_db()
                    return
                
                data = data.rename(columns={'price_close': ticker})
                data = data.set_index('tstamp')
                all_dataframes.append(data)
            
            # Combine all data into a single DataFrame based on common timestamps
            combined_df = pd.concat(all_dataframes, axis=1, join='inner')
            if combined_df.empty:
                print("No common data found for all assets. Exiting.")
                self.close_db()
                return

            analysis_tstamp = int(datetime.now(timezone.utc).timestamp())

            # Calculate the Johansen test
            (trace_stats, trace_crit_vals, eigen_stats, eigen_crit_vals, coint_rank,coint_vectors) = self.calculate_johansen_test(combined_df)
            
            if trace_stats is not None:
                print(f"Johansen test results for {len(asset_tickers)} assets:")
                print(f"  Cointegrating rank: {coint_rank}")
                print(f"  Trace Statistics: {trace_stats}")
                print(f"  Max-Eigenvalue Statistics: {eigen_stats}")
                print(f"  Cointegrating Vectors: {coint_vectors}")
                # Store the results in the database

                self.store_johansen_results(
                    analysis_tstamp,
                    timeframe,
                    lookback,
                    len(asset_tickers),
                    asset_symbol_ids,
                    trace_stats,
                    trace_crit_vals,
                    eigen_stats,
                    eigen_crit_vals,
                    coint_rank,
                    coint_vectors
                )
            else:
                print(f"Could not calculate Johansen test for the asset group.")

        except pd.io.sql.DatabaseError as e:
            print(f"An error occurred during the analysis: {e}")
        finally:
            self.close_db()
            print("Analysis complete.")

if __name__ == '__main__':
    # Example usage:
    # We are looking for the cointegration of a portfolio of tickers
    # using 'D1' timeframe and a 'lookback' of 365 days.
    
    analyzer = SymbolJohansenMulti()
    
    # A hypothetical portfolio of three stocks to test
    analyzer.run_johansen_analysis(
        asset_tickers=['AMD',  'INTC', 'LAES','RMBS', 'TSM','WOLF'],
        timeframe='D1',
        lookback=30
    )