"""
RWEC: Rolling Window Eigenvector Comparison
Detects when cointegration vector drifts over time
"""
import MetaTrader5 as mt5
import sqlite3
import pandas as pd
import numpy as np
from statsmodels.tsa.vector_ar.vecm import coint_johansen
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

class RollingVectorAnalyzer:
    def __init__(self, db_path="cointegration_data.db"):
        self.db_path = db_path
        self.conn = None
        
    def connect_mt5(self):
        """Connect to MetaTrader5"""
        if not mt5.initialize():
            raise ConnectionError("MT5 initialization failed")
        print("✅ MT5 Connected")
    
    def fetch_data(self, symbols, timeframe=mt5.TIMEFRAME_H4, n_bars=504):
        """Fetch OHLC data from MT5"""
        data = {}
        for symbol in symbols:
            rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, n_bars)
            df = pd.DataFrame(rates)
            df['time'] = pd.to_datetime(df['time'], unit='s')
            df.set_index('time', inplace=True)
            df = df['close'].rename(symbol)
            data[symbol] = df
        return pd.concat(data, axis=1).dropna()
    
    def store_data(self, df, table_name="etf_pairs"):
        """Store data in SQLite"""
        df.to_sql(table_name, self.conn, if_exists='append', index=True)
        print(f"✅ Data stored: {table_name}")
    
    def load_data(self, table_name="etf_pairs"):
        """Load data from SQLite"""
        return pd.read_sql(f"SELECT * FROM {table_name}", self.conn, index_col='time', parse_dates=['time'])
    
    def rolling_cointegration(self, data, window=90, step=22):
        """Compute rolling cointegration vectors"""
        vectors = []
        dates = []
        
        for i in range(0, len(data) - window, step):
            window_data = data.iloc[i:i+window]
            if len(window_data) == window:
                try:
                    # Johansen test (trace statistic, k_ar_diff=1)
                    result = coint_johansen(window_data, det_order=0, k_ar_diff=1)
                    # First eigenvector (largest eigenvalue)
                    vec = result.evec[:, 0]
                    # Normalize
                    vec = vec / np.linalg.norm(vec)
                    vectors.append(vec)
                    dates.append(window_data.index[-1])
                except:
                    continue
        
        return pd.DataFrame(vectors, index=dates).rename_axis('date')
    
    def vector_similarity(self, vectors_df):
        """Compute cosine similarity between consecutive vectors"""
        similarities = []
        for i in range(1, len(vectors_df)):
            vec1 = vectors_df.iloc[i-1].values
            vec2 = vectors_df.iloc[i].values
            cos_sim = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
            angle_deg = np.degrees(np.arccos(np.clip(cos_sim, -1, 1)))
            similarities.append({
                'date': vectors_df.index[i],
                'cosine_similarity': cos_sim,
                'angle_degrees': angle_deg,
                'stable': angle_deg < 30  # Threshold
            })
        return pd.DataFrame(similarities).set_index('date')
    
    def run_analysis(self, symbols=['XLK', 'AAA']):
        """Full pipeline"""
        # Connect
        self.connect_mt5()
        self.conn = sqlite3.connect(self.db_path)
        
        # Fetch & store
        data = self.fetch_data(symbols)
        # self.store_data(data)
        
        # Rolling analysis
        vectors = self.rolling_cointegration(data)
        similarities = self.vector_similarity(vectors)
        
        # Plot
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 4))
        
        # Vector components over time
        for col in vectors.columns:
            ax1.plot(vectors.index, vectors[col], label=col, alpha=0.8)
        ax1.axhline(0, color='k', alpha=0.3)
        ax1.set_title(f"Cointegration Vector Evolution {symbols[0]}/{symbols[1]}")
        ax1.legend(symbols, loc='upper right')
        ax1.grid(True, alpha=0.3)
        # rotate the x-axis labels
        plt.setp(ax1.get_xticklabels(), rotation=45)
        
        # Similarity metrics
        ax2.plot(similarities.index, similarities['angle_degrees'], 'o-', color='red', label='Angle (°)')
        ax2.axhline(30, color='orange', linestyle='--', label='Stability Threshold')
        ax2.set_title('Vector Stability (Angle between consecutive vectors)')
        ax2.set_ylabel('Angle (degrees)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.xticks(rotation=45)
        plt.savefig(f"RWEC__{symbols[0]}_{symbols[1]}.png", bbox_inches='tight')
        # plt.show()
        
        print("\n📊 RWEC Results:")
        print(similarities['stable'].value_counts())
        print(f"Mean angle: {similarities['angle_degrees'].mean():.1f}°")
        print(f"Median angle: {similarities['angle_degrees'].median():.1f}°")
        
        self.conn.close()
        mt5.shutdown()
        return similarities

# RUN
if __name__ == "__main__":
    analyzer = RollingVectorAnalyzer()
    results1 = analyzer.run_analysis(['NVDA', 'INTC'])