import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy.stats import f
import MetaTrader5 as mt5
import sqlite3
import os
from datetime import datetime

class ChowTestAnalyzer:
    """
    Performs the Chow Test for structural stability on a pair of assets
    to detect a change in the hedge ratio (beta) at a specified breakpoint.
    """
    def __init__(self, db_path=os.getenv("DB_PATH", "cointegration_data.db")):
        self.db_path = db_path
        self.conn = None
        
    def connect_mt5(self):
        """Connect to MetaTrader5"""
        if not mt5.initialize():
            raise ConnectionError("MT5 initialization failed")
        print("✅ MT5 Connected")

    def fetch_data(self, symbols, timeframe=mt5.TIMEFRAME_D1, n_bars=504):
        """Fetch OHLC 'close' data from MT5 for a pair of symbols."""
        if len(symbols) != 2:
            raise ValueError("Chow test is typically applied to a pair of assets (Y and X).")
            
        data = {}
        for symbol in symbols:
            # Fetch latest n_bars
            rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, n_bars)
            if rates is None or len(rates) == 0:
                print(f"Failed to fetch data for {symbol}")
                continue
            
            df = pd.DataFrame(rates)
            df['time'] = pd.to_datetime(df['time'], unit='s')
            df.set_index('time', inplace=True)
            df = df['close'].rename(symbol)
            data[symbol] = df
            
        # Ensure we have data for both symbols and align indices
        combined_df = pd.concat(data, axis=1).dropna()
        if combined_df.empty:
            raise ConnectionError(f"Could not fetch or align sufficient data for {symbols}")
            
        return combined_df

    def run_chow_test(self, data: pd.DataFrame, break_point_date: str):
        """
        Performs the Chow test on the asset pair Y and X.
        
        Parameters:
        - data: DataFrame with columns [Y, X] (first column is Y, second is X)
        - break_point_date: String date (e.g., 'YYYY-MM-DD') for the structural break.
        """
        if data.shape[1] != 2:
            raise ValueError("Input data must have exactly two columns (Y and X).")
        
        # Assign Y and X based on column order
        symbol_y = data.columns[0]
        symbol_x = data.columns[1]
        
        Y_full = data[symbol_y]
        X_full = sm.add_constant(data[symbol_x])
        N = len(data)
        k = 2 # Number of parameters (intercept + slope)

        # 1. Split data at suspected break point
        break_date = pd.to_datetime(break_point_date)
        
        # Check if the break date is valid
        if break_date <= data.index.min() or break_date >= data.index.max():
             raise ValueError("Breakpoint date is outside the data range.")

        # Data split
        df_pre = data[data.index < break_date]
        df_post = data[data.index >= break_date]

        N1 = len(df_pre)
        N2 = len(df_post)
        
        if N1 < k or N2 < k:
            raise ValueError("Insufficient data points in one or both sub-periods to estimate the model.")

        # 2. Full sample regression (restricted model)
        full_model = sm.OLS(Y_full, X_full).fit()
        ssr_full = full_model.ssr
        
        # 3. Separate regressions (unrestricted models)
        Y_pre = df_pre[symbol_y]
        X_pre = sm.add_constant(df_pre[symbol_x])
        pre_model = sm.OLS(Y_pre, X_pre).fit()
        ssr_pre = pre_model.ssr

        Y_post = df_post[symbol_y]
        X_post = sm.add_constant(df_post[symbol_x])
        post_model = sm.OLS(Y_post, X_post).fit()
        ssr_post = post_model.ssr
        
        # 4. Compute Chow test statistic
        numerator = (ssr_full - (ssr_pre + ssr_post)) / k
        denominator = (ssr_pre + ssr_post) / (N - 2 * k)
        F_stat = numerator / denominator
        
        # p-value
        p_value = f.sf(F_stat, k, N - 2 * k)
        
        results = {
            'symbol_y': symbol_y,
            'symbol_x': symbol_x,
            'break_date': break_point_date,
            'F_stat': F_stat,
            'p_value': p_value,
            'pre_beta': pre_model.params[1],
            'post_beta': post_model.params[1],
            'full_beta': full_model.params[1],
            'N1': N1,
            'N2': N2,
            'ssr_full': ssr_full,
            'ssr_pre': ssr_pre,
            'ssr_post': ssr_post
        }
        results['full_intercept'] = full_model.params[0]
        results['pre_intercept'] = pre_model.params[0]
        results['post_intercept'] = post_model.params[0]

        # Determine the conclusion
        if p_value < 0.05:
            results['conclusion'] = "Reject H0 (Structural break detected)"
        else:
            results['conclusion'] = "Fail to reject H0 (No structural break evidence)"

        return results

    def save_results_to_db(self, results: dict, timeframe: str, test_id: int):
        """Saves the Chow test results to the 'chow_results' table."""
        self.conn = sqlite3.connect(self.db_path)
        cursor = self.conn.cursor()

        # Create table if it doesn't exist
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS chow_results (
                tstamp INTEGER,
                test_id INTEGER,
                symbol_y TEXT,
                symbol_x TEXT,
                timeframe TEXT,
                break_date TEXT,
                F_stat REAL,
                p_value REAL,
                pre_beta REAL,
                post_beta REAL,
                conclusion TEXT
            )
        """)
        self.conn.commit()

        # Prepare the record
        tstamp = int(datetime.now().timestamp())
        
        record = (
            tstamp,
            test_id,
            results['symbol_y'],
            results['symbol_x'],
            timeframe,
            results['break_date'],
            round(results['F_stat'], 6),
            round(results['p_value'], 6),
            round(results['pre_beta'], 6),
            round(results['post_beta'], 6),
            results['conclusion']
        )
        
        # Insert record
        insert_query = """
            INSERT INTO chow_results VALUES (
                ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?
            )
        """
        cursor.execute(insert_query, record)
        self.conn.commit()
        print(f"Inserted Chow test result for {results['symbol_y']}/{results['symbol_x']} into 'chow_results' table.")
        self.conn.close()
    
    def plot_chow_test(self, data, break_point_date, results):
        """
        Generates a visualization of the Chow Test mechanics.
        Shows the full-sample regression line vs. the two sub-period lines.
        """
        import matplotlib.pyplot as plt

        # Setup data
        symbol_y = data.columns[0]
        symbol_x = data.columns[1]
        break_date = pd.to_datetime(break_point_date)
        
        df_pre = data[data.index < break_date]
        df_post = data[data.index >= break_date]

        plt.figure(figsize=(7, 5))
        
        # 1. Plot the raw data points
        plt.scatter(data[symbol_x], data[symbol_y], color='gray', alpha=0.3, label='Data Points', s=15)

        # 2. Plot the Full Regression Line (Restricted) - Dashed
        x_range_full = np.linspace(data[symbol_x].min(), data[symbol_x].max(), 100)
        # Using the intercept (params[0]) and slope (params[1]) from your OLS results
        full_line = results['full_intercept'] + results['full_beta'] * x_range_full
        plt.plot(x_range_full, full_line, color='black', linestyle='--', linewidth=2, label='Full Model (Restricted)')

        # 3. Plot the Pre-Break Regression Line (Unrestricted)
        x_range_pre = np.linspace(df_pre[symbol_x].min(), df_pre[symbol_x].max(), 100)
        pre_line = results['pre_intercept'] + results['pre_beta'] * x_range_pre
        plt.plot(x_range_pre, pre_line, color='blue', linewidth=3, label=f'Pre-Break (Beta: {results["pre_beta"]:.2f})')

        # 4. Plot the Post-Break Regression Line (Unrestricted)
        x_range_post = np.linspace(df_post[symbol_x].min(), df_post[symbol_x].max(), 100)
        post_line = results['post_intercept'] + results['post_beta'] * x_range_post
        plt.plot(x_range_post, post_line, color='red', linewidth=3, label=f'Post-Break (Beta: {results["post_beta"]:.2f})')

        # Formatting
        plt.title(f"Chow Test Visualization: {symbol_y} vs {symbol_x}\nBreakpoint: {break_point_date}", fontsize=14)
        plt.xlabel(f"Price: {symbol_x}")
        plt.ylabel(f"Price: {symbol_y}")
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Add text box with statistics
        stats_text = f"F-stat: {results['F_stat']:.4f}\np-value: {results['p_value']:.6f}"
        plt.gca().text(0.05, 0.95, stats_text, transform=plt.gca().transAxes, 
                       verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        plt.tight_layout()
        plt.show()

    def run_analysis(self, symbols: list, break_date: str, timeframe=mt5.TIMEFRAME_D1, n_bars=504, test_id=1):
        """Main method to execute the data fetch, Chow test, and database save."""
        self.connect_mt5()
        
        # 1. Fetch Data
        # We assume the first symbol is Y and the second is X for the regression Y = a + bX
        data = self.fetch_data(symbols, timeframe, n_bars)
        
        # 2. Run Chow Test
        results = self.run_chow_test(data, break_date)
        
        # 3. Print Results
        print("\n--- Chow Test Analysis ---")
        print(f"Pair: {results['symbol_y']} vs {results['symbol_x']}")
        print(f"Breakpoint: {results['break_date']}")
        print(f"F-statistic: {results['F_stat']:.4f}")
        print(f"p-value: {results['p_value']:.6f}")
        print(f"Conclusion: {results['conclusion']}")
        print(f"Pre-break Beta: {results['pre_beta']:.4f}")
        print(f"Post-break Beta: {results['post_beta']:.4f}")

        # 4. Save Results to DB
        self.save_results_to_db(results, str(timeframe), test_id)
        
        # 5. Plot Results
        self.plot_chow_test(data, break_date, results)

        mt5.shutdown()
        return results
    
    def run_early_warning(self, data: pd.DataFrame, alpha=0.05):
            """
            Performs a CUSUM test to detect structural instability without 
            needing a pre-defined breakpoint. 

            Returns a 'warning' status if the model deviates from its stability path.
            """
            import matplotlib.pyplot as plt
            from statsmodels.regression.recursive_ls import RecursiveLS

            symbol_y = data.columns[0]
            symbol_x = data.columns[1]

            y = data[symbol_y]
            x = sm.add_constant(data[symbol_x])

            # 1. Fit the Recursive Least Squares model
            model = RecursiveLS(y, x)
            results = model.fit()

            # 2. Calculate CUSUM statistics and their bounds
            # The bounds are the 'tunnels' the model must stay within to be stable
            cusum = results.cusum
            cusum_bounds = results.cusum_ci

            # 3. Check for a break (Does the CUSUM line exit the bounds?)
            # We look at the last few observations for the 'Early Warning'
            is_breaking = False
            current_cusum = cusum[0][-1]
            lower_bound = cusum_bounds[0][-1]
            upper_bound = cusum_bounds[1][-1]

            if current_cusum > upper_bound or current_cusum < lower_bound:
                is_breaking = True

            # 4. Visualization
            fig, ax = plt.subplots(figsize=(12, 6))
            results.plot_cusum(ax=ax)
            ax.set_title(f"Early Warning Monitor (CUSUM): {symbol_y} vs {symbol_x}")
            plt.show()

            status = {
                "is_breaking": is_breaking,
                "current_value": round(current_cusum, 4),
                "bounds": (round(lower_bound, 4), round(upper_bound, 4)),
                "signal": "⚠️ STABILITY ALERT" if is_breaking else "✅ MODEL STABLE"
            }

            return status
# Example Usage
if __name__ == "__main__":
    # Ensure MetaTrader 5 is running and logged in before executing this script
    # This example uses placeholder symbols and a dummy break date.
    
    # Define the pair and the hypothetical break date
    ASSET_PAIR = ['US500', 'JPN225'] # Assuming NVDA is Y and INTC is X
    # ASSET_PAIR = ['NVDA', 'WOLF'] # Assuming NVDA is Y and WOLF is X
    BREAK_DATE = '2025-04-02' # Nvidia/Intel partnership announcement date
    
    analyzer = ChowTestAnalyzer()
    
    try:
        results = analyzer.run_analysis(
            symbols=ASSET_PAIR,
            break_date=BREAK_DATE,
            timeframe=mt5.TIMEFRAME_D1,
            n_bars=504, # Roughly 2 years of daily data
            test_id=1
        )
    except ConnectionError as e:
        print(f"Error during MT5/Data connection: {e}")
    except ValueError as e:
        print(f"Error in test parameters: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
    