import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import MetaTrader5 as mt5
from datetime import datetime
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_percentage_error
import time

class PriceEquationModel:
    def __init__(self):
        self.coefficients = None
        self.training_scores = []
        self.optimization_progress = []

    def equation(self, x_prev, coeffs):
        """
        Non-linear equation for EURUSD price forecast.
        x_prev: array of previous prices [p(t-1), p(t-2)]
        """
        x_t1, x_t2 = x_prev[0], x_prev[1]

        return (coeffs[0] * x_t1 +          # linear term at a previous price
                coeffs[1] * x_t1**2 +       # quadratic term
                coeffs[2] * x_t2 +          # linear term at a price before a previous one
                coeffs[3] * x_t2**2 +       # quadratic term
                coeffs[4] * (x_t1 - x_t2) + # price change
                coeffs[5] * np.sin(x_t1) +  # cyclic component
                coeffs[6])                  # constant

    def loss_function(self, coeffs, X_train, y_train):
        """Loss function with progress display"""
        y_pred = np.array([self.equation(x, coeffs) for x in X_train])
        mse = np.mean((y_pred - y_train)**2)
        r2 = r2_score(y_train, y_pred)
        self.optimization_progress.append({'mse': mse, 'r2': r2, 'coeffs': coeffs.copy()})
        return mse

    def fetch_data(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1, bars=10000):
        """Download data from MetaTrader5"""
        print("\nConnection to MetaTrader5...")
        if not mt5.initialize():
            raise ConnectionError("MT5 initialization error")

        print(f"Download data {symbol}...")
        rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, bars)
        if rates is None:
            raise ValueError("Failed to receive data")

        df = pd.DataFrame(rates)
        df['time'] = pd.to_datetime(df['time'], unit='s')
        
        print(f"Downloaded {len(df)} entries")
        mt5.shutdown()
        
        return df['close'].values

    def fit(self, prices):
        """Optimizing equation ratios"""
        print("\nPreparing data for training...")
        n_train = len(prices) * 2 // 3
        X_train = np.array([[prices[i], prices[i-1]] for i in range(2, n_train-1)])
        y_train = prices[3:n_train]
        
        print(f"Training sample size: {len(X_train)} entries")

        # Initial ratio values
        initial_coeffs = np.array([0.5, 0.1, 0.3, 0.1, 0.2, 0.1, 0.0])
        
        print("\nStart optimizing ratios...")
        start_time = time.time()
        
        def callback(xk):
            elapsed = time.time() - start_time
            iteration = len(self.optimization_progress)
            current_score = self.optimization_progress[-1]
            print(f"\nIteration {iteration}:")
            print(f"Time: {elapsed:.2f} s")
            print(f"MSE: {current_score['mse']:.8f}")
            print(f"R²: {current_score['r2']:.4f}")
            print("Ratios:", xk)

        result = minimize(
            self.loss_function,
            initial_coeffs,
            args=(X_train, y_train),
            method='Nelder-Mead',
            callback=callback,
            options={'maxiter': 10000}
        )

        self.coefficients = result.x
        
        # Final assessment on a training sample
        y_pred_train = np.array([self.equation(x, self.coefficients) for x in X_train])
        train_r2 = r2_score(y_train, y_pred_train)
        train_mse = mean_squared_error(y_train, y_pred_train)
        train_mape = mean_absolute_percentage_error(y_train, y_pred_train) * 100
        
        print("\nTraining complete!")
        print(f"Spent time: {time.time() - start_time:.2f} s")
        print("\nTraining sample metrics:")
        print(f"R²: {train_r2:.4f}")
        print(f"MSE: {train_mse:.8f}")
        print(f"MAPE: {train_mape:.2f}%")
        
        return result.x

    def predict(self, last_prices):
        """Next price forecast"""
        if self.coefficients is None:
            raise ValueError("Model not trained")
        return self.equation(last_prices, self.coefficients)

def analyze_predictions(model, prices, start_index):
    """Analyzing forecast accuracy on data forward part"""
    print("\nStart forward test...")
    predictions = []
    actuals = []
    total = len(prices) - start_index - 1
    
    for i in range(start_index, len(prices)-1):
        if (i - start_index) % 100 == 0:  # Progress every 100 steps
            progress = (i - start_index) / total * 100
            print(f"Forward test progress: {progress:.1f}%")
            
        last_prices = [prices[i], prices[i-1]]
        pred = model.predict(last_prices)
        predictions.append(pred)
        actuals.append(prices[i+1])

    predictions = np.array(predictions)
    actuals = np.array(actuals)
    
    # Calculate metrics
    r2 = r2_score(actuals, predictions)
    mse = mean_squared_error(actuals, predictions)
    mape = mean_absolute_percentage_error(actuals, predictions) * 100
    
    print("\nForward test results:")
    print(f"R²: {r2:.4f}")
    print(f"MSE: {mse:.8f}")
    print(f"MAPE: {mape:.2f}%")
    
    return predictions, actuals

def main():
   # Create model
   model = PriceEquationModel()
   
   try:
       # Load data
       prices = model.fetch_data()
       
       # Train model
       coeffs = model.fit(prices)
       
       # Print resulting equation
       print("\nResulting equation:")
       print(f"X(t) = {coeffs[0]:.4f}·X(t-1) + {coeffs[1]:.4f}·X(t-1)² + ")
       print(f"       {coeffs[2]:.4f}·X(t-2) + {coeffs[3]:.4f}·X(t-2)² + ")
       print(f"       {coeffs[4]:.4f}·(X(t-1) - X(t-2)) + ")
       print(f"       {coeffs[5]:.4f}·sin(X(t-1)) + {coeffs[6]:.4f}")
       
       # Forward testing
       start_forward_test = len(prices) * 2 // 3
       predictions, actuals = analyze_predictions(model, prices, start_forward_test)
       
       # Set figure width to 750 pixels (divide by DPI to get inches)
       width_px = 750
       dpi = plt.rcParams['figure.dpi']
       width_in = width_px / dpi
       height_in = (width_in * 10) / 15  # Keep aspect ratio
       
       # First plot - Comparison
       plt.figure(figsize=(width_in, height_in))
       plt.plot(actuals, label='Actual Prices')
       plt.plot(predictions, label='Forecast')
       plt.title('Comparison of Forecast vs Actual Prices on Forward Test')
       plt.legend()
       plt.grid(True)
       plt.tight_layout()
       plt.savefig('forecast_comparison.png', dpi=dpi, bbox_inches='tight')
       plt.close()
       
       # Second plot - Errors
       plt.figure(figsize=(width_in, height_in))
       errors = (predictions - actuals) / actuals
       plt.plot(errors, label='Forecast Error')
       plt.title('Forecast Error')
       plt.axhline(y=0, color='r', linestyle='--')
       plt.legend()
       plt.grid(True)
       plt.tight_layout()
       plt.savefig('forecast_error.png', dpi=dpi, bbox_inches='tight')
       plt.close()
       
   except Exception as e:
       print(f"An error occurred: {str(e)}")

class TradingSystem:
    def __init__(self, model: PriceEquationModel, symbol="EURUSD", lot_size=0.1):
        self.model = model
        self.symbol = symbol
        self.lot_size = lot_size
        self._setup_connection()
        
    def _setup_connection(self):
        """Initialize connection to MT5"""
        if not mt5.initialize():
            raise ConnectionError("MT5 connection error")
            
        symbol_info = mt5.symbol_info(self.symbol)
        if symbol_info is None or not symbol_info.visible:
            mt5.symbol_select(self.symbol, True)

    def calculate_atr(self, period=14, timeframe=mt5.TIMEFRAME_H1):
        """ATR calculation"""
        rates = mt5.copy_rates_from_pos(self.symbol, timeframe, 0, period + 1)
        if rates is None:
            raise ValueError("Failed to receive ATR calculation data")
            
        df = pd.DataFrame(rates)
        df['high_low'] = df['high'] - df['low']
        df['high_close'] = np.abs(df['high'] - df['close'].shift(1))
        df['low_close'] = np.abs(df['low'] - df['close'].shift(1))
        df['tr'] = df[['high_low', 'high_close', 'low_close']].max(axis=1)
        atr = df['tr'].mean()
        return atr

    def get_model_prediction(self):
        """Get model forecast"""
        rates = mt5.copy_rates_from_pos(self.symbol, mt5.TIMEFRAME_H1, 0, 3)
        if rates is None:
            raise ValueError("Failed to get forecast data")
        
        prices = pd.DataFrame(rates)['close'].values
        last_prices = [prices[1], prices[0]]  # [p(t-1), p(t-2)]
        return self.model.predict(last_prices)

    def open_position(self):
        """Open position based on model signal"""
        try:
            # Get forecast
            predicted_price = self.get_model_prediction()
            current_price = mt5.symbol_info_tick(self.symbol).ask
            
            # Define deal direction
            signal = "BUY" if predicted_price > current_price else "SELL"
            
            # Calculate ATR for stop levels
            atr = self.calculate_atr()
            
            symbol_info = mt5.symbol_info(self.symbol)
            point = symbol_info.point
            
            # Set stop levels based on ATR
            if signal == "BUY":
                price = mt5.symbol_info_tick(self.symbol).ask
                sl_level = price - atr  # Stop level at a distance of 1 ATR
                tp_level = price + (atr / 3)  # Take profit at a distance of 1/3 ATR
            else:
                price = mt5.symbol_info_tick(self.symbol).bid
                sl_level = price + atr
                tp_level = price - (atr / 3)
            
            # Check existing positions
            positions = mt5.positions_get(symbol=self.symbol)
            if positions:
                print(f"Open position already present at {self.symbol}")
                return None
            
            # Form position open request
            request = {
                "action": mt5.TRADE_ACTION_DEAL,
                "symbol": self.symbol,
                "volume": self.lot_size,
                "type": mt5.ORDER_TYPE_BUY if signal == "BUY" else mt5.ORDER_TYPE_SELL,
                "price": price,
                "sl": sl_level,
                "tp": tp_level,
                "deviation": 20,
                "magic": 234000,
                "comment": f"pred:{predicted_price:.6f}",
                "type_time": mt5.ORDER_TIME_GTC,
                "type_filling": mt5.ORDER_FILLING_FOK,
            }

            result = mt5.order_send(request)
            
            if result.retcode != mt5.TRADE_RETCODE_DONE:
                print(f"Error opening position: {result.retcode}")
                return None
                
            print(f"Opened position {signal}: price={price:.5f}, SL={sl_level:.5f}, "
                  f"TP={tp_level:.5f}, ATR={atr:.5f}, lot={self.lot_size}")
            
            return result.order

        except Exception as e:
            print(f"Error opening position: {str(e)}")
            return None

    def trading_loop(self):
        """Main trading loop"""
        print("\nLaunch trading system...")
        
        while True:
            try:
                # Check existing positions
                positions = mt5.positions_get(symbol=self.symbol)
                
                if not positions:
                    # Open a new position if no open positions exist
                    print("\nSearching for a new trading opportunity...")
                    self.open_position()
                
                time.sleep(60)  # Check every minute

            except Exception as e:
                print(f"Trading loop error: {str(e)}")
                time.sleep(300)

def run_trading_system():
    """Launch trading system with trained model"""
    try:
        # Create and train model
        model = PriceEquationModel()
        prices = model.fetch_data()
        model.fit(prices)
        
        # Create a trading system
        trading_system = TradingSystem(model)
        
        # Launch a trading loop
        trading_system.trading_loop()
        
    except Exception as e:
        print(f"Critical error: {str(e)}")
    finally:
        mt5.shutdown()

if __name__ == "__main__":
    main()
    run_trading_system()
