import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import onnxruntime as ort
from sklearn.preprocessing import MinMaxScaler
from ta.trend import PSARIndicator, SMAIndicator
from ta.momentum import RSIIndicator
from ta.volatility import AverageTrueRange
import matplotlib.pyplot as plt

# Initialize connection with MetaTrader5
if not mt5.initialize():
    print("Initialization failed")
    mt5.shutdown()

def get_historical_data(symbol, timeframe, start_date, end_date):
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
    df = pd.DataFrame(rates)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df.set_index('time', inplace=True)
    return df

def calculate_heikin_ashi(df):
    ha_close = (df['open'] + df['high'] + df['low'] + df['close']) / 4
    ha_open = (df['open'].shift(1) + df['close'].shift(1)) / 2
    ha_high = df[['high', 'open', 'close']].max(axis=1)
    ha_low = df[['low', 'open', 'close']].min(axis=1)
    
    df['ha_close'] = ha_close
    df['ha_open'] = ha_open
    df['ha_high'] = ha_high
    df['ha_low'] = ha_low
    return df

def add_indicators(df):
    # Calculate Heikin Ashi
    df = calculate_heikin_ashi(df)
    
    # PSAR with adjusted parameters
    psar = PSARIndicator(df['high'], df['low'], df['close'], step=0.02, max_step=0.2)
    df['psar'] = psar.psar()
    
    # Add SMA
    sma = SMAIndicator(df['close'], window=50)
    df['sma'] = sma.sma_indicator()
    
    # Add RSI
    rsi = RSIIndicator(df['close'], window=14)
    df['rsi'] = rsi.rsi()
    
    # Add ATR to measure volatility
    atr = AverageTrueRange(df['high'], df['low'], df['close'], window=14)
    df['atr'] = atr.average_true_range()
    
    # Add simple trend filter
    df['trend'] = np.where(df['close'] > df['sma'], 1, -1)
    
    return df


def prepare_data(df, window_size=120):
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaled_data = scaler.fit_transform(df[['close']])
    
    X = []
    for i in range(window_size, len(scaled_data)):
        X.append(scaled_data[i-window_size:i])
    
    return np.array(X), scaler

def load_onnx_model(model_path):
    return ort.InferenceSession(model_path)

def predict_with_onnx(model, input_data):
    input_name = model.get_inputs()[0].name
    output_name = model.get_outputs()[0].name
    return model.run([output_name], {input_name: input_data})[0]

def backtest(df, model, scaler, window_size=120, initial_balance=10000):
    scaled_data = scaler.transform(df[['close']])
    
    predictions = []
    for i in range(window_size, len(scaled_data)):
        X = scaled_data[i-window_size:i].reshape(1, window_size, 1)
        pred = predict_with_onnx(model, X.astype(np.float32))
        predictions.append(scaler.inverse_transform(pred.reshape(-1, 1))[0, 0])
    
    df['predictions'] = [np.nan]*window_size + predictions
    
    # Nueva lógica de trading
    df['position'] = 0
    long_condition = (
        (df['close'] > df['predictions']) & 
        (df['close'] > df['psar']) & 
        (df['close'] > df['sma']) & 
        (df['rsi'] < 60) &  # Condición RSI menos estricta
        (df['ha_close'] > df['ha_open']) &
        (df['ha_close'].shift(1) > df['ha_open'].shift(1)) &
        (df['trend'] == 1)  # Solo comprar en tendencia alcista
    )
    short_condition = (
        (df['close'] < df['predictions']) & 
        (df['close'] < df['psar']) & 
        (df['close'] < df['sma']) & 
        (df['rsi'] > 40) &  # Condición RSI menos estricta
        (df['ha_close'] < df['ha_open']) &
        (df['ha_close'].shift(1) < df['ha_open'].shift(1)) &
        (df['trend'] == -1)  # Solo vender en tendencia bajista
    )
    
    df.loc[long_condition, 'position'] = 1  # Compra
    df.loc[short_condition, 'position'] = -1  # Venta
    
    # Implementar stop-loss y take-profit adaptativos
    sl_atr_multiple = 2
    tp_atr_multiple = 3
    
    for i in range(window_size, len(df)):
        if df['position'].iloc[i-1] != 0:
            entry_price = df['close'].iloc[i-1]
            current_atr = df['atr'].iloc[i-1]
            if df['position'].iloc[i-1] == 1:  # Posición larga
                sl_price = entry_price - sl_atr_multiple * current_atr
                tp_price = entry_price + tp_atr_multiple * current_atr
                if df['low'].iloc[i] < sl_price or df['high'].iloc[i] > tp_price:
                    df.loc[df.index[i], 'position'] = 0
            else:  # Posición corta
                sl_price = entry_price + sl_atr_multiple * current_atr
                tp_price = entry_price - tp_atr_multiple * current_atr
                if df['high'].iloc[i] > sl_price or df['low'].iloc[i] < tp_price:
                    df.loc[df.index[i], 'position'] = 0
    
    df['returns'] = df['close'].pct_change()
    df['strategy_returns'] = df['position'].shift(1) * df['returns']
    
    # Calcular balance
    df['cumulative_returns'] = (1 + df['strategy_returns']).cumprod()
    df['balance'] = initial_balance * df['cumulative_returns']
    
    return df

def visualize_results(df):
    # Find the first index with valid data
    start_index = df['predictions'].first_valid_index()
    
    # Create a new DataFrame with only valid data
    plot_df = df.loc[start_index:].copy()
    
    # Plot 1: Prices, predictions, and signals
    fig1, ax1 = plt.subplots(figsize=(7, 6))
    ax1.plot(plot_df.index, plot_df['close'], label='Closing Price', color='blue')
    ax1.plot(plot_df.index, plot_df['predictions'], label='Predictions', color='red', alpha=0.7)
    ax1.scatter(plot_df.index, plot_df['psar'], label='PSAR', color='green', alpha=0.5, s=5)
    ax1.plot(plot_df.index, plot_df['sma'], label='SMA', color='orange', alpha=0.7)
    ax1.plot(plot_df.index[plot_df['position'] == 1], 
             plot_df['close'][plot_df['position'] == 1], 
             '^', markersize=10, color='g', label='Buy')
    ax1.plot(plot_df.index[plot_df['position'] == -1], 
             plot_df['close'][plot_df['position'] == -1], 
             'v', markersize=10, color='r', label='Sell')
    
    ax1.set_title('Backtesting Results')
    ax1.set_ylabel('Price')
    ax1.legend()
    ax1.grid(True)
    plt.savefig("graph1_backtesting_results.png")
    plt.close(fig1)
    
    # Plot 2: Heikin Ashi
    fig2, ax2 = plt.subplots(figsize=(7, 6))
    ax2.plot(plot_df.index, plot_df['ha_close'], label='HA Close', color='blue')
    ax2.plot(plot_df.index, plot_df['ha_open'], label='HA Open', color='red', alpha=0.7)
    ax2.set_title('Heikin Ashi')
    ax2.set_ylabel('Price')
    ax2.legend()
    ax2.grid(True)
    plt.savefig("graph2_heikin_ashi.png")
    plt.close(fig2)
    
    # Plot 3: Balance over time
    fig3, ax3 = plt.subplots(figsize=(7, 6))
    ax3.plot(plot_df.index, plot_df['balance'], label='Balance', color='purple')
    ax3.set_title('Balance Over Time')
    ax3.set_xlabel('Date')
    ax3.set_ylabel('Balance')
    ax3.legend()
    ax3.grid(True)
    plt.savefig("graph3_balance_over_time.png")
    plt.close(fig3)

# Main parameters
symbol = "EURUSD"
timeframe = mt5.TIMEFRAME_H1
start_date = pd.to_datetime("2024-01-01")
end_date = pd.to_datetime("2024-06-01")
onnx_model_path = "C:/Users/jsgas/OneDrive/TRADING/ARTICULOS/24_PSAR_Heiken/EURUSD_D1_2024.onnx"  # Make sure to put the correct path here
window_size = 120
initial_balance = 10000

# Get data and add indicators
df = get_historical_data(symbol, timeframe, start_date, end_date)
df = add_indicators(df)

# Prepare data and load ONNX model
X, scaler = prepare_data(df, window_size)
onnx_model = load_onnx_model(onnx_model_path)

# Perform backtesting
results = backtest(df, onnx_model, scaler, window_size, initial_balance)

# Calculate performance metrics
total_return = (results['strategy_returns'] + 1).prod() - 1
sharpe_ratio = results['strategy_returns'].mean() / results['strategy_returns'].std() * np.sqrt(252)

print(f"Total return: {total_return:.2%}")
print(f"Sharpe ratio: {sharpe_ratio:.2f}")
print(f"Final balance: ${results['balance'].iloc[-1]:.2f}")

# Visualize results
visualize_results(results)

# Close connection with MetaTrader5
mt5.shutdown()