//+------------------------------------------------------------------+
//|                                              EURUSD Polar EA.mq5 |
//|                                               Gamuchirai Ndawana |
//|                    https://www.mql5.com/en/users/gamuchiraindawa |
//+------------------------------------------------------------------+
#property copyright "Gamuchirai Ndawana"
#property link      "https://www.mql5.com/en/users/gamuchiraindawa"
#property version   "1.00"

//+------------------------------------------------------------------+
//| System Constants                                                 |
//+------------------------------------------------------------------+
#define ONNX_INPUTS 9                                              //The total number of inputs for our onnx model
#define ONNX_OUTPUTS 1                                             //The total number of outputs for our onnx model
#define TF_1  PERIOD_D1                                            //The system's primary time frame
#define TRADING_VOLUME 0.1                                         //The system's trading volume

//+------------------------------------------------------------------+
//| System Resources                                                 |
//+------------------------------------------------------------------+
#resource "\\Files\\EURUSD D1 R Model.onnx" as uchar r_model_buffer[];
#resource "\\Files\\EURUSD D1 Theta Model.onnx" as uchar theta_model_buffer[];

//+------------------------------------------------------------------+
//| Global variables                                                 |
//+------------------------------------------------------------------+
double mean_values[] = {1.1884188643844635,1.1920754015799868,1.1847545720868993,1.1883860236998025,1.6806588395310122,0.7853854898794739,-1.1883860236998025,1.1884188643844635,1.1884188643844635};
double std_values[]  = {0.09123896995032886,0.09116171300874902,0.0912656190371797,0.09120265318308786,0.1289537623737421,0.0021932437785043796,0.09120265318308786,0.09123896995032886,0.09123896995032886};
double current_r,current_theta;
long r_model,theta_model;
vectorf r_model_output = vectorf::Zeros(ONNX_OUTPUTS);
vectorf theta_model_output = vectorf::Zeros(ONNX_OUTPUTS);
double bid,ask;
int ma_o_handler,ma_c_handler,state;
double ma_o_buffer[],ma_c_buffer[];

//+------------------------------------------------------------------+
//| Library                                                          |
//+------------------------------------------------------------------+
#include <Trade/Trade.mqh>
CTrade Trade;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
//---
   if(!setup())
     {
      Comment("Failed To Load Corretly");
      return(INIT_FAILED);
     }

   Comment("Started");
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   OnnxRelease(r_model);
   OnnxRelease(theta_model);
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   update();
  }
//+------------------------------------------------------------------+

//+------------------------------------------------------------------+
//| Get a prediction from our models                                 |
//+------------------------------------------------------------------+
void get_model_prediction(void)
  {
//Define theta and r
   double o = iOpen(_Symbol,PERIOD_CURRENT,1);
   double h = iHigh(_Symbol,PERIOD_CURRENT,1);
   double l = iLow(_Symbol,PERIOD_CURRENT,1);
   double c = iClose(_Symbol,PERIOD_CURRENT,1);
   current_r = MathSqrt(MathPow(o,2) + MathPow(c,2));
   current_theta = MathArctan2(c,o);

   vectorf model_inputs =
     {
      (float) o,
      (float) h,
      (float) l,
      (float) c,
      (float) current_r,
      (float) current_theta,
      (float)(current_r * (-(MathSin(current_theta)))),
      (float)(current_r * MathCos(current_theta)),
      (float)(1/MathPow(MathCos(current_theta),2))
     };

//Standardize the model inputs
   for(int i = 0; i < ONNX_INPUTS;i++)
     {
      model_inputs[i] = (float)((model_inputs[i] - mean_values[i]) / std_values[i]);
     }

//Get a prediction from our model
   OnnxRun(r_model,ONNX_DATA_TYPE_FLOAT,model_inputs,r_model_output);
   OnnxRun(theta_model,ONNX_DATA_TYPE_FLOAT,model_inputs,theta_model_output);

//Give our prediction
   Comment(StringFormat("R: %f \nTheta: %f\nR Forecast: %f\nTheta Forecast: %f",current_r,current_theta,r_model_output[0],theta_model_output[0]));
  }

//+------------------------------------------------------------------+
//| Update system state                                              |
//+------------------------------------------------------------------+
void update(void)
  {
   static datetime time_stamp;
   datetime current_time = iTime(_Symbol,TF_1,0);

   bid = SymbolInfoDouble(_Symbol,SYMBOL_BID);
   ask = SymbolInfoDouble(_Symbol,SYMBOL_ASK);
   if(current_time != time_stamp)
     {
      CopyBuffer(ma_o_handler,0,0,1,ma_o_buffer);
      CopyBuffer(ma_c_handler,0,0,1,ma_c_buffer);
      time_stamp = current_time;
      get_model_prediction();
      manage_account();
      if(PositionsTotal() == 0)
         get_signal();
     }
  }

//+------------------------------------------------------------------+
//| Manage the open positions we have in the market                  |
//+------------------------------------------------------------------+
void manage_account()
  {
   if(AccountInfoDouble(ACCOUNT_BALANCE) < AccountInfoDouble(ACCOUNT_EQUITY))
     {
      while(PositionsTotal() > 0)
         Trade.PositionClose(Symbol());
     }

   if(state == 1)
     {
      if(ma_c_buffer[0] < ma_o_buffer[0])
         Trade.PositionClose(Symbol());
     }

   if(state == -1)
     {
      if(ma_c_buffer[0] > ma_o_buffer[0])
         Trade.PositionClose(Symbol());
     }
  }

//+------------------------------------------------------------------+
//| Setup system variables                                           |
//+------------------------------------------------------------------+
bool setup(void)
  {
   ma_o_handler = iMA(Symbol(),TF_1,50,0,MODE_SMA,PRICE_CLOSE);
   ma_c_handler = iMA(Symbol(),TF_1,10,0,MODE_SMA,PRICE_CLOSE);

   r_model = OnnxCreateFromBuffer(r_model_buffer,ONNX_DEFAULT);
   theta_model = OnnxCreateFromBuffer(theta_model_buffer,ONNX_DEFAULT);

   if(r_model == INVALID_HANDLE)
      return(false);
   if(theta_model == INVALID_HANDLE)
      return(false);

   ulong input_shape[] = {1,ONNX_INPUTS};
   ulong output_shape[] = {1,ONNX_OUTPUTS};

   if(!OnnxSetInputShape(r_model,0,input_shape))
      return(false);
   if(!OnnxSetInputShape(theta_model,0,input_shape))
      return(false);

   if(!OnnxSetOutputShape(r_model,0,output_shape))
      return(false);
   if(!OnnxSetOutputShape(theta_model,0,output_shape))
      return(false);

   return(true);
  }

//+------------------------------------------------------------------+
//| Check if we have a trading signal                                |
//+------------------------------------------------------------------+
void get_signal(void)
  {
   if(ma_c_buffer[0] > ma_o_buffer[0])
     {
      if((r_model_output[0] < current_r) && (theta_model_output[0] < current_theta))
        {
         return;
        }
        
      if((r_model_output[0] > current_r) && (theta_model_output[0] > current_theta))
        {
         Trade.Buy(TRADING_VOLUME * 2,Symbol(),ask,0,0);
         Trade.Buy(TRADING_VOLUME * 2,Symbol(),ask,0,0);
         state = 1;
         return;
        }
        
      Trade.Buy(TRADING_VOLUME,Symbol(),ask,0,0);
      state = 1;
      return;
     }

   if(ma_c_buffer[0] < ma_o_buffer[0])
     {
      if((r_model_output[0] > current_r) && (theta_model_output[0] > current_theta))
        {
         return;
        }

     if((r_model_output[0] < current_r) && (theta_model_output[0] < current_theta))
        {
         
         Trade.Sell(TRADING_VOLUME * 2,Symbol(),bid,0,0);
         Trade.Sell(TRADING_VOLUME * 2,Symbol(),bid,0,0);
         state = -1;
         return;
        }

      Trade.Sell(TRADING_VOLUME,Symbol(),bid,0,0);
      state = -1;
      return;
     }
  }

//+------------------------------------------------------------------+
//| Undefine system variables we don't need                          |
//+------------------------------------------------------------------+
#undef ONNX_INPUTS
#undef ONNX_OUTPUTS
#undef TF_1
//+------------------------------------------------------------------+