import os
import sqlite3
import pandas as pd
from statsmodels.tsa.stattools import adfuller, kpss

class StationarityTester:
    def __init__(self, db_path='market_data.db'):
        """
        Initializes the StationarityTester object with the database path.
        """
        self.db_path = os.getenv('DB_PATH', db_path)
        self.conn = None

    def connect_db(self):
        """
        Establishes a connection to the SQLite database.
        """
        try:
            self.conn = sqlite3.connect(self.db_path)
            print("Successfully connected to the SQLite database.")
        except sqlite3.Error as e:
            print(f"Error connecting to database: {e}")
            self.conn = None

    def close_db(self):
        """
        Closes the database connection.
        """
        if self.conn:
            self.conn.close()
            print("Database connection closed.")
            
    def get_symbol_id(self, ticker):
        """
        Retrieves the symbol_id from the 'symbol' table for a given ticker.
        """
        if not self.conn:
            return None
        try:
            query = "SELECT symbol_id FROM symbol WHERE ticker = ?"
            cursor = self.conn.cursor()
            cursor.execute(query, (ticker,))
            result = cursor.fetchone()
            return result[0] if result else None
        except sqlite3.Error as e:
            print(f"Error retrieving symbol_id for {ticker}: {e}")
            return None
    
    def get_ticker_from_id(self, symbol_id):
        """
        Retrieves the ticker from the 'symbol' table for a given symbol_id.
        """
        if not self.conn:
            return None
        try:
            query = "SELECT ticker FROM symbol WHERE symbol_id = ?"
            cursor = self.conn.cursor()
            cursor.execute(query, (symbol_id,))
            result = cursor.fetchone()
            return result[0] if result else None
        except sqlite3.Error as e:
            print(f"Error retrieving ticker for {symbol_id}: {e}")
            return None
            
    def get_coint_groups(self):
        """
        Retrieves information for all Johansen tests with a cointegrating rank > 0.
        Returns a list of dictionaries, each containing test_id, timeframe, lookback, and a list of symbol_ids.
        """
        if not self.conn:
            print("Database connection not established.")
            return []
            
        try:
            query = """
            SELECT test_id, timeframe, lookback
            FROM coint_johansen_test
            WHERE coint_rank > 0
            """
            
            df_tests = pd.read_sql_query(query, self.conn)
            
            if df_tests.empty:
                print("No cointegrated groups found in the database.")
                return []
                
            coint_groups = []
            for _, row in df_tests.iterrows():
                test_id = row['test_id']
                
                # Retrieve all symbol_ids for the current test_id
                query_assets = """
                SELECT symbol_id FROM coint_johansen_test_assets
                WHERE test_id = ?
                """
                df_assets = pd.read_sql_query(query_assets, self.conn, params=(test_id,))
                
                group = {
                    'test_id': test_id,
                    'timeframe': row['timeframe'],
                    'lookback': row['lookback'],
                    'symbol_ids': df_assets['symbol_id'].tolist()
                }
                coint_groups.append(group)
            
            return coint_groups
        except pd.io.sql.DatabaseError as e:
            print(f"Error retrieving cointegrated groups: {e}")
            return []
            
    def get_market_data(self, symbol_ids, timeframe, lookback):
        """
        Retrieves combined market data for a list of symbol_ids.
        """
        if not self.conn or not symbol_ids:
            return pd.DataFrame()

        # A more efficient way to query for multiple symbols
        query_placeholders = ', '.join(['?'] * len(symbol_ids))
        
        query = f"""
        SELECT tstamp, symbol_id, price_close
        FROM market_data
        WHERE symbol_id IN ({query_placeholders}) AND timeframe = ?
        ORDER BY tstamp ASC
        """
        
        # We fetch all the data and then filter and pivot in pandas
        params = tuple(symbol_ids) + (timeframe,)
        df = pd.read_sql_query(query, self.conn, params=params)
        
        if df.empty:
            print("No market data found for the given symbols.")
            return pd.DataFrame()

        # Pivot the dataframe to have symbols as columns
        df_pivoted = df.pivot(index='tstamp', columns='symbol_id', values='price_close')
        
        # Ensure only common timestamps are kept
        df_final = df_pivoted.dropna()

        # We need to map symbol_id back to ticker for clarity
        ticker_map = {sid: self.get_ticker_from_id(sid) for sid in symbol_ids}
        df_final.columns = [ticker_map.get(col, col) for col in df_final.columns]

        return df_final
            
    def calculate_stationarity_tests(self, prices_df, significance_level=0.05):
        """
        Performs ADF and KPSS tests on each price series and determines stationarity.
        Returns a dictionary with results and boolean conclusions for each asset.
        """
        results = {}
        for column in prices_df.columns:
            series = prices_df[column].dropna()
            
            # ADF test
            adf_result = adfuller(series, autolag='AIC')
            adf_stat = adf_pvalue = adf_is_stationary = adf_result[0]
            adf_pvalue = adf_result[1]
            # ADF null hypothesis: non-stationary. Reject null if p-value < alpha.
            is_adf_stationary = 1 if adf_pvalue < significance_level else 0

            # KPSS test
            kpss_result = kpss(series, regression='c', nlags='auto')
            kpss_stat = kpss_result[0]
            kpss_pvalue = kpss_result[1]
            # KPSS null hypothesis: stationary. Fail to reject null if p-value > alpha.
            is_kpss_stationary = 1 if kpss_pvalue > significance_level else 0
            
            results[column] = {
                'adf_stat': adf_stat,
                'adf_pvalue': adf_pvalue,
                'is_adf_stationary': is_adf_stationary,
                'kpss_stat': kpss_stat,
                'kpss_pvalue': kpss_pvalue,
                'is_kpss_stationary': is_kpss_stationary
            }
        return results
        
    def store_stationarity_results(self, test_id, test_results):
        """
        Stores the ADF and KPSS stationarity test results.
        """
        if not self.conn:
            print("Database connection not established.")
            return

        try:
            cursor = self.conn.cursor()
            
            records_to_insert = []
            for ticker, result in test_results.items():
                symbol_id = self.get_symbol_id(ticker)
                if symbol_id is not None:
                    records_to_insert.append((
                        test_id,
                        symbol_id,
                        result['adf_stat'],
                        result['adf_pvalue'],
                        result['is_adf_stationary'],
                        result['kpss_stat'],
                        result['kpss_pvalue'],
                        result['is_kpss_stationary']
                    ))
            
            # The INSERT statement now includes the two new boolean fields
            cursor.executemany("""
                INSERT INTO coint_adf_kpss (test_id, symbol_id, adf_stat, adf_pvalue, is_adf_stationary, kpss_stat, kpss_pvalue, is_kpss_stationary)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            """, records_to_insert)
            
            self.conn.commit()
            print(f"Successfully stored stationarity test results for {len(records_to_insert)} assets.")
        except sqlite3.Error as e:
            print(f"Error storing stationarity results: {e}")

    def run_stationarity_analysis(self):
        """
        Main function to orchestrate the entire stationarity analysis.
        """
        self.connect_db()
        if not self.conn:
            return
            
        try:
            coint_groups = self.get_coint_groups()
            
            if not coint_groups:
                print("No cointegrated groups with rank > 0 to analyze. Exiting.")
                self.close_db()
                return

            for group in coint_groups:
                test_id = group['test_id']
                symbol_ids = group['symbol_ids']
                timeframe = group['timeframe']
                lookback = group['lookback']
                
                print(f"Analyzing test_id: {test_id} with symbols: {[self.get_ticker_from_id(sid) for sid in symbol_ids]}")
                
                prices_df = self.get_market_data(symbol_ids, timeframe, lookback)
                
                if prices_df.empty:
                    print(f"Could not retrieve data for test_id {test_id}. Skipping.")
                    continue
                
                stationarity_results = self.calculate_stationarity_tests(prices_df)
                self.store_stationarity_results(test_id, stationarity_results)
                
        except Exception as e:
            print(f"An unexpected error occurred during the analysis: {e}")
        finally:
            self.close_db()
            print("Stationarity analysis complete.")

if __name__ == '__main__':
    tester = StationarityTester()
    tester.run_stationarity_analysis()