//+------------------------------------------------------------------+
//|                                      SP500 X Treasury Yields.mq5 |
//|                                        Gamuchirai Zororo Ndawana |
//|                          https://www.mql5.com/en/gamuchiraindawa |
//+------------------------------------------------------------------+
#property copyright "Gamuchirai Zororo Ndawana"
#property link      "https://www.mql5.com/en/gamuchiraindawa"
#property version   "1.00"
#property tester_file "sp500_treasury_yields_scale.csv"

//+------------------------------------------------------------------+
//| Require the ONNX model                                           |
//+------------------------------------------------------------------+
#resource "\\Files\\SP500_ONNX_FLOAT_M1.onnx" as const uchar ModelBuffer[];

//+------------------------------------------------------------------+
//| Libraries we need                                                |
//+------------------------------------------------------------------+
#include <Trade/Trade.mqh>
CTrade Trade;

//+------------------------------------------------------------------+
//| Inputs for our EA                                                |
//+------------------------------------------------------------------+
input int lot_multiple = 1; //How many times bigger than minimum lot?
input double sl_width = 1; //How wide should our stop loss be?

//+------------------------------------------------------------------+
//| Global variables                                                 |
//+------------------------------------------------------------------+
long model; //Our ONNX SGDRegressor model
vectorf prediction(1); //Our model's prediction
float mean_values[8],variance_values[8];//We need this data to normalise and scale model inputs
double trading_volume; //How big should our positions be?
int state = 0;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {

//--- Create the ONNX model from the model buffer we have
   model = OnnxCreateFromBuffer(ModelBuffer,ONNX_DEFAULT);

//--- Ensure the model is valid
   if(model == INVALID_HANDLE)
     {
      Comment("[ERROR] Failed to initialize the model: ",GetLastError());
      return(INIT_FAILED);
     }

//--- Define the model parameters, input and output shapes
   ulong input_shape[] = {1,8};

//--- Check if we were defined the right input shape
   if(!OnnxSetInputShape(model,0,input_shape))
     {
      Comment("[ERROR] Incorrect input shape specified: ",GetLastError(),"\nThe model's inputs are: ",OnnxGetInputCount(model));
      return(INIT_FAILED);
     }

   ulong output_shape[] = {1,1};

//--- Check if we were defined the right output shape
   if(!OnnxSetOutputShape(model,0,output_shape))
     {
      Comment("[ERROR] Incorrect output shape specified: ",GetLastError(),"\nThe model's outputs are: ",OnnxGetOutputCount(model));
      return(INIT_FAILED);
     }

//--- Read the configuration file
   if(!read_configuration_file())
     {
      Comment("Failed to find the configuration file, ensure it is stored here: ",TerminalInfoString(TERMINAL_DATA_PATH));
      return(INIT_FAILED);
     }
//--- Select the symbols
   SymbolSelect("US500",true);
   SymbolSelect("UST05Y_U4",true);

//--- Calculate the lotsize
   trading_volume = SymbolInfoDouble(Symbol(),SYMBOL_VOLUME_MIN) * lot_multiple;

//--- Return init succeeded
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//--- Free up the resources we used for our ONNX model
   OnnxRelease(model);
//--- Remove the expert advisor
   ExpertRemove();
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
   predict();
   Comment("Model forecast",prediction[0]);
  }

//+------------------------------------------------------------------+
//| A function responsible for reading the CSV config file           |
//+------------------------------------------------------------------+
bool read_configuration_file(void)
  {
//--- Read the config file
   Print("Reading in the config file");

//--- Config file name
   string file_name = "sp500_treasury_yields_scale.csv";

//--- Try open the file
   int result = FileOpen(file_name,FILE_READ|FILE_CSV|FILE_ANSI,",");

//--- Check the result
   if(result != INVALID_HANDLE)
     {
      Print("Opened the file");
      //--- Prepare to read the file
      int counter = 0;
      string value = "";
      //--- Make sure we can proceed
      while(!FileIsEnding(result) && !IsStopped())
        {
         if(counter > 60)
            break;
         //--- Read in the file
         value = FileReadString(result);
         Print("Reading: ",value);
         //--- Have we reached the end of the line?
         if(FileIsLineEnding(result))
            Print("row++");
         counter++;
         //--- The first few lines will contain the title of each columns, we will ingore that
         if((counter >= 11) && (counter <= 18))
           {
            mean_values[counter - 11] = (float) value;
           }
         if((counter >= 20) && (counter <= 27))
           {
            variance_values[counter - 20] = (float) value;
           }
        }
      //---Close the file
      FileClose(result);
      Print("Mean values");
      ArrayPrint(mean_values);
      Print("Variance values");
      ArrayPrint(variance_values);
      return(true);
     }

   else
      if(result == INVALID_HANDLE)
        {
         Print("Failed to read the file");
         return(false);
        }

   return(false);
  }


//+------------------------------------------------------------------+
//| A function responsible for getting a forecast from our model     |
//+------------------------------------------------------------------+
void predict(void)
  {
//--- Let's prepare our inputs
   vectorf input_data = vectorf::Zeros(8);
//--- Select the symbol
   input_data[0] = ((iOpen("UST05Y_U4",PERIOD_M1,0) - mean_values[0]) / variance_values[0]);
   input_data[1] = ((iClose("UST05Y_U4",PERIOD_M1,0) - mean_values[1]) / variance_values[1]);
   input_data[2] = ((iHigh("UST05Y_U4",PERIOD_M1,0) - mean_values[2]) / variance_values[2]);
   input_data[3] = ((iLow("UST05Y_U4",PERIOD_M1,0) - mean_values[3]) / variance_values[3]);;
   input_data[4] = ((iOpen("US500",PERIOD_M1,0) - mean_values[4]) / variance_values[4]);;
   input_data[5] = ((iClose("US500",PERIOD_M1,0) - mean_values[5]) / variance_values[5]);;
   input_data[6] = ((iHigh("US500",PERIOD_M1,0) - mean_values[6]) / variance_values[6]);
   input_data[7] = ((iLow("US500",PERIOD_M1,0) - mean_values[7]) / variance_values[7]);;
//--- Show the inputs
   Print("Inputs: ",input_data);
//--- Obtain a prediction from our model
   OnnxRun(model,ONNX_DEFAULT,input_data,prediction);
//--- Interpret the forecast
intepret_prediction();
  }

//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void intepret_prediction(void)
  {
   if(PositionsTotal() == 0)
     {
      double ask = SymbolInfoDouble("US500",SYMBOL_ASK);
      double bid = SymbolInfoDouble("US500",SYMBOL_BID);
      double close = iClose("US500",PERIOD_M1,0);
      if(prediction[0] > close)
        {
         Trade.Buy(trading_volume,"US500",ask,(ask - sl_width),(ask + sl_width),"SP500 X Treasury Yields");
         state = 1;
        }

      if(prediction[0] < iClose("US500",PERIOD_M1,0))
        {
         Trade.Sell(trading_volume,"US500",bid,(bid + sl_width),(bid - sl_width),"SP500 X Treasury Yields");
         state = 2;
        }
     }
   else
      if(PositionsTotal() > 0)
        {
         if((state == 1) && (prediction[0] > iClose("US500",PERIOD_M1,0)))
           {
            Alert("Reversal predicted, consider closing your buy position");
           }

         if((state == 2) && (prediction[0] < iClose("US500",PERIOD_M1,0)))
           {
            Alert("Reversal predicted, consider closing your buy position");
           }
        }

  }
//+------------------------------------------------------------------+
