import os
import sqlite3
import pandas as pd
import MetaTrader5 as mt5
from datetime import datetime, timedelta, timezone
import numpy as np

# A class to manage the entire process
class SymbolCorrelation:
    def __init__(self, db_path='market_data.db'):
        """
        Initializes the SymbolCorrelation object with the database path.
        """
        self.db_path = os.getenv('DB_PATH', db_path)
        self.conn = None

    def connect_db(self):
        """
        Establishes a connection to the SQLite database.
        """
        try:
            self.conn = sqlite3.connect(self.db_path)
            print("Successfully connected to the SQLite database.")
        except sqlite3.Error as e:
            print(f"Error connecting to database: {e}")
            self.conn = None

    def close_db(self):
        """
        Closes the database connection.
        """
        if self.conn:
            self.conn.close()
            print("Database connection closed.")
            
    def connect_mt5(self):
        """
        Connects to the MetaTrader 5 terminal.
        """
        if not mt5.initialize():
            print("Failed to initialize MT5, error code:", mt5.last_error())
            return False
        return True

    def get_symbol_id(self, ticker):
        """
        Retrieves the symbol_id from the 'symbol' table for a given ticker.
        """
        if not self.conn:
            print("Database connection not established.")
            return None
        try:
            query = "SELECT symbol_id FROM symbol WHERE ticker = ?"
            cursor = self.conn.cursor()
            cursor.execute(query, (ticker,))
            result = cursor.fetchone()
            return result[0] if result else None
        except sqlite3.Error as e:
            print(f"Error retrieving symbol_id for {ticker}: {e}")
            return None

    def get_market_data_from_db(self, symbol_id, timeframe, lookback):
        """
        Retrieves market data from the 'market_data' table.
        """
        if not self.conn:
            print("Database connection not established.")
            return pd.DataFrame()

        # Get the current UTC time once to ensure consistency
        end_datetime = datetime.now(timezone.utc)
        start_datetime = end_datetime - timedelta(days=lookback)
    
        end_timestamp = int(end_datetime.timestamp())
        start_timestamp = int(start_datetime.timestamp())

        try:
            query = """
            SELECT price_close, tstamp
            FROM market_data
            WHERE symbol_id = ? AND timeframe = ? AND tstamp >= ? AND tstamp <= ?
            ORDER BY tstamp DESC
            LIMIT ?
            """
            df = pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, start_timestamp, end_timestamp, lookback))
            # Check if we have enough data
            # Correctly check for enough data and fetch from MT5 if needed
            if len(df) < lookback:
                print(f"Not enough data in DB for symbol_id {symbol_id}, timeframe {timeframe}, lookback {lookback}.")
                mt5_df = self.fetch_market_data_from_mt5(symbol_id, timeframe, lookback)
                if mt5_df is not None and not mt5_df.empty:
                    # Re-query the database to get the combined data
                    df = pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, start_timestamp, end_timestamp, lookback))
                else:
                    print(f"Failed to fetch data for symbol_id {symbol_id} from MT5. Returning empty DataFrame.")
                    return pd.DataFrame(columns=['price_close', 'tstamp']) # Return a DataFrame with the correct columns
            return df
        except pd.io.sql.DatabaseError as e:
            print(f"Error querying market data: {e}")
            return pd.DataFrame(columns=['price_close', 'tstamp']) # Return a DataFrame with the correct columns

    def fetch_market_data_from_mt5(self, symbol_id, timeframe, lookback):
        """
        Fetches market data from MetaTrader 5 and stores it in the database.
        """
        if not self.connect_mt5():
            print("MT5 connection failed.")
            return pd.DataFrame()

        # Get symbol ticker from ID
        cursor = self.conn.cursor()
        cursor.execute("SELECT ticker FROM symbol WHERE symbol_id = ?", (symbol_id,))
        ticker_name = cursor.fetchone()[0]
        
        # Convert timeframe string to MT5 constant
        mt5_timeframes = {
            'M1': mt5.TIMEFRAME_M1, 'M5': mt5.TIMEFRAME_M5, 'M15': mt5.TIMEFRAME_M15,
            'M30': mt5.TIMEFRAME_M30, 'H1': mt5.TIMEFRAME_H1, 'H4':mt5.TIMEFRAME_H4, 'D1': mt5.TIMEFRAME_D1
        }
        if timeframe not in mt5_timeframes:
            print(f"Unsupported timeframe: {timeframe}")
            return pd.DataFrame()
            
        rate = mt5.copy_rates_from_pos(ticker_name, mt5_timeframes[timeframe], 0, lookback)
        mt5.shutdown()

        # Fix: Correctly check for empty or invalid data
        if rate is None or rate.size == 0:
            print(f"Failed to fetch data for {ticker_name} from MT5.")
            return pd.DataFrame(columns=['price_close', 'tstamp']) # Return a DataFrame with the correct columns

        df = pd.DataFrame(rate)
        df['tstamp'] = pd.to_datetime(df['time'], unit='s')
        
        # Prepare data for insertion
        records_to_insert = []
        for _, row in df.iterrows():
            record = (
                int(row['time']),
                timeframe,
                row['open'],
                row['high'],
                row['low'],
                row['close'],
                row['tick_volume'],
                row['real_volume'],
                row['spread'],
                symbol_id
            )
            records_to_insert.append(record)
            
        try:
            cursor = self.conn.cursor()
            cursor.executemany("""
                INSERT OR IGNORE INTO market_data (tstamp, timeframe, price_open, price_high, price_low, price_close, tick_volume, real_volume, spread, symbol_id)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, records_to_insert)
            self.conn.commit()
            print(f"Successfully stored {len(records_to_insert)} records for {ticker_name} in the database.")
            
            # Now, query the newly inserted data
            query = """
                SELECT price_close FROM market_data
                WHERE symbol_id = ? AND timeframe = ?
                ORDER BY tstamp DESC
                LIMIT ?
            """
            # The return value from this function is not directly used in your main loop, but it's good practice to return it.
            return pd.read_sql_query(query, self.conn, params=(symbol_id, timeframe, lookback))
        except sqlite3.Error as e:
            print(f"Error storing market data: {e}")
            return pd.DataFrame(columns=['price_close', 'tstamp']) # Return a DataFrame with the correct columns

    def calculate_pearson_correlation(self, ref_df, corr_df):
        """
        Calculates the Pearson correlation coefficient between two dataframes
        after aligning them by timestamp, using pd.merge on the index.
        """
        if ref_df.empty or corr_df.empty:
            print("One of the dataframes is empty. Cannot calculate correlation.")
            return None

        # Ensure 'tstamp' is the index for proper alignment
        ref_df = ref_df.set_index('tstamp')
        corr_df = corr_df.set_index('tstamp')

        # Rename columns to avoid conflicts and make them identifiable
        ref_df = ref_df.rename(columns={'price_close': 'ref_price_close'})
        corr_df = corr_df.rename(columns={'price_close': 'corr_price_close'})

        # Merge on the index, not a column
        merged_df = pd.merge(ref_df, corr_df, left_index=True, right_index=True, how='inner')

        # If the merged dataframe is empty, there is no common data
        if merged_df.empty:
            print("No common data found to calculate correlation.")
            return None

        # Calculate the Pearson correlation matrix
        correlation_matrix = merged_df.corr(method='pearson')

        # Extract the specific correlation coefficient
        correlation_coefficient = correlation_matrix.loc['ref_price_close', 'corr_price_close']

        return correlation_coefficient

    def calculate_pearson_correlation_Pandas(self, ref_df, corr_df):
        """
        Calculates the Pearson correlation coefficient between two dataframes
        after aligning them by timestamp, using the Pandas built-in method.
        """
        if ref_df.empty or corr_df.empty:
            print("One of the dataframes is empty. Cannot calculate correlation.")
            return None

        # Ensure 'tstamp' is the index for proper alignment
        ref_df = ref_df.set_index('tstamp')
        corr_df = corr_df.set_index('tstamp')

        # Rename columns to avoid conflicts and make them identifiable
        ref_df = ref_df.rename(columns={'price_close': 'ref_price_close'})
        corr_df = corr_df.rename(columns={'price_close': 'corr_price_close'})

        # Concatenate the two dataframes along the columns
        combined_df = pd.concat([ref_df, corr_df], axis=1, join='inner')

        # If the combined dataframe is empty, there is no common data
        if combined_df.empty:
            print("No common data found to calculate correlation.")
            return None

        # Calculate the Pearson correlation matrix
        correlation_matrix = combined_df.corr(method='pearson')

        # Extract the specific correlation coefficient
        correlation_coefficient = correlation_matrix.loc['ref_price_close', 'corr_price_close']

        return correlation_coefficient

    def store_correlation_results(self, tstamp, ref_ticker, corr_ticker, timeframe, lookback, coefficient):
        """
        Stores the calculated correlation coefficient in the 'corr_pearson' table.
        """
        if not self.conn:
            print("Database connection not established.")
            return
        
        try:
            cursor = self.conn.cursor()
            cursor.execute("""
                INSERT INTO corr_pearson (tstamp, ref_ticker, corr_ticker, timeframe, lookback, coefficient)
                VALUES (?, ?, ?, ?, ?, ?)
            """, (tstamp, ref_ticker, corr_ticker, timeframe, lookback, coefficient))
            self.conn.commit()
            print(f"Successfully stored correlation for {corr_ticker} with {ref_ticker}.")
        except sqlite3.Error as e:
            print(f"Error storing correlation results: {e}")

    def run_correlation_analysis(self, asset_type, industry, ref_ticker, timeframe, lookback):
        """
        Main function to orchestrate the entire process.
        """
        self.connect_db()
        if not self.conn:
            return

        ref_symbol_id = self.get_symbol_id(ref_ticker)
        if ref_symbol_id is None:
            print(f"Reference symbol '{ref_ticker}' not found in the database.")
            self.close_db()
            return
            
        try:
            # Get the price data for the reference symbol
            ref_data = self.get_market_data_from_db(ref_symbol_id, timeframe, lookback)
            if ref_data.empty:
                print(f"Could not get data for reference symbol {ref_ticker}. Exiting.")
                self.close_db()
                return

            # Get the current UTC timestamp for the new table entry
            correlation_tstamp = int(datetime.now(timezone.utc).timestamp())

            # Select all symbols with the specific asset_type and industry
            query = "SELECT symbol_id, ticker FROM symbol WHERE asset_type = ? AND industry = ?"
            df_symbols = pd.read_sql_query(query, self.conn, params=(asset_type, industry))
            
            if df_symbols.empty:
                print(f"No symbols found for asset_type '{asset_type}' and industry '{industry}'.")
                self.close_db()
                return
            
            for _, row in df_symbols.iterrows():
                symbol_id = row['symbol_id']
                ticker = row['ticker']
                
                print(f"Processing symbol: {ticker}...")
                
                # Get the price data for the current symbol
                corr_data = self.get_market_data_from_db(symbol_id, timeframe, lookback)
                
                if corr_data.empty:
                    print(f"Skipping {ticker} due to lack of data.")
                    continue
                
                # Calculate the correlation
                correlation_coefficient = self.calculate_pearson_correlation(ref_data, corr_data)
                
                if correlation_coefficient is not None:
                    print(f"Pearson correlation between {ticker} and {ref_ticker} is: {correlation_coefficient:.4f}")
                    
                    # Store the results
                    self.store_correlation_results(correlation_tstamp, ref_ticker, ticker, timeframe, lookback, correlation_coefficient)
                else:
                    print(f"Could not calculate correlation for {ticker}.")

        except pd.io.sql.DatabaseError as e:
            print(f"An error occurred during the analysis: {e}")
        finally:
            self.close_db()
            print("Analysis complete.")

if __name__ == '__main__':
    # We are looking for the correlation of all 'Stocks' in the 'Semiconductor' industry with 'NVDA'
    # using 'D1' timeframe and a 'lookback' of 180 days.
    analyzer = SymbolCorrelation()
    analyzer.run_correlation_analysis(
        asset_type='Stock',
        industry='Semiconductors',
        ref_ticker='NVDA',
        timeframe='D1',
        lookback=180
    )