import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import onnxruntime as ort
from sklearn.preprocessing import MinMaxScaler
from ta.trend import PSARIndicator, SMAIndicator
from ta.momentum import RSIIndicator
from ta.volatility import AverageTrueRange
import matplotlib.pyplot as plt

# Inicializar conexión con MetaTrader5
if not mt5.initialize():
    print("Inicialización fallida")
    mt5.shutdown()

def get_historical_data(symbol, timeframe, start_date, end_date):
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
    df = pd.DataFrame(rates)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    df.set_index('time', inplace=True)
    return df

def calculate_heikin_ashi(df):
    ha_close = (df['open'] + df['high'] + df['low'] + df['close']) / 4
    ha_open = (df['open'].shift(1) + df['close'].shift(1)) / 2
    ha_high = df[['high', 'open', 'close']].max(axis=1)
    ha_low = df[['low', 'open', 'close']].min(axis=1)
    
    df['ha_close'] = ha_close
    df['ha_open'] = ha_open
    df['ha_high'] = ha_high
    df['ha_low'] = ha_low
    return df

def add_indicators(df):
    # Calcular Heikin Ashi
    df = calculate_heikin_ashi(df)
    
    # PSAR con parámetros ajustados
    psar = PSARIndicator(df['high'], df['low'], df['close'], step=0.02, max_step=0.2)
    df['psar'] = psar.psar()
    
    # Añadir SMA
    sma = SMAIndicator(df['close'], window=50)
    df['sma'] = sma.sma_indicator()
    
    # Añadir RSI
    rsi = RSIIndicator(df['close'], window=14)
    df['rsi'] = rsi.rsi()
    
    # Añadir ATR para medir volatilidad
    atr = AverageTrueRange(df['high'], df['low'], df['close'], window=14)
    df['atr'] = atr.average_true_range()
    
    # Añadir filtro de tendencia simple
    df['trend'] = np.where(df['close'] > df['sma'], 1, -1)
    
    return df

def prepare_data(df, window_size=120):
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaled_data = scaler.fit_transform(df[['close']])
    
    X = []
    for i in range(window_size, len(scaled_data)):
        X.append(scaled_data[i-window_size:i])
    
    return np.array(X), scaler

def load_onnx_model(model_path):
    return ort.InferenceSession(model_path)

def predict_with_onnx(model, input_data):
    input_name = model.get_inputs()[0].name
    output_name = model.get_outputs()[0].name
    return model.run([output_name], {input_name: input_data})[0]

def backtest(df, model, scaler, window_size=120, initial_balance=10000):
    scaled_data = scaler.transform(df[['close']])
    
    predictions = []
    for i in range(window_size, len(scaled_data)):
        X = scaled_data[i-window_size:i].reshape(1, window_size, 1)
        pred = predict_with_onnx(model, X.astype(np.float32))
        predictions.append(scaler.inverse_transform(pred.reshape(-1, 1))[0, 0])
    
    df['predictions'] = [np.nan]*window_size + predictions
    
    # Nueva lógica de trading
    df['position'] = 0
    long_condition = (
        (df['close'] > df['predictions']) & 
        (df['close'] > df['psar']) & 
        (df['close'] > df['sma']) & 
        (df['rsi'] < 60) &  # Condición RSI menos estricta
        (df['ha_close'] > df['ha_open']) &
        (df['ha_close'].shift(1) > df['ha_open'].shift(1)) &
        (df['trend'] == 1)  # Solo comprar en tendencia alcista
    )
    short_condition = (
        (df['close'] < df['predictions']) & 
        (df['close'] < df['psar']) & 
        (df['close'] < df['sma']) & 
        (df['rsi'] > 40) &  # Condición RSI menos estricta
        (df['ha_close'] < df['ha_open']) &
        (df['ha_close'].shift(1) < df['ha_open'].shift(1)) &
        (df['trend'] == -1)  # Solo vender en tendencia bajista
    )
    
    df.loc[long_condition, 'position'] = 1  # Compra
    df.loc[short_condition, 'position'] = -1  # Venta
    
    # Implementar stop-loss y take-profit adaptativos
    sl_atr_multiple = 2
    tp_atr_multiple = 3
    
    for i in range(window_size, len(df)):
        if df['position'].iloc[i-1] != 0:
            entry_price = df['close'].iloc[i-1]
            current_atr = df['atr'].iloc[i-1]
            if df['position'].iloc[i-1] == 1:  # Posición larga
                sl_price = entry_price - sl_atr_multiple * current_atr
                tp_price = entry_price + tp_atr_multiple * current_atr
                if df['low'].iloc[i] < sl_price or df['high'].iloc[i] > tp_price:
                    df.loc[df.index[i], 'position'] = 0
            else:  # Posición corta
                sl_price = entry_price + sl_atr_multiple * current_atr
                tp_price = entry_price - tp_atr_multiple * current_atr
                if df['high'].iloc[i] > sl_price or df['low'].iloc[i] < tp_price:
                    df.loc[df.index[i], 'position'] = 0
    
    df['returns'] = df['close'].pct_change()
    df['strategy_returns'] = df['position'].shift(1) * df['returns']
    
    # Calcular balance
    df['cumulative_returns'] = (1 + df['strategy_returns']).cumprod()
    df['balance'] = initial_balance * df['cumulative_returns']
    
    return df

def visualize_results(df):
    # Encontrar el primer índice con datos válidos
    start_index = df['predictions'].first_valid_index()
    
    # Crear un nuevo DataFrame con solo los datos válidos
    plot_df = df.loc[start_index:].copy()
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(7, 20), sharex=True)
    
    # Gráfica de precios, predicciones y señales
    ax1.plot(plot_df.index, plot_df['close'], label='Precio de cierre', color='blue')
    ax1.plot(plot_df.index, plot_df['predictions'], label='Predicciones', color='red', alpha=0.7)
    ax1.scatter(plot_df.index, plot_df['psar'], label='PSAR', color='green', alpha=0.5, s=5)
    ax1.plot(plot_df.index, plot_df['sma'], label='SMA', color='orange', alpha=0.7)
    ax1.plot(plot_df.index[plot_df['position'] == 1], 
             plot_df['close'][plot_df['position'] == 1], 
             '^', markersize=10, color='g', label='Compra')
    ax1.plot(plot_df.index[plot_df['position'] == -1], 
             plot_df['close'][plot_df['position'] == -1], 
             'v', markersize=10, color='r', label='Venta')
    
    ax1.set_title('Backtesting Results')
    ax1.set_ylabel('Precio')
    ax1.legend()
    ax1.grid(True)
    
    # Gráfica de Heikin Ashi
    ax2.plot(plot_df.index, plot_df['ha_close'], label='HA Close', color='blue')
    ax2.plot(plot_df.index, plot_df['ha_open'], label='HA Open', color='red', alpha=0.7)
    ax2.set_title('Heikin Ashi')
    ax2.set_ylabel('Precio')
    ax2.legend()
    ax2.grid(True)
    
    # Gráfica del balance
    ax3.plot(plot_df.index, plot_df['balance'], label='Balance', color='purple')
    ax3.set_title('Balance a lo largo del tiempo')
    ax3.set_xlabel('Fecha')
    ax3.set_ylabel('Balance')
    ax3.legend()
    ax3.grid(True)
    
    plt.tight_layout()
    plt.savefig("graph.png")
    plt.show()

# Parámetros principales
symbol = "EURUSD"
timeframe = mt5.TIMEFRAME_H1
start_date = pd.to_datetime("2024-01-01")
end_date = pd.to_datetime("2024-06-01")
onnx_model_path = "C:/Users/jsgas/OneDrive/TRADING/ARTICULOS/24_PSAR_Heiken/EURUSD_D1_2024.onnx"  # Asegúrate de poner la ruta correcta aquí
window_size = 120
initial_balance = 10000

# Obtener datos y añadir indicadores
df = get_historical_data(symbol, timeframe, start_date, end_date)
df = add_indicators(df)

# Preparar datos y cargar modelo ONNX
X, scaler = prepare_data(df, window_size)
onnx_model = load_onnx_model(onnx_model_path)

# Realizar backtesting
results = backtest(df, onnx_model, scaler, window_size, initial_balance)

# Calcular métricas de rendimiento
total_return = (results['strategy_returns'] + 1).prod() - 1
sharpe_ratio = results['strategy_returns'].mean() / results['strategy_returns'].std() * np.sqrt(252)

print(f"Retorno total: {total_return:.2%}")
print(f"Ratio de Sharpe: {sharpe_ratio:.2f}")
print(f"Balance final: ${results['balance'].iloc[-1]:.2f}")

# Visualizar resultados
visualize_results(results)

# Cerrar conexión con MetaTrader5
mt5.shutdown()