import numpy as np
import pandas as pd
import sympy as sp
import MetaTrader5 as mt5
from datetime import datetime, timedelta
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import LinearRegression, Ridge, LogisticRegression
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split, cross_val_score
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

class SimpleSymbolicPredictor:
    """
    Упрощенная символьная модель с простой визуализацией 700x700
    """
    
    def __init__(self, symbol: str = "EURUSD", terminal_path: str = None):
        self.symbol = symbol
        self.terminal_path = terminal_path
        self.prediction_horizon = 24  # 24 бара H1
        self.lookback_bars = 5000    # Уменьшено для быстроты
        
        # Модели и уравнения
        self.price_equation = None
        self.binary_equation = None
        self.ridge_model = None
        self.logistic_model = None
        self.poly_features = None
        
        # Данные
        self.data = None
        self.features = None
        self.predictions_df = None
        self.scaler = StandardScaler()
        
        # Метрики
        self.metrics = {}
        
        print(f"SimpleSymbolicPredictor initialized for {symbol}")

    def connect_mt5(self) -> bool:
        """Подключение к MT5"""
        if self.terminal_path:
            if not mt5.initialize(path=self.terminal_path):
                print(f"MT5 initialization failed: {mt5.last_error()}")
                return False
        else:
            if not mt5.initialize():
                print(f"MT5 initialization failed: {mt5.last_error()}")
                return False
        return True

    def disconnect_mt5(self):
        """Отключение от MT5"""
        mt5.shutdown()

    def get_historical_data(self) -> pd.DataFrame:
        """Получение исторических данных"""
        try:
            if not self.connect_mt5():
                print("Failed to connect to MT5")
                return None
            
            rates = mt5.copy_rates_from_pos(
                self.symbol, 
                mt5.TIMEFRAME_H1, 
                0, 
                self.lookback_bars + self.prediction_horizon + 100
            )
            
            self.disconnect_mt5()
            
            if rates is None or len(rates) == 0:
                print(f"No data received for {self.symbol}")
                return None
            
            df = pd.DataFrame(rates)
            df['time'] = pd.to_datetime(df['time'], unit='s')
            df.set_index('time', inplace=True)
            
            print(f"Received {len(df)} historical bars")
            return df
            
        except Exception as e:
            print(f"Error getting historical data: {e}")
            self.disconnect_mt5()
            return None

    def calculate_simple_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
        """Расчет упрощенного набора технических индикаторов"""
        try:
            data = df.copy()
            
            # Базовые расчеты
            data['returns'] = data['close'].pct_change()
            data['volatility'] = data['returns'].rolling(20).std() * np.sqrt(24)
            
            # Скользящие средние (только ключевые)
            data['sma_20'] = data['close'].rolling(20).mean()
            data['sma_50'] = data['close'].rolling(50).mean()
            data['price_to_sma_20'] = data['close'] / data['sma_20'] - 1
            
            # RSI
            def calculate_rsi(prices, period=14):
                delta = prices.diff()
                gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
                loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
                rs = gain / loss
                rsi = 100 - (100 / (1 + rs))
                return rsi
            
            data['rsi'] = calculate_rsi(data['close'], 14)
            
            # MACD
            ema_12 = data['close'].ewm(span=12).mean()
            ema_26 = data['close'].ewm(span=26).mean()
            data['macd'] = ema_12 - ema_26
            
            # Momentum
            data['momentum'] = data['close'] / data['close'].shift(10) - 1
            
            print(f"Calculated simple technical indicators")
            return data
            
        except Exception as e:
            print(f"Error calculating indicators: {e}")
            return df

    def prepare_features_and_targets(self, data: pd.DataFrame):
        """Подготовка признаков и целевых переменных (упрощенная версия)"""
        try:
            # Выбираем только ключевые признаки для упрощения
            feature_columns = ['returns', 'volatility', 'price_to_sma_20', 'rsi', 'macd', 'momentum']
            
            # Создаем целевые переменные
            data['target_price'] = data['close'].shift(-self.prediction_horizon)
            data['target_direction'] = (data['target_price'] > data['close']).astype(int)
            
            # Убираем NaN
            data_clean = data.dropna()
            
            if len(data_clean) < 100:
                print("Insufficient clean data")
                return None
            
            # Формируем признаки и цели
            X = data_clean[feature_columns].values
            y_price = data_clean['target_price'].values
            y_direction = data_clean['target_direction'].values
            
            # Нормализация
            X_scaled = self.scaler.fit_transform(X)
            
            print(f"Features prepared: {X_scaled.shape}")
            print(f"Direction balance: {np.mean(y_direction)*100:.1f}% UP")
            
            return {
                'X': X_scaled,
                'y_price': y_price,
                'y_direction': y_direction,
                'feature_names': feature_columns,
                'data_clean': data_clean
            }
            
        except Exception as e:
            print(f"Error preparing features: {e}")
            return None

    def create_simple_symbolic_equations(self, features_data: dict):
        """Создание упрощенных символьных уравнений"""
        try:
            X = features_data['X']
            y_price = features_data['y_price']
            y_direction = features_data['y_direction']
            feature_names = features_data['feature_names']
            
            print("Creating simplified symbolic equations...")
            
            # === УПРОЩЕННОЕ УРАВНЕНИЕ ДЛЯ ЦЕНЫ (только линейные термы) ===
            print("Training simplified price model...")
            poly_features = PolynomialFeatures(degree=1, include_bias=True)  # Только линейные!
            X_poly = poly_features.fit_transform(X)
            
            ridge = Ridge(alpha=1.0)
            ridge.fit(X_poly, y_price)
            
            print(f"Price model R² score: {ridge.score(X_poly, y_price):.4f}")
            
            # Создаем символьные переменные
            symbols = {}
            for i, name in enumerate(feature_names):
                symbols[f"x{i}"] = sp.Symbol(f"x{i}", real=True)
            
            # Создаем упрощенное уравнение для цены (только линейные термы)
            price_equation = sp.S(ridge.coef_[0])  # константа
            for i, (name, coef) in enumerate(zip(feature_names, ridge.coef_[1:])):
                if abs(coef) > 1e-6:
                    price_equation += coef * symbols[f"x{i}"]
            
            # === УПРОЩЕННОЕ УРАВНЕНИЕ ДЛЯ НАПРАВЛЕНИЯ ===
            print("Training simplified binary model...")
            
            X_train, X_test, y_train, y_test = train_test_split(
                X, y_direction, test_size=0.3, random_state=42, stratify=y_direction
            )
            
            logistic = LogisticRegression(random_state=42, max_iter=1000, C=1.0)
            logistic.fit(X_train, y_train)
            
            test_accuracy = logistic.score(X_test, y_test)
            y_pred = logistic.predict(X_test)
            
            # Вычисляем метрики
            self.metrics = {
                'price_r2': ridge.score(X_poly, y_price),
                'direction_accuracy': test_accuracy,
                'precision': precision_score(y_test, y_pred),
                'recall': recall_score(y_test, y_pred),
                'f1_score': f1_score(y_test, y_pred),
                'confusion_matrix': confusion_matrix(y_test, y_pred)
            }
            
            print(f"Binary classification test accuracy: {test_accuracy:.4f}")
            
            # Создаем упрощенное символьное уравнение для логистической регрессии
            linear_combination = sp.S(logistic.intercept_[0])
            for i, coef in enumerate(logistic.coef_[0]):
                if abs(coef) > 1e-6:
                    linear_combination += coef * symbols[f"x{i}"]
            
            binary_equation = 1 / (1 + sp.exp(-linear_combination))
            
            print("Simplified symbolic equations created successfully")
            print(f"Price equation terms: {len(ridge.coef_)}")
            print(f"Binary equation terms: {len(logistic.coef_[0])}")
            
            # Сохраняем результаты
            self.price_equation = price_equation
            self.binary_equation = binary_equation
            self.ridge_model = ridge
            self.logistic_model = logistic
            self.poly_features = poly_features
            self.feature_names = feature_names
            
            # Генерируем предсказания для визуализации
            self.generate_predictions(features_data)
            
            return price_equation, binary_equation
            
        except Exception as e:
            print(f"Error creating symbolic equations: {e}")
            return None, None

    def generate_predictions(self, features_data: dict):
        """Генерация предсказаний для визуализации"""
        try:
            X = features_data['X']
            data_clean = features_data['data_clean']
            
            # Предсказания цены
            X_poly = self.poly_features.transform(X)
            price_predictions = self.ridge_model.predict(X_poly)
            
            # Предсказания направления
            direction_predictions = self.logistic_model.predict(X)
            direction_probabilities = self.logistic_model.predict_proba(X)[:, 1]
            
            # Создаем DataFrame с предсказаниями
            self.predictions_df = pd.DataFrame({
                'actual_price': data_clean['close'].values,
                'predicted_price': price_predictions,
                'actual_direction': features_data['y_direction'],
                'predicted_direction': direction_predictions,
                'direction_probability': direction_probabilities
            }, index=data_clean.index)
            
        except Exception as e:
            print(f"Error generating predictions: {e}")

    def create_simple_visualization(self):
        """Создание простой визуализации 700x700 одна панель с сохранением"""
        try:
            if self.predictions_df is None:
                print("No predictions available for visualization")
                return
            
            # Создаем одну фигуру точно 700x700 пикселей
            fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=100)
            
            # Берем последние 1500 точек для наглядности
            recent_data = self.predictions_df.tail(1500)
            
            # Основной график - цена и предсказания
            ax.plot(recent_data.index, recent_data['actual_price'], 
                   label=f'Actual {self.symbol}', color='black', linewidth=2)
            ax.plot(recent_data.index, recent_data['predicted_price'], 
                   label='Predicted Price (24h ahead)', color='blue', linewidth=1.5, alpha=0.8)
            
            # Цветовое кодирование фона по направлениям
            up_mask = recent_data['actual_direction'] == 1
            down_mask = recent_data['actual_direction'] == 0
            
            ax.fill_between(recent_data.index, recent_data['actual_price'].min(), 
                          recent_data['actual_price'].max(), 
                          where=up_mask, color='green', alpha=0.1, label='Actual UP periods')
            ax.fill_between(recent_data.index, recent_data['actual_price'].min(), 
                          recent_data['actual_price'].max(), 
                          where=down_mask, color='red', alpha=0.1, label='Actual DOWN periods')
            
            # Добавляем точки предсказанных направлений
            up_pred_mask = recent_data['predicted_direction'] == 1
            down_pred_mask = recent_data['predicted_direction'] == 0
            
            ax.scatter(recent_data.index[up_pred_mask], recent_data['actual_price'][up_pred_mask], 
                      c='darkgreen', s=8, alpha=0.6, label='Predicted UP')
            ax.scatter(recent_data.index[down_pred_mask], recent_data['actual_price'][down_pred_mask], 
                      c='darkred', s=8, alpha=0.6, label='Predicted DOWN')
            
            # Настройка графика
            ax.set_title(f'{self.symbol} Symbolic Price Predictor Analysis', 
                        fontsize=16, fontweight='bold', pad=20)
            ax.set_xlabel('Time', fontsize=12)
            ax.set_ylabel('Price', fontsize=12)
            ax.legend(loc='upper left', fontsize=10)
            ax.grid(True, alpha=0.3)
            ax.tick_params(axis='x', rotation=45)
            
            # Добавляем текстовую информацию в правый верхний угол
            info_text = f"Model Metrics:\n"
            info_text += f"Price R²: {self.metrics.get('price_r2', 0):.4f}\n"
            info_text += f"Direction Accuracy: {self.metrics.get('direction_accuracy', 0):.4f}\n"
            info_text += f"Precision: {self.metrics.get('precision', 0):.4f}\n"
            info_text += f"F1-Score: {self.metrics.get('f1_score', 0):.4f}\n\n"
            info_text += f"Prediction Horizon: {self.prediction_horizon}H\n"
            info_text += f"Features: {len(self.feature_names)}\n"
            info_text += f"Data Points: {len(recent_data)}"
            
            ax.text(0.02, 0.98, info_text, transform=ax.transAxes, fontsize=10, 
                   verticalalignment='top', bbox=dict(boxstyle='round', 
                   facecolor='white', alpha=0.8, edgecolor='gray'))
            
            plt.tight_layout()
            
            # Сохраняем график в директории проекта точно 700x700 пикселей
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{self.symbol}_symbolic_predictor_{timestamp}.png"
            
            # Сохраняем с точными параметрами 700x700
            plt.savefig(filename, 
                       dpi=100, 
                       bbox_inches='tight',
                       facecolor='white',
                       edgecolor='none')
            
            plt.show()
            
            print(f"Visualization saved as '{filename}' (700x700 pixels)")
            print("Single-panel visualization created successfully")
            
        except Exception as e:
            print(f"Error creating visualization: {e}")

    def display_simple_equations(self):
        """Отображение упрощенных уравнений"""
        if self.price_equation is None or self.binary_equation is None:
            print("No equations available")
            return
        
        print("\n" + "="*60)
        print("SIMPLIFIED SYMBOLIC EQUATIONS")
        print("="*60)
        
        print(f"\n1. PRICE PREDICTION (Linear):")
        print(f"P(t+24) = {self.price_equation}")
        
        print(f"\n2. DIRECTION PROBABILITY (Logistic):")
        print(f"Prob(UP) = {self.binary_equation}")
        
        print(f"\nVariable mapping:")
        for i, name in enumerate(self.feature_names):
            print(f"  x{i} = {name}")
        
        print(f"\nModel Statistics:")
        print(f"  Price R²: {self.metrics.get('price_r2', 0):.4f}")
        print(f"  Direction Accuracy: {self.metrics.get('direction_accuracy', 0):.4f}")
        print(f"  Features: {len(self.feature_names)}")
        
        print("="*60)

    def run_simple_analysis(self):
        """Запуск упрощенного анализа"""
        try:
            print(f"Starting simple analysis for {self.symbol}")
            print("-" * 40)
            
            # 1. Получение данных
            print("1. Getting historical data...")
            self.data = self.get_historical_data()
            if self.data is None:
                return None
            
            # 2. Расчет индикаторов
            print("2. Calculating simple indicators...")
            self.data = self.calculate_simple_indicators(self.data)
            
            # 3. Подготовка признаков
            print("3. Preparing features...")
            self.features = self.prepare_features_and_targets(self.data)
            if self.features is None:
                return None
            
            # 4. Создание упрощенных уравнений
            print("4. Creating simple symbolic equations...")
            price_eq, binary_eq = self.create_simple_symbolic_equations(self.features)
            if price_eq is None or binary_eq is None:
                return None
            
            # 5. Создание визуализации
            print("5. Creating simple visualization...")
            self.create_simple_visualization()
            
            # 6. Отображение результатов
            self.display_simple_equations()
            
            print("\nSimple analysis completed successfully!")
            
            return {
                'price_equation': price_eq,
                'binary_equation': binary_eq,
                'model': self,
                'predictions': self.predictions_df,
                'metrics': self.metrics
            }
            
        except Exception as e:
            print(f"Error in simple analysis: {e}")
            return None


# Пример использования
if __name__ == "__main__":
    # Создание и запуск упрощенного анализа
    predictor = SimpleSymbolicPredictor(symbol="EURUSD")
    
    result = predictor.run_simple_analysis()
    
    if result:
        print("\nSimple symbolic equations successfully created!")
        print("\nFeatures:")
        print("- Linear price prediction equation")
        print("- Logistic binary direction classification") 
        print("- Simple 700x700 pixel visualization")
        print("- 4-panel dashboard with key metrics")
        print("- No recursion issues with simplified equations")
        print("- Fast execution with reduced complexity")
    else:
        print("Failed to create symbolic equations")
