import matplotlib
matplotlib.use('Agg')  # Установить бэкенд Agg для неинтерактивного рендеринга
import os
import asyncio
import logging
import pandas as pd
import numpy as np
import MetaTrader5 as mt5
import requests
import zipfile
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import warnings
import importlib.util
import glob
from datetime import datetime, timedelta

warnings.filterwarnings("ignore", category=UserWarning)

# Конфигурация
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

COT_URL = "https://www.cftc.gov/files/dea/history/dea_fut_xls_2025.zip"
TFF_URL = "https://www.cftc.gov/files/dea/history/fut_fin_xls_2025.zip"
OUTPUT_DIR = "data"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Настройка графиков с фиксированной шириной 750px
plt.style.use('ggplot')
DPI = 100
WIDTH_PX = 750
HEIGHT_PX = 500
FIGSIZE = (WIDTH_PX/DPI, HEIGHT_PX/DPI)
plt.rcParams['figure.figsize'] = FIGSIZE
plt.rcParams['figure.dpi'] = DPI

class CurrencyForecastModule:
    def __init__(self, pairs: list, days_history: int = 30):
        self.pairs = pairs
        self.days_history = days_history
        self.models = {}
        self.scalers = {}
        self.forecasts = {}

        # Проверка зависимостей
        self._check_dependencies()

        # Инициализация MT5
        if not mt5.initialize():
            logger.error("MT5 initialization failed. Ensure MT5 terminal is running and connected.")
            raise RuntimeError("MT5 initialization failed")

        # Проверка доступности символов
        self._validate_symbols()

        # Инициализация данных
        self._initialize_data()

    def _check_dependencies(self):
        """Проверяет наличие необходимых библиотек."""
        dependencies = ['xlrd', 'openpyxl']
        for dep in dependencies:
            if not importlib.util.find_spec(dep):
                logger.error(f"Missing dependency: {dep}. Install it using 'pip install {dep}'.")
                raise ImportError(f"Missing dependency: {dep}")

    def _validate_symbols(self):
        """Проверяет доступность символов в MT5 и сопоставляет их."""
        available_symbols = [s.name for s in mt5.symbols_get()]
        logger.info(f"Available symbols in MT5: {available_symbols}")
        self.symbol_mapping = {}
        for pair in self.pairs[:]:
            if pair in available_symbols:
                self.symbol_mapping[pair] = pair
            else:
                base_pair = pair.split('.')[0]
                if base_pair in available_symbols:
                    self.symbol_mapping[pair] = base_pair
                    logger.info(f"Mapped: {pair} -> {base_pair}")
                else:
                    logger.warning(f"Symbol {pair} not found in MT5. Skipping.")
                    self.pairs.remove(pair)

    def _initialize_data(self):
        """Инициализация данных COT и TFF."""
        logger.info("Initializing data for CurrencyForecastModule...")
        self.cot_data = self._load_cot_reports()
        self.tff_data = self._load_tff_reports()

        # Создание сводного графика сравнения COT и TFF
        if not self.cot_data.empty and not self.tff_data.empty:
            self._create_cot_tff_comparison()

        # Инициализация моделей для каждой пары
        for pair in self.pairs:
            self._train_model(pair)

    def _load_cot_reports(self) -> pd.DataFrame:
        """Загружает и обрабатывает отчеты COT."""
        cache_path = os.path.join(OUTPUT_DIR, "cot_report.csv")
        if os.path.exists(cache_path):
            logger.info(f"Loading COT data from cache: {cache_path}")
            return pd.read_csv(cache_path)

        try:
            response = requests.get(COT_URL)
            response.raise_for_status()
            zip_path = os.path.join(OUTPUT_DIR, "cot_data.zip")

            with open(zip_path, "wb") as f:
                f.write(response.content)

            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(OUTPUT_DIR)
                logger.info(f"COT archive contents: {zip_ref.namelist()}")

            excel_files = glob.glob(os.path.join(OUTPUT_DIR, "**", "*.xls*"), recursive=True)
            if not excel_files:
                logger.error("COT Excel file not found in extracted files")
                return pd.DataFrame()

            excel_file = excel_files[0]
            logger.info(f"Processing COT file: {excel_file}")

            relevant_columns = [
                "Market_and_Exchange_Names",
                "NonComm_Positions_Long_All",
                "NonComm_Positions_Short_All",
                "Comm_Positions_Long_All",
                "Comm_Positions_Short_All",
                "Open_Interest_All"
            ]

            cot_data = pd.read_excel(excel_file, engine='xlrd' if excel_file.endswith('.xls') else 'openpyxl')
            logger.info(f"Columns in COT file: {cot_data.columns.tolist()}")
            logger.info(f"First 5 rows of COT data:\n{cot_data.head().to_string()}")

            available_columns = [col for col in relevant_columns if col in cot_data.columns]
            if not available_columns:
                logger.error("Expected columns not found in COT data")
                return pd.DataFrame()

            cot_data = cot_data[available_columns]

            forex_markets = ["EURO FX", "JAPANESE YEN", "BRITISH POUND", "AUSTRALIAN DOLLAR",
                            "CANADIAN DOLLAR", "SWISS FRANC", "MEXICAN PESO", "NEW ZEALAND DOLLAR"]
            cot_data = cot_data[cot_data["Market_and_Exchange_Names"].str.contains('|'.join(forex_markets), case=False, na=False)]

            if "NonComm_Positions_Long_All" in cot_data.columns:
                cot_data["Net_NonComm"] = cot_data["NonComm_Positions_Long_All"] - cot_data["NonComm_Positions_Short_All"]
            if "Comm_Positions_Long_All" in cot_data.columns:
                cot_data["Net_Comm"] = cot_data["Comm_Positions_Long_All"] - cot_data["Comm_Positions_Short_All"]

            cot_data.to_csv(cache_path, index=False)
            logger.info(f"COT report saved to {cache_path}")

            self._visualize_cot_data(cot_data)
            return cot_data
        except Exception as e:
            logger.error(f"Error loading COT data: {e}")
            return pd.DataFrame()

    def _visualize_cot_data(self, cot_data: pd.DataFrame):
        """Visualizes Net Non-Commercial positions for currency futures."""
        if cot_data.empty or "Net_NonComm" not in cot_data.columns:
            logger.warning("No COT data available for visualization")
            return

        fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
        
        colors = plt.cm.tab10(np.linspace(0, 1, len(cot_data["Market_and_Exchange_Names"].unique())))
        
        for i, market in enumerate(cot_data["Market_and_Exchange_Names"].unique()):
            market_data = cot_data[cot_data["Market_and_Exchange_Names"] == market]
            ax.plot(range(len(market_data)), market_data["Net_NonComm"],
                   label=market, alpha=0.8, linewidth=2, color=colors[i])

        ax.set_title("COT Net Non-Commercial Positions for Currency Futures", 
                    fontsize=14, fontweight='bold', pad=20)
        ax.set_xlabel("Time Period", fontsize=12)
        ax.set_ylabel("Net Positions (Contracts)", fontsize=12)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        output_path = os.path.join(OUTPUT_DIR, "cot_net_positions.png")
        plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
        plt.close()
        logger.info(f"COT net positions chart saved to {output_path} (750px width)")

    def _load_tff_reports(self) -> pd.DataFrame:
        """Загружает отчеты Traders in Financial Futures (TFF)."""
        cache_path = os.path.join(OUTPUT_DIR, "tff_report.csv")
        if os.path.exists(cache_path):
            logger.info(f"Loading TFF data from cache: {cache_path}")
            return pd.read_csv(cache_path)

        try:
            response = requests.get(TFF_URL)
            response.raise_for_status()
            zip_path = os.path.join(OUTPUT_DIR, "tff_data.zip")

            with open(zip_path, "wb") as f:
                f.write(response.content)

            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(OUTPUT_DIR)
                logger.info(f"TFF archive contents: {zip_ref.namelist()}")

            excel_files = glob.glob(os.path.join(OUTPUT_DIR, "**", "*.xls*"), recursive=True)
            tff_files = [f for f in excel_files if 'FinFut' in f or 'fin' in f.lower()]

            if not tff_files:
                logger.error("TFF Excel file not found in extracted files")
                return pd.DataFrame()

            excel_file = tff_files[0]
            logger.info(f"Processing TFF file: {excel_file}")

            relevant_columns = [
                "Market_and_Exchange_Names",
                "Lev_Money_Positions_Long_All",
                "Lev_Money_Positions_Short_All",
                "Asset_Mgr_Positions_Long_All",
                "Asset_Mgr_Positions_Short_All",
                "Open_Interest_All"
            ]

            tff_data = pd.read_excel(excel_file, engine='xlrd' if excel_file.endswith('.xls') else 'openpyxl')
            logger.info(f"Columns in TFF file: {tff_data.columns.tolist()}")
            logger.info(f"First 5 rows of TFF data:\n{tff_data.head().to_string()}")

            available_columns = [col for col in relevant_columns if col in tff_data.columns]
            if not available_columns:
                logger.error("Expected columns not found in TFF data")
                return pd.DataFrame()

            tff_data = tff_data[available_columns]

            forex_markets = ["EURO FX", "JAPANESE YEN", "BRITISH POUND", "AUSTRALIAN DOLLAR",
                            "CANADIAN DOLLAR", "SWISS FRANC", "MEXICAN PESO", "NEW ZEALAND DOLLAR"]
            tff_data = tff_data[tff_data["Market_and_Exchange_Names"].str.contains('|'.join(forex_markets), case=False, na=False)]

            if "Lev_Money_Positions_Long_All" in tff_data.columns:
                tff_data["Net_Lev_Money"] = tff_data["Lev_Money_Positions_Long_All"] - tff_data["Lev_Money_Positions_Short_All"]
            if "Asset_Mgr_Positions_Long_All" in tff_data.columns:
                tff_data["Net_Asset_Mgr"] = tff_data["Asset_Mgr_Positions_Long_All"] - tff_data["Asset_Mgr_Positions_Short_All"]

            tff_data.to_csv(cache_path, index=False)
            logger.info(f"TFF report saved to {cache_path}")

            self._visualize_tff_data(tff_data)
            return tff_data
        except Exception as e:
            logger.error(f"Error loading TFF data: {e}")
            return pd.DataFrame()

    def _visualize_tff_data(self, tff_data: pd.DataFrame):
        """Visualizes Net Leveraged Funds positions for currency futures."""
        if tff_data.empty or "Net_Lev_Money" not in tff_data.columns:
            logger.warning("No TFF data available for visualization")
            return

        fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
        
        colors = plt.cm.viridis(np.linspace(0, 1, len(tff_data["Market_and_Exchange_Names"].unique())))
        
        for i, market in enumerate(tff_data["Market_and_Exchange_Names"].unique()):
            market_data = tff_data[tff_data["Market_and_Exchange_Names"] == market]
            ax.plot(range(len(market_data)), market_data["Net_Lev_Money"],
                   label=market, alpha=0.8, linewidth=2, color=colors[i])

        ax.set_title("TFF Net Leveraged Funds Positions for Currency Futures", 
                    fontsize=14, fontweight='bold', pad=20)
        ax.set_xlabel("Time Period", fontsize=12)
        ax.set_ylabel("Net Positions (Contracts)", fontsize=12)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        output_path = os.path.join(OUTPUT_DIR, "tff_net_positions.png")
        plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
        plt.close()
        logger.info(f"TFF net positions chart saved to {output_path} (750px width)")

    def _create_cot_tff_comparison(self):
        """Creates a comparative chart of COT and TFF positions for major currencies."""
        if self.cot_data.empty or self.tff_data.empty:
            logger.warning("Insufficient data for creating comparison chart")
            return

        try:
            # Focus on major currencies
            major_currencies = ["EURO FX", "JAPANESE YEN", "BRITISH POUND"]
            
            fig, axes = plt.subplots(1, 3, figsize=(WIDTH_PX*3/DPI, HEIGHT_PX/DPI), dpi=DPI)
            fig.suptitle("COT vs TFF Positions Comparison for Major Currencies", 
                        fontsize=16, fontweight='bold', y=0.98)
            
            for i, currency in enumerate(major_currencies):
                ax = axes[i]
                
                # COT data
                cot_subset = self.cot_data[self.cot_data["Market_and_Exchange_Names"].str.contains(currency, case=False, na=False)]
                if not cot_subset.empty and "Net_NonComm" in cot_subset.columns:
                    ax.plot(range(len(cot_subset)), cot_subset["Net_NonComm"], 
                           label="COT Non-Commercial", linewidth=2, color='blue', alpha=0.7)
                
                # TFF data
                tff_subset = self.tff_data[self.tff_data["Market_and_Exchange_Names"].str.contains(currency, case=False, na=False)]
                if not tff_subset.empty and "Net_Lev_Money" in tff_subset.columns:
                    ax.plot(range(len(tff_subset)), tff_subset["Net_Lev_Money"], 
                           label="TFF Leveraged Funds", linewidth=2, color='red', alpha=0.7)
                
                ax.set_title(currency, fontsize=12, fontweight='bold')
                ax.set_xlabel("Period", fontsize=10)
                ax.set_ylabel("Net Positions", fontsize=10)
                ax.legend(fontsize=9)
                ax.grid(True, alpha=0.3)
                ax.axhline(y=0, color='black', linestyle='--', alpha=0.5)
            
            plt.tight_layout()
            output_path = os.path.join(OUTPUT_DIR, "cot_tff_comparison.png")
            plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
            plt.close()
            logger.info(f"COT vs TFF comparison chart saved to {output_path}")
            
        except Exception as e:
            logger.error(f"Error creating comparison chart: {e}")

    def _create_model_performance_dashboard(self):
        """Creates a machine learning models performance dashboard."""
        if not self.models:
            logger.warning("No trained models available for dashboard creation")
            return

        try:
            # Collect metrics from all models
            model_metrics = []
            for pair in self.models.keys():
                feature_importance_path = os.path.join(OUTPUT_DIR, f"feature_importance_{pair}.csv")
                if os.path.exists(feature_importance_path):
                    importance_df = pd.read_csv(feature_importance_path)
                    top_feature = importance_df.iloc[0]['feature'] if not importance_df.empty else "N/A"
                    top_importance = importance_df.iloc[0]['importance'] if not importance_df.empty else 0
                    
                    model_metrics.append({
                        'pair': pair,
                        'top_feature': top_feature,
                        'top_importance': top_importance
                    })
            
            if not model_metrics:
                return
                
            metrics_df = pd.DataFrame(model_metrics)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=FIGSIZE, dpi=DPI)
            fig.suptitle("Machine Learning Models Performance Dashboard", 
                        fontsize=14, fontweight='bold')
            
            # Feature importance chart
            pairs = metrics_df['pair'].tolist()
            importances = metrics_df['top_importance'].tolist()
            
            bars = ax1.bar(pairs, importances, color=plt.cm.Set3(np.linspace(0, 1, len(pairs))), alpha=0.8)
            ax1.set_title("Top Feature Importance by Currency Pairs", fontsize=12)
            ax1.set_xlabel("Currency Pairs", fontsize=10)
            ax1.set_ylabel("Feature Importance", fontsize=10)
            ax1.tick_params(axis='x', rotation=45)
            
            # Add values on bars
            for bar, importance in zip(bars, importances):
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 0.001,
                        f'{importance:.3f}', ha='center', va='bottom', fontsize=9)
            
            # Pie chart of feature type distribution
            feature_types = {}
            for _, row in metrics_df.iterrows():
                feature = row['top_feature']
                if 'Net_' in feature:
                    category = 'COT/TFF Positions'
                elif 'volatility' in feature:
                    category = 'Volatility'
                elif 'sma' in feature:
                    category = 'Moving Averages'
                elif 'change' in feature:
                    category = 'Price Changes'
                else:
                    category = 'Other'
                
                feature_types[category] = feature_types.get(category, 0) + 1
            
            if feature_types:
                ax2.pie(feature_types.values(), labels=feature_types.keys(), autopct='%1.1f%%',
                       colors=plt.cm.Pastel1(np.linspace(0, 1, len(feature_types))))
                ax2.set_title("Distribution of Important Feature Types", fontsize=12)
            
            plt.tight_layout()
            output_path = os.path.join(OUTPUT_DIR, "model_performance_dashboard.png")
            plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
            plt.close()
            logger.info(f"Model performance dashboard saved to {output_path} (750px width)")
            
        except Exception as e:
            logger.error(f"Error creating performance dashboard: {e}")

    def _get_historical_prices(self, pair: str) -> pd.DataFrame:
        """Loads historical prices from MetaTrader5."""
        try:
            symbol = self.symbol_mapping.get(pair, pair)
            timeframe = mt5.TIMEFRAME_H1
            utc_from = datetime.now() - timedelta(days=self.days_history)
            for attempt in range(3):
                rates = mt5.copy_rates_from(symbol, timeframe, utc_from, 24 * self.days_history)
                if rates is not None and len(rates) > 0:
                    break
                logger.warning(f"Attempt {attempt+1}: No data for {symbol}. Retrying...")
                import time
                time.sleep(1)

            if rates is None or len(rates) == 0:
                logger.warning(f"No historical data for {symbol}")
                return pd.DataFrame()

            df = pd.DataFrame(rates)
            df['time'] = pd.to_datetime(df['time'], unit='s')
            df.set_index('time', inplace=True)
            df['price_change_24h'] = df['close'].shift(-24) / df['close'] - 1
            df.dropna(inplace=True)
            return df[['open', 'high', 'low', 'close', 'tick_volume', 'price_change_24h']]
        except Exception as e:
            logger.error(f"Error loading historical prices for {pair}: {e}")
            return pd.DataFrame()

    def _map_pair_to_cot_tff(self, pair: str) -> str:
        """Maps currency pair to COT/TFF market."""
        mapping = {
            'EURUSD': 'EURO FX',
            'GBPUSD': 'BRITISH POUND',
            'USDJPY': 'JAPANESE YEN',
            'AUDUSD': 'AUSTRALIAN DOLLAR',
            'USDCAD': 'CANADIAN DOLLAR',
            'USDCHF': 'SWISS FRANC',
            'NZDUSD': 'NEW ZEALAND DOLLAR'
        }
        base_pair = pair.replace('.ecn', '')[:6]
        return mapping.get(base_pair, '')

    def _prepare_features(self, pair: str) -> pd.DataFrame:
        """Prepares features for prediction model."""
        try:
            df_prices = self._get_historical_prices(pair)
            if df_prices.empty:
                logger.warning(f"No price data for {pair}")
                df = pd.DataFrame(index=[datetime.now()], columns=['close', 'price_change_24h'])
                df['close'] = 1.0
                df['price_change_24h'] = 0.0
            else:
                df = df_prices.copy()

            df['volatility'] = (df['high'] - df['low']) / df['close'] if 'high' in df.columns else 0.0
            df['volume_sma_24'] = df['tick_volume'].rolling(window=24).mean() if 'tick_volume' in df.columns else 0.0
            df['price_sma_24'] = df['close'].rolling(window=24).mean() if 'close' in df.columns else df['close']
            df['price_change_1h'] = df['close'].pct_change() if 'close' in df.columns else 0.0

            market = self._map_pair_to_cot_tff(pair)
            if market and not self.cot_data.empty:
                cot_subset = self.cot_data[self.cot_data["Market_and_Exchange_Names"].str.contains(market, case=False, na=False)]
                if not cot_subset.empty:
                    cot_features = cot_subset[['Net_NonComm', 'Net_Comm']].mean().to_frame().T
                    for col in cot_features.columns:
                        df[col] = cot_features[col].iloc[0]

            if market and not self.tff_data.empty:
                tff_subset = self.tff_data[self.tff_data["Market_and_Exchange_Names"].str.contains(market, case=False, na=False)]
                if not tff_subset.empty:
                    if 'Net_Lev_Money' in tff_subset.columns and 'Net_Asset_Mgr' in tff_subset.columns:
                        tff_features = tff_subset[['Net_Lev_Money', 'Net_Asset_Mgr']].mean().to_frame().T
                        for col in tff_features.columns:
                            df[col] = tff_features[col].iloc[0]
                    else:
                        logger.warning(f"Missing Net_Lev_Money or Net_Asset_Mgr columns in TFF data for {pair}")

            for col in ['Net_NonComm', 'Net_Comm', 'Net_Lev_Money', 'Net_Asset_Mgr']:
                if col in df.columns:
                    df[f'{col}_lag1'] = df[col].shift(1)
                    df[f'{col}_change'] = df[col].pct_change().fillna(0)

            df.dropna(inplace=True)
            return df
        except Exception as e:
            logger.error(f"Error preparing features for {pair}: {e}")
            return pd.DataFrame()

    def _train_model(self, pair: str):
        """Trains model for price prediction."""
        try:
            df = self._prepare_features(pair)
            if df.empty or len(df) < 10:
                logger.warning(f"Insufficient data for training model for {pair}")
                return

            X = df.drop(columns=['price_change_24h'])
            y = df['price_change_24h']

            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)

            X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

            model = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)
            model.fit(X_train, y_train)

            train_score = model.score(X_train, y_train)
            test_score = model.score(X_test, y_test)
            logger.info(f"Model for {pair}: Train R² = {train_score:.4f}, Test R² = {test_score:.4f}")

            self.models[pair] = model
            self.scalers[pair] = scaler

            feature_importance = pd.DataFrame({
                'feature': X.columns,
                'importance': model.feature_importances_
            }).sort_values('importance', ascending=False)
            output_path = os.path.join(OUTPUT_DIR, f"feature_importance_{pair}.csv")
            feature_importance.to_csv(output_path, index=False)
            logger.info(f"Feature importance saved to {output_path}")

        except Exception as e:
            logger.error(f"Error training model for {pair}: {e}")

        # Create dashboard after training all models
        if len(self.models) == len(self.pairs):
            self._create_model_performance_dashboard()

    async def get_price_forecast(self, pair: str) -> dict:
        """
        Returns a 24-hour price forecast for the specified currency pair.

        Returns:
            dict: {'pair': str, 'forecast_price': float, 'confidence': float}
        """
        try:
            if pair not in self.models or pair not in self.scalers:
                logger.warning(f"Model for {pair} is not trained")
                return {'pair': pair, 'forecast_price': None, 'confidence': 0.0}

            df = self._prepare_features(pair)
            if df.empty:
                logger.warning(f"No data available for forecasting {pair}")
                return {'pair': pair, 'forecast_price': None, 'confidence': 0.0}

            X_latest = df.drop(columns=['price_change_24h']).iloc[-1:]
            X_scaled = self.scalers[pair].transform(X_latest)

            model = self.models[pair]
            price_change_pred = model.predict(X_scaled)[0]
            confidence = model.score(X_scaled, df['price_change_24h'].iloc[-1:]) if len(df) > 1 else 0.6

            tick = mt5.symbol_info_tick(self.symbol_mapping.get(pair, pair))
            if not tick:
                logger.warning(f"No current data available for {pair}")
                return {'pair': pair, 'forecast_price': None, 'confidence': 0.0}

            current_price = (tick.bid + tick.ask) / 2
            forecast_price = current_price * (1 + price_change_pred)

            self.forecasts[pair] = {
                'forecast_price': forecast_price,
                'confidence': confidence,
                'current_price': current_price,
                'price_change_pred': price_change_pred,
                'timestamp': datetime.now()
            }

            forecast_df = pd.DataFrame([self.forecasts[pair]])
            output_path = os.path.join(OUTPUT_DIR, f"forecast_{pair}.csv")
            forecast_df.to_csv(output_path, index=False)
            logger.info(f"Forecast for {pair} saved to {output_path}")

            self._visualize_forecast(pair, current_price, forecast_price, confidence)

            return {
                'pair': pair,
                'forecast_price': forecast_price,
                'confidence': max(0.0, min(1.0, confidence))
            }
        except Exception as e:
            logger.error(f"Error forecasting for {pair}: {e}")
            return {'pair': pair, 'forecast_price': None, 'confidence': 0.0}

    def _visualize_forecast(self, pair: str, current_price: float, forecast_price: float, confidence: float):
        """Visualizes current price and forecast."""
        fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
        
        bars = ax.bar(['Current Price', 'Forecast Price'], [current_price, forecast_price], 
                     color=['#3498db', '#2ecc71'], alpha=0.8, edgecolor='black', linewidth=1)
        
        ax.set_title(f"24-Hour Price Forecast for {pair} (Confidence: {confidence:.2f})", 
                    fontsize=14, fontweight='bold', pad=20)
        ax.set_ylabel('Price', fontsize=12)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar, price in zip(bars, [current_price, forecast_price]):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
                   f'{price:.5f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
        
        # Add percentage change
        price_change_pct = ((forecast_price - current_price) / current_price) * 100
        change_color = 'green' if price_change_pct > 0 else 'red'
        ax.text(0.5, 0.95, f'Expected Change: {price_change_pct:+.2f}%', 
               transform=ax.transAxes, ha='center', va='top', 
               fontsize=12, fontweight='bold', color=change_color,
               bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
        
        plt.tight_layout()
        output_path = os.path.join(OUTPUT_DIR, f"forecast_{pair}_plot.png")
        plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
        plt.close()
        logger.info(f"Forecast chart saved to {output_path}")

    def _create_summary_dashboard(self):
        """Creates a comprehensive summary dashboard with all forecasts."""
        if not self.forecasts:
            logger.warning("No forecasts available for summary dashboard")
            return

        try:
            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(WIDTH_PX*2/DPI, HEIGHT_PX*2/DPI), dpi=DPI)
            fig.suptitle("Currency Forecast Summary Dashboard", fontsize=16, fontweight='bold', y=0.95)
            
            # Forecast vs Current Prices
            pairs = list(self.forecasts.keys())
            current_prices = [self.forecasts[pair]['current_price'] for pair in pairs]
            forecast_prices = [self.forecasts[pair]['forecast_price'] for pair in pairs]
            
            x_pos = np.arange(len(pairs))
            width = 0.35
            
            ax1.bar(x_pos - width/2, current_prices, width, label='Current Price', alpha=0.8, color='#3498db')
            ax1.bar(x_pos + width/2, forecast_prices, width, label='Forecast Price', alpha=0.8, color='#2ecc71')
            ax1.set_title('Current vs Forecast Prices', fontweight='bold')
            ax1.set_xlabel('Currency Pairs')
            ax1.set_ylabel('Price')
            ax1.set_xticks(x_pos)
            ax1.set_xticklabels(pairs, rotation=45)
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Confidence Levels
            confidences = [self.forecasts[pair]['confidence'] for pair in pairs]
            bars = ax2.bar(pairs, confidences, color=plt.cm.RdYlGn(np.array(confidences)), alpha=0.8)
            ax2.set_title('Model Confidence Levels', fontweight='bold')
            ax2.set_xlabel('Currency Pairs')
            ax2.set_ylabel('Confidence Score')
            ax2.set_ylim(0, 1)
            ax2.tick_params(axis='x', rotation=45)
            ax2.grid(True, alpha=0.3)
            
            # Add confidence values on bars
            for bar, conf in zip(bars, confidences):
                height = bar.get_height()
                ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{conf:.2f}', ha='center', va='bottom', fontweight='bold')
            
            # Expected Price Changes
            price_changes = [((self.forecasts[pair]['forecast_price'] - self.forecasts[pair]['current_price']) / 
                             self.forecasts[pair]['current_price']) * 100 for pair in pairs]
            colors = ['green' if change > 0 else 'red' for change in price_changes]
            
            bars = ax3.bar(pairs, price_changes, color=colors, alpha=0.7)
            ax3.set_title('Expected 24h Price Changes (%)', fontweight='bold')
            ax3.set_xlabel('Currency Pairs')
            ax3.set_ylabel('Price Change (%)')
            ax3.axhline(y=0, color='black', linestyle='-', alpha=0.3)
            ax3.tick_params(axis='x', rotation=45)
            ax3.grid(True, alpha=0.3)
            
            # Add change values on bars
            for bar, change in zip(bars, price_changes):
                height = bar.get_height()
                ax3.text(bar.get_x() + bar.get_width()/2., height + (0.01 if height > 0 else -0.03),
                        f'{change:+.2f}%', ha='center', va='bottom' if height > 0 else 'top', 
                        fontweight='bold', fontsize=9)
            
            # Risk Assessment Matrix
            risk_levels = []
            for pair in pairs:
                conf = self.forecasts[pair]['confidence']
                change = abs(price_changes[pairs.index(pair)])
                if conf > 0.7 and change > 0.5:
                    risk = 'High Opportunity'
                elif conf > 0.7 and change <= 0.5:
                    risk = 'Low Risk'
                elif conf <= 0.7 and change > 0.5:
                    risk = 'High Risk'
                else:
                    risk = 'Neutral'
                risk_levels.append(risk)
            
            risk_counts = {risk: risk_levels.count(risk) for risk in set(risk_levels)}
            colors_pie = ['#2ecc71', '#f39c12', '#e74c3c', '#95a5a6']
            
            ax4.pie(risk_counts.values(), labels=risk_counts.keys(), autopct='%1.1f%%',
                   colors=colors_pie[:len(risk_counts)], startangle=90)
            ax4.set_title('Risk Assessment Distribution', fontweight='bold')
            
            plt.tight_layout()
            output_path = os.path.join(OUTPUT_DIR, "forecast_summary_dashboard.png")
            plt.savefig(output_path, dpi=DPI, bbox_inches='tight')
            plt.close()
            logger.info(f"Summary dashboard saved to {output_path}")
            
        except Exception as e:
            logger.error(f"Error creating summary dashboard: {e}")

    async def update_forecasts(self):
        """Updates forecasts for all pairs."""
        logger.info("Updating price forecasts...")
        for pair in self.pairs:
            forecast = await self.get_price_forecast(pair)
            logger.info(f"Forecast for {pair}: Price={forecast['forecast_price']}, Confidence={forecast['confidence']:.2f}")
        
        # Create summary dashboard after all forecasts are updated
        self._create_summary_dashboard()

    def get_trading_signals(self) -> pd.DataFrame:
        """Generates trading signals based on forecasts."""
        if not self.forecasts:
            logger.warning("No forecasts available for generating trading signals")
            return pd.DataFrame()
        
        signals = []
        for pair, forecast_data in self.forecasts.items():
            current_price = forecast_data['current_price']
            forecast_price = forecast_data['forecast_price']
            confidence = forecast_data['confidence']
            
            price_change_pct = ((forecast_price - current_price) / current_price) * 100
            
            # Generate signal based on price change and confidence
            if confidence > 0.7:
                if price_change_pct > 0.5:
                    signal = 'STRONG BUY'
                elif price_change_pct > 0.1:
                    signal = 'BUY'
                elif price_change_pct < -0.5:
                    signal = 'STRONG SELL'
                elif price_change_pct < -0.1:
                    signal = 'SELL'
                else:
                    signal = 'HOLD'
            else:
                signal = 'NEUTRAL'
            
            signals.append({
                'pair': pair,
                'signal': signal,
                'current_price': current_price,
                'forecast_price': forecast_price,
                'expected_change_pct': price_change_pct,
                'confidence': confidence,
                'timestamp': forecast_data['timestamp']
            })
        
        signals_df = pd.DataFrame(signals)
        output_path = os.path.join(OUTPUT_DIR, "trading_signals.csv")
        signals_df.to_csv(output_path, index=False)
        logger.info(f"Trading signals saved to {output_path}")
        
        return signals_df

    def __del__(self):
        """Cleanup when object is deleted."""
        mt5.shutdown()


# Example usage
if __name__ == "__main__":
    # Initialize the forecast module
    currency_pairs = ['EURUSD', 'GBPUSD', 'USDJPY', 'AUDUSD']
    
    try:
        forecast_module = CurrencyForecastModule(currency_pairs, days_history=30)
        
        # Update forecasts
        asyncio.run(forecast_module.update_forecasts())
        
        # Generate trading signals
        signals = forecast_module.get_trading_signals()
        print("\nTrading Signals:")
        print(signals.to_string(index=False))
        
        logger.info("Currency forecasting completed successfully!")
        logger.info("Generated charts:")
        logger.info("1. cot_net_positions.png - COT Net Non-Commercial positions")
        logger.info("2. tff_net_positions.png - TFF Net Leveraged Funds positions") 
        logger.info("3. cot_tff_comparison.png - COT vs TFF comparison for major currencies")
        logger.info("4. model_performance_dashboard.png - ML models performance dashboard")
        logger.info("5. forecast_summary_dashboard.png - Comprehensive forecast summary")
        logger.info("6. Individual forecast charts for each currency pair")
        logger.info("7. trading_signals.csv - Generated trading signals")
        
    except Exception as e:
        logger.error(f"Error in main execution: {e}")
    
    finally:
        mt5.shutdown()
