//+------------------------------------------------------------------+
//|                                                    GBPUSD AI.mq5 |
//|                                        Gamuchirai Zororo Ndawana |
//|                          https://www.mql5.com/en/gamuchiraindawa |
//+------------------------------------------------------------------+
#property copyright "Gamuchirai Zororo Ndawana"
#property link      "https://www.mql5.com/en/gamuchiraindawa"
#property version   "1.00"

//+------------------------------------------------------------------+
//| Load our resources                                               |
//+------------------------------------------------------------------+
#resource  "\\Files\\AUDJPY D1 MA AI F22 P40.onnx" as const uchar onnx_buffer[];
#resource  "\\Files\\AUDJPY D1 RSI AI F22 P40.onnx" as const uchar rsi_onnx_buffer[];

//+------------------------------------------------------------------+
//| Libraries                                                        |
//+------------------------------------------------------------------+
#include <Trade\Trade.mqh>
CTrade Trade;
#include <Trade\OrderInfo.mqh>
class COrderInfo;

//+------------------------------------------------------------------+
//| Global variables                                                 |
//+------------------------------------------------------------------+
long     onnx_model;
int      ma_handler,state;
double   bid,ask,vol;
vectorf  model_forecast   = vectorf::Zeros(1);
vectorf  rsi_model_output = vectorf::Zeros(1);
double   min_volume,max_volume_increase, volume_step, buy_stop_loss, sell_stop_loss,atr_stop,risk_equity;
double   take_profit = 0;
double   close_price[3],atr_reading[],ma_buffer[];
long     min_distance,login;
int      atr,close_average,ticket_1,ticket_2;
bool     authorized = false;
double   margin,lot_step;
string   currency,server;
bool     all_closed =true;
int      rsi_handler;
long     rsi_onnx_model;
double   indicator_reading[];
ENUM_ACCOUNT_TRADE_MODE account_type;
const double  stop_percent = 1;

//+------------------------------------------------------------------+
//| Technical indicators                                             |
//+------------------------------------------------------------------+
input group "Money Management"
input int    lot_multiple     = 10; // How big should the lot size be?
input double profit_target = 0;     // Profit Target
input double loss_target   = 0;     // Max Loss Allowed

input group "Money Management"
input int    bb_period = 36;        //Bollinger band period
input int    ma_period = 4;         //Moving average period
const int    atr_period = 200;      //ATR Period
input double atr_multiple =2.5;      //ATR Multiple

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
//Authorization
   if(!TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      Comment("Press Ctrl + E To Give The Robot Permission To Trade And Reload The Program");
      return(INIT_FAILED);
     }

   else
      if(!MQLInfoInteger(MQL_TRADE_ALLOWED))
        {
         Comment("Reload The Program And Make Sure You Clicked Allow Algo Trading");
         return(INIT_FAILED);
        }

      else
        {
         Comment("This License is Genuine");
         setup();
        }
//--- Everything was okay
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   OnnxRelease(onnx_model);
   OnnxRelease(rsi_onnx_model);
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//--- Update technical data
   update();

   if(PositionsTotal() == 0)
     {
      check_setup();
     }

   if(PositionsTotal() > 0)
     {
      check_atr_stop();
     }
  }
//+------------------------------------------------------------------+
//| Get a prediction from our model                                  |
//+------------------------------------------------------------------+
int model_predict(void)
  {
//MA Forecast
   vectorf  model_inputs = vectorf::Zeros(2);
   vectorf  rsi_model_inputs = vectorf::Zeros(3);
   CopyBuffer(ma_handler,0,0,40,ma_buffer);

   if(ma_buffer[0] > ma_buffer[39])
     {
      model_inputs[0] = 1;
      model_inputs[1] = 0;
     }

   else
      if(ma_buffer[0] < ma_buffer[39])
        {
         model_inputs[1] = 1;
         model_inputs[0] = 0;
        }

//RSI Forecast
   CopyBuffer(rsi_handler,0,0,1,indicator_reading);

   if(indicator_reading[0] < 30)
     {
      rsi_model_inputs[0] = 1;
      rsi_model_inputs[1] = 0;
      rsi_model_inputs[2] = 0;
     }


   else
      if(indicator_reading[0] >70)
        {
         rsi_model_inputs[0] = 0;
         rsi_model_inputs[1] = 1;
         rsi_model_inputs[2] = 0;
        }

      else
        {
         rsi_model_inputs[0] = 0;
         rsi_model_inputs[1] = 0;
         rsi_model_inputs[2] = 1;
        }

//Model predictions
   OnnxRun(onnx_model,ONNX_DEFAULT,model_inputs,model_forecast);
   OnnxRun(rsi_onnx_model,ONNX_DEFAULT,rsi_model_inputs,rsi_model_output);


//Evaluate model output for buy setup
   if(((rsi_model_output[0] > 0)  && (model_forecast[0] > 0)))
     {
      //AI Models forecast
      Comment("AI Forecast: UP");
      return(1);
     }

//Evaluate model output for a sell setup
   if((rsi_model_output[0] < 0) && (model_forecast[0] < 0))
     {
      Comment("AI Forecast: DOWN");
      return(-1);
     }

//Otherwise no position was found
   return(0);
  }

//+------------------------------------------------------------------+
//| Check for valid trade setups                                     |
//+------------------------------------------------------------------+
void check_setup(void)
  {
   int res = model_predict();

   if(res == -1)
     {
      Trade.Sell(vol,Symbol(),bid,0,0,"VD V75 AI");
      state = -1;
     }

   else
      if(res == 1)
        {
         Trade.Buy(vol,Symbol(),ask,0,0,"VD V75 AI");
         state = 1;
        }
  }

//+------------------------------------------------------------------+
//| Update our market data                                           |
//+------------------------------------------------------------------+
void update(void)
  {
   ask = SymbolInfoDouble(_Symbol,SYMBOL_ASK);
   bid = SymbolInfoDouble(_Symbol,SYMBOL_BID);
   buy_stop_loss = 0;
   sell_stop_loss = 0;
   static datetime time_stamp;
   datetime time = iTime(_Symbol,PERIOD_CURRENT,0);
   check_price(3);
   CopyBuffer(atr,0,0,1,atr_reading);
   CopyBuffer(ma_handler,0,0,1,ma_buffer);
   ArraySetAsSeries(atr_reading,true);
   atr_stop = ((min_volume + atr_reading[0]) * atr_multiple);
//On Every Candle
   if(time_stamp != time)
     {

      //Mark the candle
      time_stamp = time;
      OrderCalcMargin(ORDER_TYPE_BUY,_Symbol,min_volume,ask,margin);
     }
  }

//+------------------------------------------------------------------+

//+------------------------------------------------------------------+
//| Load resources                                                   |
//+------------------------------------------------------------------+
bool setup(void)
  {
//Account Info
   currency = AccountInfoString(ACCOUNT_CURRENCY);
   server = AccountInfoString(ACCOUNT_SERVER);
   login = AccountInfoInteger(ACCOUNT_LOGIN);

//Indicators
   atr = iATR(_Symbol,PERIOD_CURRENT,atr_period);

//--- Setup technical indicators
   ma_handler   =iMA(Symbol(),PERIOD_CURRENT,40,0,MODE_SMA,PRICE_LOW);
   vol          = SymbolInfoDouble(Symbol(),SYMBOL_VOLUME_MIN) * lot_multiple;
   rsi_handler  = iRSI(Symbol(),PERIOD_CURRENT,30,PRICE_CLOSE);

//Market Information
   min_volume = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN);
   max_volume_increase = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MAX) / SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN);
   min_distance = SymbolInfoInteger(_Symbol,SYMBOL_TRADE_STOPS_LEVEL);
   lot_step = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_STEP);

//--- Define our ONNX model
   ulong ma_input_shape [] = {1,2};
   ulong rsi_input_shape [] = {1,3};
   ulong output_shape [] = {1,1};

//--- Create the model
   onnx_model     = OnnxCreateFromBuffer(onnx_buffer,ONNX_DEFAULT);
   rsi_onnx_model = OnnxCreateFromBuffer(rsi_onnx_buffer,ONNX_DEFAULT);

   if((onnx_model == INVALID_HANDLE) || (rsi_onnx_model == INVALID_HANDLE))
     {
      Comment("[ERROR] Failed to load AI module correctly");
      return(false);
     }

//--- Validate I/O
   if((!OnnxSetInputShape(onnx_model,0,ma_input_shape)) || (!OnnxSetInputShape(rsi_onnx_model,0,rsi_input_shape)))
     {
      Comment("[ERROR] Failed to set input shape correctly: ",GetLastError());
      return(false);
     }

   if((!OnnxSetOutputShape(onnx_model,0,output_shape)) || (!OnnxSetOutputShape(rsi_onnx_model,0,output_shape)))
     {
      Comment("[ERROR] Failed to load AI module correctly: ",GetLastError());
      return(false);
     }
//--- Everything went fine
   return(true);
  }

//+------------------------------------------------------------------+
//| Close all our open positions                                     |
//+------------------------------------------------------------------+
void close_all()
  {
   if(PositionsTotal() > 0)
     {
      ulong ticket;
      for(int i =0;i < PositionsTotal();i++)
        {
         ticket = PositionGetTicket(i);
         Trade.PositionClose(ticket);
        }
     }
  }

//+------------------------------------------------------------------+
//| Update our trailing ATR stop                                     |
//+------------------------------------------------------------------+
void check_atr_stop()
  {

   for(int i = PositionsTotal() -1; i >= 0; i--)
     {

      string symbol = PositionGetSymbol(i);
      if(_Symbol == symbol)
        {

         ulong ticket = PositionGetInteger(POSITION_TICKET);
         double position_price = PositionGetDouble(POSITION_PRICE_OPEN);
         double type = PositionGetInteger(POSITION_TYPE);
         double current_stop_loss = PositionGetDouble(POSITION_SL);

         if(type == POSITION_TYPE_BUY)
           {
            double atr_stop_loss = (ask - (atr_stop));
            double atr_take_profit = (ask + (atr_stop));

            if((current_stop_loss < atr_stop_loss) || (current_stop_loss == 0))
              {
               Trade.PositionModify(ticket,atr_stop_loss,atr_take_profit);
              }
           }

         else
            if(type == POSITION_TYPE_SELL)
              {
               double atr_stop_loss = (bid + (atr_stop));
               double atr_take_profit = (bid - (atr_stop));
               if((current_stop_loss > atr_stop_loss) || (current_stop_loss == 0))
                 {
                  Trade.PositionModify(ticket,atr_stop_loss,atr_take_profit);
                 }
              }
        }
     }
  }

//+------------------------------------------------------------------+
//| Close our open buy positions                                     |
//+------------------------------------------------------------------+
void close_buy()
  {
   ulong ticket;
   int type;
   if(PositionsTotal() > 0)
     {
      for(int i = 0; i < PositionsTotal();i++)
        {
         if(PositionGetSymbol(i) == _Symbol)
           {
            ticket = PositionGetTicket(i);
            type = (int)PositionGetInteger(POSITION_TYPE);
            if(type == POSITION_TYPE_BUY)
              {
               Trade.PositionClose(ticket);
              }
           }
        }
     }
  }

//+------------------------------------------------------------------+
//| Close our open sell positions                                    |
//+------------------------------------------------------------------+
void close_sell()
  {
   ulong ticket;
   int type;
   if(PositionsTotal() > 0)
     {
      for(int i = 0; i < PositionsTotal();i++)
        {
         if(PositionGetSymbol(i) == _Symbol)
           {
            ticket = PositionGetTicket(i);
            type = (int)PositionGetInteger(POSITION_TYPE);
            if(type == POSITION_TYPE_SELL)
              {
               Trade.PositionClose(ticket);
              }
           }
        }
     }
  }

//+------------------------------------------------------------------+
//| Get the most recent price values                                 |
//+------------------------------------------------------------------+
void check_price(int candles)
  {
   for(int i = 0; i < candles;i++)
     {
      close_price[i] = iClose(_Symbol,PERIOD_CURRENT,i);
     }
  }
//+------------------------------------------------------------------+
