import matplotlib
matplotlib.use('Agg')  # Set the backend to 'Agg' before importing pyplot
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import MetaTrader5 as mt5
from sklearn.preprocessing import StandardScaler
from typing import Dict, Tuple
from sklearn.model_selection import train_test_split
from pathlib import Path

# Constants
V_rest = -65.0  # Resting potential (mV)
Cm = 1.0  # Membrane capacitance (μF/cm²)
g_Na = 120.0  # Maximum Na+ channel conductance (mS/cm²)
g_K = 36.0  # Maximum K+ channel conductance (mS/cm²)
g_L = 0.3  # Leak conductance (mS/cm²)
E_Na = 50.0  # Na+ equilibrium potential (mV)
E_K = -77.0  # K+ equilibrium potential (mV)
E_L = -54.4  # Leak equilibrium potential (mV)

# Plasma parameters
plasma_strength = 1.0  # Plasma influence strength
plasma_decay = 20.0  # Plasma influence decay time

# STDP parameters
A_plus = 0.1  # Enhancement coefficient for positive Δt
A_minus = 0.1  # Weakening coefficient for negative Δt
tau_plus = 20.0  # Decay time for positive Δt
tau_minus = 20.0  # Decay time for negative Δt

# Market features calculation class
class MarketFeatures:
    def __init__(self, window_size=20):
        self.window_size = window_size
        self.price_history = []
        self.scaler = StandardScaler()
        
    def add_price(self, price: float, ohlc_data: pd.DataFrame) -> Dict[str, float]:
        self.price_history.append(price)
        if len(self.price_history) < self.window_size:
            return self._get_default_features()
        
        features = {}
        
        # Moving averages
        features['sma_10'] = self._calculate_sma(ohlc_data['close'], window=10)
        features['sma_20'] = self._calculate_sma(ohlc_data['close'], window=20)
        features['ema_10'] = self._calculate_ema(ohlc_data['close'], window=10)
        features['ema_20'] = self._calculate_ema(ohlc_data['close'], window=20)
        
        # RSI
        features['rsi'] = self._calculate_rsi(ohlc_data['close'], window=14)
        
        # MACD
        macd, signal = self._calculate_macd(ohlc_data['close'])
        features['macd'] = macd
        features['macd_signal'] = signal
        
        # Bollinger Bands
        upper_band, lower_band = self._calculate_bollinger_bands(ohlc_data['close'], window=20)
        features['bollinger_upper'] = upper_band
        features['bollinger_lower'] = lower_band
        
        # ATR
        features['atr'] = self._calculate_atr(ohlc_data, window=14)
        
        # Momentum
        features['momentum'] = self._calculate_momentum(ohlc_data['close'], window=10)
        
        # Volume indicators
        features['volume_sma_10'] = self._calculate_sma(ohlc_data['tick_volume'], window=10)
        features['volume_sma_20'] = self._calculate_sma(ohlc_data['tick_volume'], window=20)
        
        # Time features
        features['day_of_week'] = ohlc_data.index[-1].dayofweek
        features['hour'] = ohlc_data.index[-1].hour
        features['month'] = ohlc_data.index[-1].month
        
        # Feature normalization
        feature_values = np.array(list(features.values())).reshape(1, -1)
        normalized_features = self.scaler.fit_transform(feature_values)
        return {k: normalized_features[0][i] for i, k in enumerate(features.keys())}
    
    def _get_default_features(self) -> Dict[str, float]:
        return {f'feature_{i}': 0.0 for i in range(100)}
    
    def _calculate_sma(self, series: pd.Series, window: int) -> float:
        if len(series) < window:
            return 0.0
        return series.rolling(window=window).mean().iloc[-1]
    
    def _calculate_ema(self, series: pd.Series, window: int) -> float:
        if len(series) < window:
            return 0.0
        return series.ewm(span=window, adjust=False).mean().iloc[-1]
    
    def _calculate_rsi(self, series: pd.Series, window: int) -> float:
        if len(series) < window + 1:
            return 50.0
        delta = series.diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
        rs = gain / loss
        return 100 - (100 / (1 + rs.iloc[-1]))
    
    def _calculate_macd(self, series: pd.Series) -> Tuple[float, float]:
        if len(series) < 26:
            return 0.0, 0.0
        ema_12 = series.ewm(span=12, adjust=False).mean()
        ema_26 = series.ewm(span=26, adjust=False).mean()
        macd = ema_12 - ema_26
        signal = macd.ewm(span=9, adjust=False).mean()
        return macd.iloc[-1], signal.iloc[-1]
    
    def _calculate_bollinger_bands(self, series: pd.Series, window: int) -> Tuple[float, float]:
        if len(series) < window:
            return 0.0, 0.0
        sma = series.rolling(window=window).mean()
        std = series.rolling(window=window).std()
        upper_band = sma + (2 * std)
        lower_band = sma - (2 * std)
        return upper_band.iloc[-1], lower_band.iloc[-1]
    
    def _calculate_atr(self, ohlc_data: pd.DataFrame, window: int) -> float:
        if len(ohlc_data) < window + 1:
            return 0.0
        high = ohlc_data['high']
        low = ohlc_data['low']
        close = ohlc_data['close']
        tr = np.maximum(high - low, np.maximum(abs(high - close.shift()), abs(low - close.shift())))
        atr = tr.rolling(window=window).mean()
        return atr.iloc[-1]
    
    def _calculate_momentum(self, series: pd.Series, window: int) -> float:
        if len(series) < window:
            return 0.0
        return series.iloc[-1] - series.iloc[-window]


    # Hodgkin-Huxley Neuron Model
class HodgkinHuxleyNeuron:
    def __init__(self):
        self.V = V_rest  # Membrane potential
        self.m = 0.05  # Na+ activation variable
        self.h = 0.6   # Na+ inactivation variable
        self.n = 0.32  # K+ activation variable
        self.last_spike_time = -float('inf')  # Last spike time

    def alpha_m(self, V):
        return 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10))

    def beta_m(self, V):
        return 4.0 * np.exp(-(V + 65) / 18)

    def alpha_h(self, V):
        return 0.07 * np.exp(-(V + 65) / 20)

    def beta_h(self, V):
        return 1.0 / (1 + np.exp(-(V + 35) / 10))

    def alpha_n(self, V):
        return 0.01 * (V + 55) / (1 - np.exp(-(V + 55) / 10))

    def beta_n(self, V):
        return 0.125 * np.exp(-(V + 65) / 80)

    def update_gates(self, V, dt):
        self.m += dt * (self.alpha_m(V) * (1 - self.m) - self.beta_m(V) * self.m)
        self.h += dt * (self.alpha_h(V) * (1 - self.h) - self.beta_h(V) * self.h)
        self.n += dt * (self.alpha_n(V) * (1 - self.n) - self.beta_n(V) * self.n)

    def ion_currents(self, V):
        I_Na = g_Na * self.m**3 * self.h * (V - E_Na)
        I_K = g_K * self.n**4 * (V - E_K)
        I_L = g_L * (V - E_L)
        return I_Na, I_K, I_L

    def plasma_influence(self, current_time):
        return plasma_strength * np.exp(-(current_time - self.last_spike_time) / plasma_decay)

    def update_potential(self, I_ext, dt):
        I_Na, I_K, I_L = self.ion_currents(self.V)
        dV_dt = (-I_Na - I_K - I_L + I_ext) / Cm
        self.V += dV_dt * dt
        if self.V > 30:  # Spike
            self.V = V_rest
            self.last_spike_time = dt

# PyTorch Model with STDP
class BioTradingModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BioTradingModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(0.2)
        self.activation = nn.Tanh()
        self.neurons = [HodgkinHuxleyNeuron() for _ in range(hidden_size)]
        self.last_spike_times = torch.zeros(hidden_size)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.dropout(x)
        x = self.activation(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

    def update_weights_stdp(self, pre_spike_times, post_spike_times):
        for i in range(len(self.fc1.weight)):
            for j in range(len(self.fc1.weight[i])):
                delta_t = post_spike_times[i] - pre_spike_times[j]
                if delta_t > 0:
                    self.fc1.weight[i][j] += A_plus * np.exp(-delta_t / tau_plus)
                elif delta_t < 0:
                    self.fc1.weight[i][j] -= A_minus * np.exp(delta_t / tau_minus)

# Enhanced Trading System
class EnhancedPlasmaBrainTrader:
    def __init__(self, input_size, hidden_size, output_size):
        self.model = BioTradingModel(input_size, hidden_size, output_size)
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.scaler = StandardScaler()
        self.price_history = []
        self.predictions = []
        
    def predict(self, price: float, features: Dict[str, float]):
        self.price_history.append(price)
        feature_values = np.array(list(features.values())).reshape(1, -1)
        normalized_features = self.scaler.fit_transform(feature_values)
        features_tensor = torch.tensor(normalized_features, dtype=torch.float32)
        
        # Model prediction
        self.model.eval()
        with torch.no_grad():
            prediction = self.model(features_tensor)
        self.predictions.append(prediction.item())
        
        # Model training
        if len(self.price_history) > 5:
            self.model.train()
            self.optimizer.zero_grad()
            # Прогнозируем саму цену, а не её изменение
            target = torch.tensor([price], dtype=torch.float32).view(1, 1)
            prediction_train = self.model(features_tensor)
            loss = self.criterion(prediction_train, target)
            loss.backward()
            self.optimizer.step()
        
        return prediction.item()

    def get_stats(self):
        if len(self.predictions) < 2:
            return "Insufficient data"
        # Сравниваем реальные цены с предсказанными
        actual_prices = np.array(self.price_history[1:])  # исключаем первую цену
        predicted_prices = np.array(self.predictions[:-1])
        correlation = np.corrcoef(actual_prices, predicted_prices)[0, 1]
        mse = np.mean((actual_prices - predicted_prices) ** 2)
        return {
            'correlation': correlation,
            'mse': mse,
            'total_trades': len(self.predictions) - 1,
        }

# Main code
if __name__ == "__main__":
    # MT5 initialization
    if not mt5.initialize():
        print("Error: MT5 initialization failed")
        quit()

    # Data loading
    symbol = "EURUSD"
    timeframe = mt5.TIMEFRAME_D1
    start_date = datetime.now() - timedelta(days=365 * 8)
    end_date = datetime.now()
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
    df = pd.DataFrame(rates)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df.set_index('time', inplace=True)
    prices = df['close'].values

    # Split data into training and testing sets
    train_size = int(len(prices) * 0.8)
    train_prices, test_prices = prices[:train_size], prices[train_size:]

    # Iterative improvement
    best_stats = None
    best_config = None

    print("\nStarting training process...")
    for iteration in range(20):
        print(f"\nIteration {iteration + 1}/20")
        input_size = 100  # Number of features
        hidden_size = 64  # Hidden layer size
        output_size = 1   # Output layer (prediction)
        trader = EnhancedPlasmaBrainTrader(input_size, hidden_size, output_size)
        predictions = []
        
        print("Processing training data...")
        for i in range(1, len(train_prices)):
            price = train_prices[i]
            ohlc_data = df.iloc[:i]
            features = MarketFeatures().add_price(price, ohlc_data)
            pred = trader.predict(price, features)
            predictions.append(pred)
            
            if i % 100 == 0:
                print(f"Processed {i}/{len(train_prices)} samples")
        
        stats = trader.get_stats()
        if best_stats is None or stats['correlation'] > best_stats['correlation']:
            best_stats = stats
            best_config = (input_size, hidden_size, output_size)
        print(f"Current correlation: {stats['correlation']:.3f}, MSE: {stats['mse']:.6f}")

    print("\nStarting testing process...")
    test_predictions = []
    for i in range(train_size, len(prices)):
        price = prices[i]
        ohlc_data = df.iloc[:i]
        features = MarketFeatures().add_price(price, ohlc_data)
        pred = trader.predict(price, features)
        test_predictions.append(pred)
        
        if (i - train_size) % 100 == 0:
            print(f"Processed {i - train_size}/{len(prices) - train_size} test samples")
    
    test_actual_changes = np.diff(test_prices)
    test_predicted_changes = np.array(test_predictions[:-1])
    test_correlation = np.corrcoef(test_actual_changes, test_predicted_changes)[0, 1]
    test_mse = np.mean((test_actual_changes - test_predicted_changes) ** 2)

    # Create charts directory
    Path("./charts").mkdir(exist_ok=True)

    # Training process visualization
    plt.figure(figsize=(8, 4))
    plt.plot(range(len(predictions)), predictions, label='Prediction', alpha=0.7)
    plt.plot(range(len(train_prices[1:])), train_prices[1:], label='Actual', alpha=0.7)
    plt.title('Training Process: Prediction vs Reality')
    plt.legend()
    plt.grid(True)
    plt.gcf().set_size_inches(6, 4)
    plt.savefig('./charts/training_process.png', dpi=100, bbox_inches='tight')
    plt.close()

    # Test results visualization
    plt.figure(figsize=(8, 4))
    plt.plot(range(len(test_predictions)), test_predictions, 
             label='Prediction', color='blue', alpha=0.7)
    plt.plot(range(len(test_prices)), test_prices, 
             label='Actual', color='red', alpha=0.7)
    plt.title('Test Results: Prediction vs Reality')
    plt.legend()
    plt.grid(True)
    plt.gcf().set_size_inches(6, 4)
    plt.savefig('./charts/test_results.png', dpi=100, bbox_inches='tight')
    plt.close()

    # Training error visualization
    plt.figure(figsize=(8, 4))
    errors = np.abs(np.array(predictions) - train_prices[1:])
    plt.plot(range(len(errors)), errors, label='Error', color='red')
    plt.title('Training Error Dynamics')
    plt.xlabel('Iteration')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.gcf().set_size_inches(6, 4)
    plt.savefig('./charts/training_error.png', dpi=100, bbox_inches='tight')
    plt.close()

    print("\nBest configuration:")
    print(f"Architecture: {best_config}")
    print(f"Test data correlation: {test_correlation:.3f}, MSE: {test_mse:.6f}")
    print("\nCharts saved in ./charts/ directory")

    # Close MT5
    mt5.shutdown()
