//+------------------------------------------------------------------+
//|                                                  ONNX Test       |
//|                                                   Copyright 2023 |
//|                                               Your Name Here     |
//+------------------------------------------------------------------+
#property copyright   "Copyright 2023, Your Name Here"
#property link        "https://www.mql5.com"
#property version     "1.00"

static vectorf ExtOutputData(1);
vectorf output_data(1);

#include <Trade\Trade.mqh>
CTrade trade;

input double InpLots       = 1.0;    // Lot volume to open a position
input bool   InpUseStops   = true;   // Trade with stop orders
input int    InpTakeProfit = 500;    // Take Profit level
input int    InpStopLoss   = 500;    // Stop Loss level

#resource "\\LSTM\\data_processing_pipeline.onnx" as uchar DataProcessingModel[]
#resource "\\LSTM\\lstm_model.onnx" as uchar PredictionModel[]

#define SAMPLE_SIZE_DATA 160  // Adjusted to match the model's expected input size
#define SAMPLE_SIZE_PRED 60

long     DataProcessingHandle = INVALID_HANDLE;
long     PredictionHandle = INVALID_HANDLE;
datetime ExtNextBar = 0;
int      ExtPredictedClass = -1;

#define PRICE_UP   1
#define PRICE_SAME 2
#define PRICE_DOWN 0

// Expert Advisor initialization
int OnInit()
{
    // Load the data processing ONNX model
    DataProcessingHandle = OnnxCreateFromBuffer(DataProcessingModel, ONNX_DEFAULT);
    if (DataProcessingHandle == INVALID_HANDLE)
    {
        Print("Error creating data processing model OnnxCreateFromBuffer ", GetLastError());
        return(INIT_FAILED);
    }

    // Set input shape for data processing model
    const long input_shape[] = {SAMPLE_SIZE_DATA, 5};  // Adjust based on your model's input dimensions
    if (!OnnxSetInputShape(DataProcessingHandle, ONNX_DEFAULT, input_shape))
    {
        Print("Error setting the input shape OnnxSetInputShape for data processing model ", GetLastError());
        return(INIT_FAILED);
    }

    // Set output shape for data processing model
    const long output_shape[] = {SAMPLE_SIZE_DATA, 5};  // Adjust based on your model's output dimensions
    if (!OnnxSetOutputShape(DataProcessingHandle, 0, output_shape))
    {
        Print("Error setting the output shape OnnxSetOutputShape for data processing model ", GetLastError());
        return(INIT_FAILED);
    }

    // Load the prediction ONNX model
    PredictionHandle = OnnxCreateFromBuffer(PredictionModel, ONNX_DEFAULT);
    if (PredictionHandle == INVALID_HANDLE)
    {
        Print("Error creating prediction model OnnxCreateFromBuffer ", GetLastError());
        return(INIT_FAILED);
    }

    // Set input shape for prediction model
    const long prediction_input_shape[] = {SAMPLE_SIZE_PRED, 1, 5};  // Adjust based on your model's input dimensions
    if (!OnnxSetInputShape(PredictionHandle, ONNX_DEFAULT, prediction_input_shape))
    {
        Print("Error setting the input shape OnnxSetInputShape for prediction model ", GetLastError());
        return(INIT_FAILED);
    }

    // Set output shape for prediction model
    const long prediction_output_shape[] = {1};  // Adjust based on your model's output dimensions
    if (!OnnxSetOutputShape(PredictionHandle, 0, prediction_output_shape))
    {
        Print("Error setting the output shape OnnxSetOutputShape for prediction model ", GetLastError());
        return(INIT_FAILED);
    }

    return(INIT_SUCCEEDED);
}

// Expert Advisor deinitialization
void OnDeinit(const int reason)
{
    if (DataProcessingHandle != INVALID_HANDLE)
    {
        OnnxRelease(DataProcessingHandle);
        DataProcessingHandle = INVALID_HANDLE;
    }

    if (PredictionHandle != INVALID_HANDLE)
    {
        OnnxRelease(PredictionHandle);
        PredictionHandle = INVALID_HANDLE;
    }
}

// Process the tick function
void OnTick()
{
    if (TimeCurrent() < ExtNextBar)
        return;

    ExtNextBar = TimeCurrent();
    ExtNextBar -= ExtNextBar % PeriodSeconds();
    ExtNextBar += PeriodSeconds();

    // Fetch new data and run the data processing ONNX model
    vectorf input_data = ProcessData(DataProcessingHandle);
    if (input_data.Size() == 0)
    {
        Print("Error processing data");
        return;
    }

    // Make predictions using the prediction ONNX model
    double predictions[SAMPLE_SIZE_DATA - SAMPLE_SIZE_PRED + 1];
    for (int i = 0; i < SAMPLE_SIZE_DATA - SAMPLE_SIZE_PRED + 1; i++)
    {
        double prediction = MakePrediction(input_data, PredictionHandle, i, SAMPLE_SIZE_PRED);
        double min_price = iLow(Symbol(), PERIOD_D1, 0); // price is relative to the day's price therefore we use low of day for min price
        double max_price = iHigh(Symbol(), PERIOD_D1, 0); // high of day for max price
        double price = prediction * (max_price - min_price) + min_price;
        predictions[i] = price;
        PrintFormat("Predicted close price (index %d): %f", i, predictions[i]);
    }

    // Determine the trading signal
    DetermineSignal(predictions);

    // Execute trades based on the signal
    if (ExtPredictedClass >= 0)
        if (PositionSelect(_Symbol))
            CheckForClose();
        else
            CheckForOpen();
}

// Function to determine the trading signal
void DetermineSignal(double &predictions[])
{
    double spread = GetSpreadInPips(_Symbol);
    double predicted = predictions[SAMPLE_SIZE_DATA - SAMPLE_SIZE_PRED]; // Use the last prediction for decision making

    if (spread < 0.000005 && predicted > iClose(Symbol(), PERIOD_M15, 1))
    {
        ExtPredictedClass = PRICE_UP;
    }
    else if (spread < 0.000005 && predicted < iClose(Symbol(), PERIOD_M15, 1))
    {
        ExtPredictedClass = PRICE_DOWN;
    }
    else
    {
        ExtPredictedClass = PRICE_SAME;
    }
}

// Check position opening conditions
void CheckForOpen(void)
{
    ENUM_ORDER_TYPE signal = WRONG_VALUE;

    if (ExtPredictedClass == PRICE_DOWN)
        signal = ORDER_TYPE_SELL;
    else if (ExtPredictedClass == PRICE_UP)
        signal = ORDER_TYPE_BUY;

    if (signal != WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
    {
        double price, sl = 0, tp = 0;
        double bid = SymbolInfoDouble(_Symbol, SYMBOL_BID);
        double ask = SymbolInfoDouble(_Symbol, SYMBOL_ASK);
        if (signal == ORDER_TYPE_SELL)
        {
            price = bid;
            if (InpUseStops)
            {
                sl = NormalizeDouble(bid + InpStopLoss * _Point, _Digits);
                tp = NormalizeDouble(ask - InpTakeProfit * _Point, _Digits);
            }
        }
        else
        {
            price = ask;
            if (InpUseStops)
            {
                sl = NormalizeDouble(ask - InpStopLoss * _Point, _Digits);
                tp = NormalizeDouble(bid + InpTakeProfit * _Point, _Digits);
            }
        }
        trade.PositionOpen(_Symbol, signal, InpLots, price, sl, tp);
    }
}

// Check position closing conditions
void CheckForClose(void)
{
    if (InpUseStops)
        return;

    bool tsignal = false;
    long type = PositionGetInteger(POSITION_TYPE);

    if (type == POSITION_TYPE_BUY && ExtPredictedClass == PRICE_DOWN)
        tsignal = true;
    if (type == POSITION_TYPE_SELL && ExtPredictedClass == PRICE_UP)
        tsignal = true;

    if (tsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
    {
        trade.PositionClose(_Symbol, 3);
        CheckForOpen();
    }
}

// Function to get the current spread
double GetSpreadInPips(string symbol)
{
    double spreadPoints = SymbolInfoInteger(symbol, SYMBOL_SPREAD);
    double spreadPips = spreadPoints * _Point / _Digits;
    return spreadPips;
}

// Function to process data using the data processing ONNX model
vectorf ProcessData(long data_processing_handle)
{
   MqlRates rates[SAMPLE_SIZE_DATA];
   vectorf blank_vector;
   int copied = CopyRates(_Symbol, PERIOD_M15, 1, SAMPLE_SIZE_DATA, rates);
   if (copied != SAMPLE_SIZE_DATA)
   {
      Print("Failed to copy the expected number of rates. Expected: ", SAMPLE_SIZE_DATA, ", Copied: ", copied);
      return blank_vector;
   }

   float input_data[SAMPLE_SIZE_DATA * 5];
   for (int i = 0; i < copied; i++)
   {
      // Normalize time to be between 0 and 1 within a day
      input_data[i * 5 + 0] = (float)((rates[i].time));  // normalized time
      input_data[i * 5 + 1] = (float)rates[i].open;  // open
      input_data[i * 5 + 2] = (float)rates[i].high;  // high
      input_data[i * 5 + 3] = (float)rates[i].low;   // low
      input_data[i * 5 + 4] = (float)rates[i].close; // close
   }

   vectorf input_vector;
   input_vector.Resize(copied * 5);
   for (int i = 0; i < copied * 5; i++)
   {
      input_vector[i] = input_data[i];
   }

   vectorf output_vector;
   output_vector.Resize(copied * 5);

   if (!OnnxRun(data_processing_handle, ONNX_NO_CONVERSION, input_vector, output_vector))
   {
      Print("Error running the data processing ONNX model: ", GetLastError());
      return blank_vector;
   }

   return output_vector;
}


// Function to make predictions using the prediction ONNX model
double MakePrediction(const vectorf& input_data, long prediction_handle, int start_index, int size)
{
   vectorf input_subset;
   input_subset.Resize(size * 5);
   for (int i = 0; i < size * 5; i++)
   {
      input_subset[i] = input_data[start_index * 5 + i];
   }

   vectorf output_vector;
   output_vector.Resize(1);

   if (!OnnxRun(prediction_handle, ONNX_NO_CONVERSION, input_subset, output_vector))
   {
      Print("Error running the prediction ONNX model: ", GetLastError());
      return -1.0;
   }

   // Extract the normalized close price from the output data
   double norm_close = output_vector[0];

   return norm_close;
}