import pandas as pd
import wbdata
import MetaTrader5 as mt5
from catboost import CatBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import warnings
import logging
from datetime import datetime, timedelta
from typing import Dict, Optional

# Disable warnings
warnings.filterwarnings("ignore", category=UserWarning, module="wbdata")
logger = logging.getLogger(__name__)

class ForexAnalyzer:
    def __init__(self):
        self.indicators = {
            'NY.GDP.MKTP.KD.ZG': 'GDP growth',
            'FP.CPI.TOTL.ZG': 'Inflation',
            'FR.INR.RINR': 'Real interest rate',
            'NE.EXP.GNFS.ZS': 'Exports',
            'NE.IMP.GNFS.ZS': 'Imports',
            'BN.CAB.XOKA.GD.ZS': 'Current account balance',
            'GC.DOD.TOTL.GD.ZS': 'Government debt',
            'SL.UEM.TOTL.ZS': 'Unemployment rate',
            'NY.GNP.PCAP.CD': 'GNI per capita',
            'NY.GDP.PCAP.KD.ZG': 'GDP per capita growth'
        }
        self.economic_data = None
        self.historical_data = {}
        self.prepared_data = {}
        self.forecasts = {}
        self.feature_importances = {}

    def fetch_economic_data(self):
        data_frames = []
        for indicator, name in self.indicators.items():
            try:
                data_frame = wbdata.get_dataframe({indicator: name}, country='all')
                data_frames.append(data_frame)
            except Exception as e:
                logger.error(f"Error fetching data for indicator '{indicator}': {e}")

        if data_frames:
            self.economic_data = pd.concat(data_frames, axis=1)
            return self.economic_data
        return None

    def initialize_mt5(self):
        if not mt5.initialize():
            raise Exception("Failed to initialize MetaTrader5")
        return True

    def fetch_mt5_data(self, lookback_days=1000):
        symbols = mt5.symbols_get()
        symbol_names = [symbol.name for symbol in symbols]

        for symbol in symbol_names:
            rates = mt5.copy_rates_from_pos(symbol, mt5.TIMEFRAME_D1, 0, lookback_days)
            if rates is not None:
                df = pd.DataFrame(rates)
                df['time'] = pd.to_datetime(df['time'], unit='s')
                df.set_index('time', inplace=True)
                self.historical_data[symbol] = df

        return self.historical_data

    def prepare_data(self, symbol_data):
        data = symbol_data.copy()
        data['close_diff'] = data['close'].diff()
        data['close_corr'] = data['close'].rolling(window=30).corr(data['close'].shift(1))

        for indicator in self.indicators.keys():
            if indicator in self.economic_data.columns:
                data[indicator] = self.economic_data[indicator].ffill()

        data.dropna(inplace=True)
        return data

    def prepare_all_data(self):
        for symbol, df in self.historical_data.items():
            self.prepared_data[symbol] = self.prepare_data(df)
        return self.prepared_data

    def forecast(self, symbol, symbol_data):
        if len(symbol_data) < 50:
            return None, None

        X = symbol_data.drop(columns=['close'])
        y = symbol_data['close']

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)

        if len(X_train) == 0 or len(X_test) == 0:
            return None, None

        model = CatBoostRegressor(iterations=1000, learning_rate=0.1, depth=8, loss_function='RMSE')
        model.fit(X_train, y_train, verbose=False)

        predictions = model.predict(X_test)
        feature_importance = model.feature_importances_
        feature_names = X.columns
        importance_df = pd.DataFrame({'feature': feature_names, 'importance': feature_importance})
        importance_df = importance_df.sort_values('importance', ascending=False)

        future_data = symbol_data.tail(30).copy()
        future_predictions = model.predict(future_data.drop(columns=['close']))

        return future_predictions, importance_df

    def run_forecasts(self):
        for symbol, df in self.prepared_data.items():
            try:
                forecast_result, importance_df = self.forecast(symbol, df)
                if forecast_result is not None and importance_df is not None:
                    self.forecasts[symbol] = forecast_result
                    self.feature_importances[symbol] = importance_df
            except Exception as e:
                logger.error(f"Error forecasting for {symbol}: {e}")

        return self.forecasts, self.feature_importances

    def interpret_results(self, symbol):
        forecast = self.forecasts.get(symbol)
        importance_df = self.feature_importances.get(symbol)

        if forecast is None or importance_df is None:
            return f"Insufficient data for interpretation of {symbol}"

        trend = "upward" if forecast[-1] > forecast[0] else "downward"
        volatility = "high" if forecast.std() / forecast.mean() > 0.1 else "low"
        top_feature = importance_df.iloc[0]['feature']
        
        return {
            "symbol": symbol,
            "trend": trend,
            "volatility": volatility,
            "top_feature": top_feature,
            "forecast_values": forecast.tolist(),
            "feature_importance": importance_df.to_dict()
        }

    def cleanup(self):
        mt5.shutdown()

class EconomicSignalModule:
    def __init__(self, terminal_path: str):
        self.analyzer = ForexAnalyzer()
        try:
            self.analyzer.initialize_mt5()
            self.analyzer.fetch_economic_data()
            self.analyzer.fetch_mt5_data()
            self.analyzer.prepare_all_data()
            self.forecasts, _ = self.analyzer.run_forecasts()
        except Exception as e:
            logger.error(f"Error initializing EconomicSignalModule: {e}")
            
    async def get_economic_signal(self, pair: str) -> Optional[Dict]:
        try:
            # Get pair interpretation
            result = self.analyzer.interpret_results(pair)
            
            if not result or isinstance(result, str):
                return None
                
            # Define direction based on a trend
            direction = None
            if result['trend'] == 'upward':
                direction = "BUY"
            elif result['trend'] == 'downward':
                direction = "SELL"
                
            if not direction:
                return None
                
            return {
                'direction': direction,
                'volatility': result['volatility'],
                'forecast_confidence': abs(result['forecast_values'][-1] - result['forecast_values'][0]) 
                                    / result['forecast_values'][0] if result['forecast_values'] else 0
            }
            
        except Exception as e:
            logger.error(f"Error getting economic signal for {pair}: {e}")
            return None
            
    async def update_forecasts(self):
        try:
            self.analyzer.fetch_economic_data()
            self.analyzer.fetch_mt5_data()
            self.analyzer.prepare_all_data()
            self.forecasts, _ = self.analyzer.run_forecasts()
        except Exception as e:
            logger.error(f"Error updating forecasts: {e}")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    # Test the module
    terminal_path = "C:/Program Files/RannForex MetaTrader 5/terminal64.exe"
    economic_module = EconomicSignalModule(terminal_path)
