//+------------------------------------------------------------------+
//|                                                    SignalSAC.mqh |
//|                   Copyright 2009-2017, MetaQuotes Software Corp. |
//|                                              http://www.mql5.com |
//+------------------------------------------------------------------+
#include <Expert\ExpertSignal.mqh>
#include <My\Cql.mqh>
#resource "Python/EURUSD_H1_D1_critic2.onnx" as uchar __CRITIC_2[]
#resource "Python/EURUSD_H1_D1_critic1.onnx" as uchar __CRITIC_1[]
#resource "Python/EURUSD_H1_D1_actor.onnx" as uchar __ACTOR[]
#define  __ACTIONS 3
#define  __ENVIONMENTS 3
//+------------------------------------------------------------------+
// wizard description start
//+------------------------------------------------------------------+
//| Description of the class                                         |
//| Title=Signals based on Reinforcement-Learning with Soft Actor Critic.|
//| Type=SignalAdvanced                                              |
//| Name=Reinforcement-Learning with Soft Actor Critic               |
//| ShortName=SAC                                                    |
//| Class=CSignalSAC                                                 |
//| Page=signal_soft_actor_critic                                    |
//+------------------------------------------------------------------+
// wizard description end
//+------------------------------------------------------------------+
//| SACs CSignalSAC.                                                 |
//| Purpose: Soft Actor Critic for Reinforcement-Learning.           |
//|            Derives from class CExpertSignal.                     |
//+------------------------------------------------------------------+
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
class CSignalSAC   : public CExpertSignal
{
protected:

   long                          m_critic_2_handle;
   long                          m_critic_1_handle;
   long                          m_actor_handle;



public:
   void                          CSignalSAC(void);
   void                          ~CSignalSAC(void);

   //--- methods of setting adjustable parameters

   //--- method of verification of arch
   virtual bool      ValidationSettings(void);
   //--- method of creating the indicator and timeseries
   virtual bool      InitIndicators(CIndicators *indicators);
   //--- methods of checking if the market models are formed
   virtual int       LongCondition(void);
   virtual int       ShortCondition(void);

protected:
   vectorf           GetOutput();
   vectorf           LogProbabilities(vectorf &Mean, vectorf &Log_STD);
};
//+------------------------------------------------------------------+
//| Constructor                                                      |
//+------------------------------------------------------------------+
void CSignalSAC::CSignalSAC(void)

{
//--- initialization of protected data
   m_used_series = USE_SERIES_OPEN + USE_SERIES_HIGH + USE_SERIES_LOW + USE_SERIES_CLOSE + USE_SERIES_SPREAD + USE_SERIES_TIME;
   //
//--- create O model from static buffer
   m_critic_2_handle = OnnxCreateFromBuffer(__CRITIC_2, ONNX_DEFAULT);
   m_critic_1_handle = OnnxCreateFromBuffer(__CRITIC_1, ONNX_DEFAULT);
   m_actor_handle = OnnxCreateFromBuffer(__ACTOR, ONNX_DEFAULT);
}
//+------------------------------------------------------------------+
//| Destructor                                                       |
//+------------------------------------------------------------------+
void CSignalSAC::~CSignalSAC(void)
{
}
//+------------------------------------------------------------------+
//| Validation arch protected data.                                  |
//+------------------------------------------------------------------+
bool CSignalSAC::ValidationSettings(void)
{  if(!CExpertSignal::ValidationSettings())
      return(false);
//--- initial data checks
   if(m_period > PERIOD_H1)
   {  Print(" time frame too large ");
      return(false);
   }
   ResetLastError();
   if(m_critic_2_handle == INVALID_HANDLE)
   {  Print("Crit 2 OnnxCreateFromBuffer error ", GetLastError());
      return(false);
   }
   if(m_critic_1_handle == INVALID_HANDLE)
   {  Print("Crit 1 OnnxCreateFromBuffer error ", GetLastError());
      return(false);
   }
   if(m_actor_handle == INVALID_HANDLE)
   {  Print("Actor OnnxCreateFromBuffer error ", GetLastError());
      return(false);
   }
   // Set input shapes
   const long _critic_in_shape[] = {1, 4, 1};
   const long _actor_in_shape[] = {1};
   // Set output shapes
   const long _critic_out_shape[] = {1, 4, 1, 1};
   const long _actor_out_shape[] = {1, 6};
   if(!OnnxSetInputShape(m_critic_2_handle, ONNX_DEFAULT, _critic_in_shape))
   {  Print("Crit 2  OnnxSetInputShape error ", GetLastError());
      return(false);
   }
   if(!OnnxSetOutputShape(m_critic_2_handle, 0, _critic_out_shape))
   {  Print("Crit 2  OnnxSetOutputShape error ", GetLastError());
      return(false);
   }
   if(!OnnxSetInputShape(m_critic_1_handle, ONNX_DEFAULT, _critic_in_shape))
   {  Print("Crit 1 OnnxSetInputShape error ", GetLastError());
      return(false);
   }
   if(!OnnxSetOutputShape(m_critic_1_handle, 0, _critic_out_shape))
   {  Print("Crit 1 OnnxSetOutputShape error ", GetLastError());
      return(false);
   }
   if(!OnnxSetInputShape(m_actor_handle, ONNX_DEFAULT, _actor_in_shape))
   {  Print("Actor OnnxSetInputShape error ", GetLastError());
      return(false);
   }
   if(!OnnxSetOutputShape(m_actor_handle, 0, _actor_out_shape))
   {  Print("Actor OnnxSetOutputShape error ", GetLastError());
      return(false);
   }
//read best weights
//--- ok
   return(true);
}
//+------------------------------------------------------------------+
//| Create indicators.                                               |
//+------------------------------------------------------------------+
bool CSignalSAC::InitIndicators(CIndicators *indicators)
{
//--- check pointer
   if(indicators == NULL)
      return(false);
//--- initialization of indicators and timeseries of additional filters
   if(!CExpertSignal::InitIndicators(indicators))
      return(false);
//--- ok
   return(true);
}
//+------------------------------------------------------------------+
//| "Voting" that price will grow.                                   |
//+------------------------------------------------------------------+
int CSignalSAC::LongCondition(void)
{  int result = 0;
   vectorf _out = GetOutput();
   printf(__FUNCSIG__+" 0 & 0-2 gap are: %.5f & %.5f",_out[0],fabs(_out[0]-_out[1])-fabs(_out[1]-_out[2]));
   if(_out[0] < _out[2])
   {  result = 100;
   }
   return(result);
}
//+------------------------------------------------------------------+
//| "Voting" that price will fall.                                   |
//+------------------------------------------------------------------+
int CSignalSAC::ShortCondition(void)
{  int result = 0;
   vectorf _out = GetOutput();
   printf(__FUNCSIG__+" 2 & 0-2 gap are: %.5f & %.5f",_out[2],fabs(_out[1]-_out[2])-fabs(_out[0]-_out[1]));
   if(_out[2] < _out[0])
   {  result = 100;
   }
   return(result);
}
//+------------------------------------------------------------------+
//| This function calculates the next actions to be selected from    |
//| the Reinforcement Learning Cycle.                                |
//+------------------------------------------------------------------+
vectorf CSignalSAC::GetOutput()
{  vectorf _out;
   int _load = 1;
   static vectorf _x_states(1);
   _out.Init(__ACTIONS);
   _out.Fill(0.0);
   vector _in, _in_row, _in_row_old, _in_col, _in_col_old;
   if
   (
      _in_row.Init(_load) &&
      _in_row.CopyRates(m_symbol.Name(), PERIOD_H1, 8, 0, _load) &&
      _in_row.Size() == _load
      &&
      _in_row_old.Init(_load) &&
      _in_row_old.CopyRates(m_symbol.Name(), PERIOD_H1, 8, 1, _load) &&
      _in_row_old.Size() == _load
      &&
      _in_col.Init(_load) &&
      _in_col.CopyRates(m_symbol.Name(), PERIOD_D1, 8, 0, _load) &&
      _in_col.Size() == _load
      &&
      _in_col_old.Init(_load) &&
      _in_col_old.CopyRates(m_symbol.Name(), PERIOD_D1, 8, 1, _load) &&
      _in_col_old.Size() == _load
   )
   {  _in_row -= _in_row_old;
      _in_col -= _in_col_old;
      Cql *QL;
      Sql _RL;
      _RL.actions  = __ACTIONS;//buy, sell, do nothing
      _RL.environments = __ENVIONMENTS;//bullish, bearish, flat
      QL = new Cql(_RL);
      vector _e(_load);
      QL.Environment(_in_row, _in_col, _e);
      delete QL;
      _x_states[0] = float(_e[0]);
      static matrixf _y_mu_logstd(6, 1);
//--- run the inference
      ResetLastError();
      if(!OnnxRun(m_actor_handle, ONNX_NO_CONVERSION, _x_states, _y_mu_logstd))
      {  Print("Actor OnnxConversion error ", GetLastError());
         return(_out);
      }
      else
      {  vectorf _mu(__ACTIONS), _logstd(__ACTIONS);
         _mu.Fill(0.0); _logstd.Fill(0.0);
         for(int i=0;i<__ACTIONS;i++)
         {  _mu[i] = _y_mu_logstd[i][0];
            _logstd[i] = _y_mu_logstd[i+__ACTIONS][0];
         }
         _out = LogProbabilities(_mu, _logstd);
      }
   }
   return(_out);
}

//+------------------------------------------------------------------+
// Function to compute the Gaussian probability distribution and log
// probabilities
//+------------------------------------------------------------------+
vectorf CSignalSAC::LogProbabilities(vectorf &Mean, vectorf &Log_STD)
{  vectorf _log_probs;
   // Compute standard deviations from log_std
   vectorf _std = exp(Log_STD);
   // Sample actions and compute log probabilities
   float _z = float(rand() % USHORT_MAX / USHORT_MAX); // Generate N(0, 1) sample
   // Sample action using reparameterization trick: action = mean + std * N(0, 1)
   vectorf _actions = Mean + (_std * _z);
   // Compute log probability of the sampled action
   vectorf _variance = _std * _std;
   vectorf _diff = _actions - Mean;
   _log_probs = -0.5f * (log(2.0f * M_PI * _variance) + (_diff * _diff) / _variance);
   return(_log_probs);
}
//+------------------------------------------------------------------+
