//+------------------------------------------------------------------+
//|                                                   SignalTRPO.mqh |
//|                             Copyright 2000-2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include <Expert\ExpertSignal.mqh>
#include <SRI\61_X.mqh>
// wizard description start
//+------------------------------------------------------------------+
//| Description of the class                                         |
//| Title=Signals of Reinforcement Learning with ADX & CCI.          |
//| Type=SignalAdvanced                                              |
//| Name=Reinforcement Learning with TRPO                            |
//| ShortName=TRPO                                                   |
//| Class=CSignal_TRPO                                               |
//| Page=signal_trpo                                                 |
//| Parameter=Pattern_2,int,50,Pattern 2 [0...100]                   |
//| Parameter=Pattern_3,int,50,Pattern 3 [0...100]                   |
//| Parameter=Pattern_4,int,50,Pattern 4 [0...100]                   |
//| Parameter=PatternUsed,uchar,4,Feature Used [0...8]               |
//| Parameter=Reinforce,bool,false,Use Reinforcement [false...true]  |
//| Parameter=PeriodUsed,int,14,Used Period [3...55]                 |
//+------------------------------------------------------------------+
// wizard description end
//+------------------------------------------------------------------+
//| Class CSignal_TRPO.                                              |
//| Purpose: Class of generator of trade signals based on            |
//|          Reinforcement Learning with ADX & CCI.                  |
//| Is derived from the CExpertSignal class.                         |
//+------------------------------------------------------------------+
//
#resource "Python/61_2.onnx" as uchar __61_2[]
#resource "Python/61_3.onnx" as uchar __61_3[]
#resource "Python/61_4.onnx" as uchar __61_4[]
//
#resource "Python/62_policy_2.onnx" as uchar __62_policy_2[]
#resource "Python/62_policy_3.onnx" as uchar __62_policy_3[]
#resource "Python/62_policy_4.onnx" as uchar __62_policy_4[]
//
#resource "Python/62_value_2.onnx" as uchar __62_value_2[]
#resource "Python/62_value_3.onnx" as uchar __62_value_3[]
#resource "Python/62_value_4.onnx" as uchar __62_value_4[]
//
#define __PATTERNS 3
int __IN_SHAPES[__PATTERNS] = {  5, 5, 5 };
//
class CSignal_TRPO : public CExpertSignal
{
protected:
   CiADX                m_adx;
   CiCCI                m_cci;

   long              m_handles[__PATTERNS];
   long              m_handles_a[__PATTERNS];
   long              m_handles_c[__PATTERNS];
   //--- adjusted parameters

   //--- "weights" of market models (0-100)
   int               m_pattern_2;      // model 2
   int               m_pattern_3;      // model 3
   int               m_pattern_4;      // model 4

   int               m_periods;
   
   uchar             m_pattern_used;
   bool              m_reinforce;
   //
   //int               m_patterns_usage;   //

public:
   CSignal_TRPO(void);
   ~CSignal_TRPO(void);
   //--- methods of setting adjustable parameters
   //--- methods of adjusting "weights" of market models
   void              Pattern_2(int value)
   {  m_pattern_2 = value;
   }
   void              Pattern_3(int value)
   {  m_pattern_3 = value;
   }
   void              Pattern_4(int value)
   {  m_pattern_4 = value;
   }
   //
   void              PatternUsed(uchar value)
   {  m_pattern_used = value;
      PatternsUsage(value);
   }
   void              Reinforce(bool value)
   {  m_reinforce = value;
   }
   void              PeriodUsed(int value)
   {  m_periods = value;
   }
   //--- method of verification of settings
   virtual bool      ValidationSettings(void);
   //--- method of creating the oscillator and timeseries
   virtual bool      InitIndicators(CIndicators *indicators);
   //--- methods of checking if the market models are formed
   virtual int       LongCondition(void);
   virtual int       ShortCondition(void);

protected:
   //--- method of initialization of the oscillator
   bool              InitIndicator(CIndicators *indicators);
   //--- methods of getting data
   double            Close(int ind)
   {  //
      m_close.Refresh(-1);
      return(m_close.GetData(ind));
   }
   double            High(int ind)
   {  //
      m_high.Refresh(-1);
      return(m_high.GetData(ind));
   }
   double            Low(int ind)
   {  //
      m_low.Refresh(-1);
      return(m_low.GetData(ind));
   }
   int               X()
   {  //
      return(StartIndex());
   }
   //--- methods to check for patterns
   double              Supervise(int Index, ENUM_POSITION_TYPE T);
   double              Reinforce(int Index, ENUM_POSITION_TYPE T, double State);
};
//+------------------------------------------------------------------+
//| Constructor                                                      |
//+------------------------------------------------------------------+
CSignal_TRPO::CSignal_TRPO(void) : m_pattern_2(50),
   m_pattern_3(50),
   m_pattern_4(50)
//m_patterns_usage(255)
{
//--- initialization of protected data
   m_used_series = USE_SERIES_CLOSE + USE_SERIES_TIME;
   PatternsUsage(m_patterns_usage);
//--- create model from static buffer
   //
   m_handles[0] = OnnxCreateFromBuffer(__61_2, ONNX_DEFAULT);
   m_handles[1] = OnnxCreateFromBuffer(__61_3, ONNX_DEFAULT);
   m_handles[2] = OnnxCreateFromBuffer(__61_4, ONNX_DEFAULT);
   //
   m_handles_a[0] = OnnxCreateFromBuffer(__62_policy_2, ONNX_DEFAULT);
   m_handles_a[1] = OnnxCreateFromBuffer(__62_policy_3, ONNX_DEFAULT);
   m_handles_a[2] = OnnxCreateFromBuffer(__62_policy_4, ONNX_DEFAULT);
   //
   m_handles_c[0] = OnnxCreateFromBuffer(__62_value_2, ONNX_DEFAULT);
   m_handles_c[1] = OnnxCreateFromBuffer(__62_value_3, ONNX_DEFAULT);
   m_handles_c[2] = OnnxCreateFromBuffer(__62_value_4, ONNX_DEFAULT);
}
//+------------------------------------------------------------------+
//| Destructor                                                       |
//+------------------------------------------------------------------+
CSignal_TRPO::~CSignal_TRPO(void)
{
}
//+------------------------------------------------------------------+
//| Validation settings protected data.                              |
//+------------------------------------------------------------------+
bool CSignal_TRPO::ValidationSettings(void)
{
//--- validation settings of additional filters
   if(!CExpertSignal::ValidationSettings())
      return(false);
      
   if(m_patterns_usage!=4 && m_patterns_usage!=8 && m_patterns_usage!=16)
   {  printf(__FUNCSIG__+" selected pattern: "+IntegerToString(m_patterns_usage)+" is out of scope; should be 2, 3, or 4! ");
      return(false);
   }
//--- initial data checks
   const long _out_shape[] = {1, 1, 1};
   const long _out_shape_cov_mat[] = {1};
   for(int i = 0; i < __PATTERNS; i++)
   {  // Set input actor shapes
      const long _in_shape[] = {1, 1, __IN_SHAPES[i]};
      if(!OnnxSetInputShape(m_handles[i], ONNX_DEFAULT, _in_shape))
      {  Print("OnnxSetInputShape error ", GetLastError()," for feature: ",i);
         return(false);
      }
      // Set output shapes
      if(!OnnxSetOutputShape(m_handles[i], 0, _out_shape))
      {  Print("OnnxSetOutputShape error ", GetLastError()," for feature: ",i);
         return(false);
      }
      //
      const long _in_shape_a[] = {1, 1, 1};
      if(!OnnxSetInputShape(m_handles_a[i], ONNX_DEFAULT, _in_shape_a))
      {  Print("OnnxSetInputShape POLICY error ", GetLastError());
         return(false);
      }
      // Set output actor shapes
      if(!OnnxSetOutputShape(m_handles_a[i], 0, _out_shape))
      {  Print("OnnxSetOutputShape POLICY MU error ", GetLastError());
         return(false);
      }
      if(!OnnxSetOutputShape(m_handles_a[i], 1, _out_shape_cov_mat))
      {  Print("OnnxSetOutputShape POLICY COV-MAT error ", GetLastError());
         return(false);
      }
      // Set input critic shapes
      const long _in_shape_c[] = {1, 1, 1};
      if(!OnnxSetInputShape(m_handles_c[i], ONNX_DEFAULT, _in_shape_c))
      {  Print("OnnxSetInputShape VALUE error ", GetLastError());
         return(false);
      }
      // Set output critic shapes
      const long _out_shape_c[] = {1, 1, 1};
      if(!OnnxSetOutputShape(m_handles_c[i], 0, _out_shape_c))
      {  Print("OnnxSetOutputShape VALUE error ", GetLastError());
         return(false);
      }
   }
//--- ok
   return(true);
}
//+------------------------------------------------------------------+
//| Create indicators.                                               |
//+------------------------------------------------------------------+
bool CSignal_TRPO::InitIndicators(CIndicators *indicators)
{
//--- check pointer
   if(indicators == NULL)
      return(false);
//--- initialization of indicators and timeseries of additional filters
   if(!CExpertSignal::InitIndicators(indicators))
      return(false);
//--- create and initialize MA oscillator
   if(!InitIndicator(indicators))
      return(false);
//--- ok
   return(true);
}
//+------------------------------------------------------------------+
//| Initialize MA indicators.                                        |
//+------------------------------------------------------------------+
bool CSignal_TRPO::InitIndicator(CIndicators *indicators)
{
//--- check pointer
   if(indicators == NULL)
      return(false);
//--- add object to collection
   if(!indicators.Add(GetPointer(m_adx)))
   {  printf(__FUNCTION__ + ": error adding object");
      return(false);
   }
//--- initialize object
   if(!m_adx.Create(m_symbol.Name(), m_period,m_periods))
   {  printf(__FUNCTION__ + ": error initializing object");
      return(false);
   }
   
   if(!indicators.Add(GetPointer(m_cci)))
   {  printf(__FUNCTION__ + ": error adding object");
      return(false);
   }
   if(!m_cci.Create(m_symbol.Name(), m_period, m_periods, PRICE_CLOSE))
   {  printf(__FUNCTION__ + ": error initializing object");
      return(false);
   }
//--- ok
   return(true);
}
//+------------------------------------------------------------------+
//| "Voting" that price will grow.                                   |
//+------------------------------------------------------------------+
int CSignal_TRPO::LongCondition(void)
{  int result  = 0, results = 0;
//--- if the model 2 is used
   if(m_pattern_used == 4)
   {  double _s = Supervise(2-2, POSITION_TYPE_BUY);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(2-2, POSITION_TYPE_BUY, _s);
      }
      result += int(round(m_pattern_2 * _r));
      results++;
   }
//--- if the model 3 is used
   if(m_pattern_used == 8)
   {  double _s = Supervise(3-2, POSITION_TYPE_BUY);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(3-2, POSITION_TYPE_BUY, _s);
      }
      result += int(round(m_pattern_3 * _r));
      results++;
   }
//--- if the model 4 is used
   if(m_pattern_used == 16)
   {  double _s = Supervise(4-2, POSITION_TYPE_BUY);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(4-2, POSITION_TYPE_BUY, _s);
      }
      result += int(round(m_pattern_4 * _r));
      results++;
   }
//--- return the result
   if(results > 0)
   {  return(int(round(result / results)));
   }
   return(0);
}
//+------------------------------------------------------------------+
//| "Voting" that price will fall.                                   |
//+------------------------------------------------------------------+
int CSignal_TRPO::ShortCondition(void)
{  int result  = 0, results = 0;
//--- if the model 2 is used
   if(m_pattern_used == 4)
   {  double _s = Supervise(2-2, POSITION_TYPE_SELL);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(2-2, POSITION_TYPE_SELL, _s);
      }
      result += int(round(m_pattern_2 * _r));
      results++;
   }
//--- if the model 3 is used
   if(m_pattern_used == 8)
   {  double _s = Supervise(3-2, POSITION_TYPE_SELL);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(3-2, POSITION_TYPE_SELL, _s);
      }
      result += int(round(m_pattern_3 * _r));
      results++;
   }
//--- if the model 4 is used
   if(m_pattern_used == 16)
   {  double _s = Supervise(4-2, POSITION_TYPE_SELL);
      double _r = _s;
      if(m_reinforce)
      {  _r = Reinforce(4-2, POSITION_TYPE_SELL, _s);
      }
      result += int(round(m_pattern_4 * _r));
      results++;
   }
//--- return the result
   if(results > 0)
   {  return(int(round(result / results)));
   }
   return(0);
}
//+------------------------------------------------------------------+
//| Supervised Learning Model Forward Pass.                          |
//+------------------------------------------------------------------+
double CSignal_TRPO::Supervise(int Index, ENUM_POSITION_TYPE T)
{  vectorf _x = Get(Index, m_time.GetData(X()), m_adx, m_cci, __IN_SHAPES[Index]);
   vectorf _y(1);
   _y.Fill(0.0);
   ResetLastError();
   if(!OnnxRun(m_handles[Index], ONNX_NO_CONVERSION, _x, _y))
   {  printf(__FUNCSIG__ + " failed to get y forecast, err: %i", GetLastError());
      return(double(_y[0]));
   }
   if(T == POSITION_TYPE_BUY && _y[0] > 0.5f)
   {  _y[0] = 2.0f * (_y[0] - 0.5f);
   }
   else if(T == POSITION_TYPE_SELL && _y[0] < 0.5f)
   {  _y[0] = 2.0f * (0.5f - _y[0]);
   }
   return(double(_y[0]));
}
//+------------------------------------------------------------------+
//| Reinforcement Learning Model Forward Pass.                       |
//+------------------------------------------------------------------+
double CSignal_TRPO::Reinforce(int Index, ENUM_POSITION_TYPE T, double State)
{  vectorf _x(1);
   _x.Fill(float(State));
   vectorf _y(1),_y_mu(1),_y_cov_mat(1);
   _y.Fill(0.0);
   _y_mu.Fill(0.0);
   _y_cov_mat.Fill(0.0);
   ResetLastError();
   if(!OnnxRun(m_handles_a[Index], ONNX_NO_CONVERSION, _x, _y_mu, _y_cov_mat))
   {  printf(__FUNCSIG__ + " failed to get y action forecast, err: %i", GetLastError());
   }
   _y = multivariate_normal(_y_mu, _y_cov_mat, 1);
   //normalize action output to be 0.0-1.0 range;
   if(T == POSITION_TYPE_BUY && State >= 0.5f && _y[0] >= 0.5f)
   {  _y[0] = 2.0f * (_y[0] - 0.5f);
   }
   else if(T == POSITION_TYPE_SELL && State <= 0.5f && _y[0] <= 0.5f)
   {  _y[0] = 2.0f * (0.5f - _y[0]);
   }
   else
   {  _y[0] = 0.0f;
   }
   return(double(_y_cov_mat[0]*State));
}

// Function prototypes

// Generate samples from multivariate normal distribution
vectorf multivariate_normal(vectorf &mean, vectorf &cov, int n_samples)
{  int dim = int(mean.Size());
// Perform Cholesky decomposition
   vectorf L = cholesky_decomposition(cov, int(mean.Size()));
// Allocate memory for samples (n_samples x dim)
   vectorf samples;
   samples.Init(n_samples*dim);
   samples.Fill(0.0);
// Generate samples
   for (int i = 0; i < n_samples; i++)
   {  for (int j = 0; j < dim; j++)
      {  // Generate standard normal random variable
         float z = rand_normal(0.0, 1.0);
         // Initialize with mean
         samples[i * dim + j] = mean[j];
         // Transform using Cholesky factor
         for (int k = 0; k <= j; k++)
         {  samples[i * dim + j] += L[j * dim + k] * z;
         }
      }
   }
   return samples;
}

// Cholesky decomposition of a symmetric positive-definite matrix
vectorf cholesky_decomposition(vectorf &A, int mean_size)
{  int n = mean_size;
   vectorf L;
   L.Init(n*n);
   for (int i = 0; i < n; i++)
   {  for (int j = 0; j <= i; j++)
      {  float sum = 0.0;
         if (j == i)
         {  for (int k = 0; k < j; k++)
            {  sum += L[j * n + k] * L[j * n + k];
            }
            L[j * n + j] = sqrt(A[j * n + j] - sum);
         }
         else
         {  for (int k = 0; k < j; k++)
            {  sum += L[i * n + k] * L[j * n + k];
            }
            L[i * n + j] = (A[i * n + j] - sum) / L[j * n + j];
         }
      }
   }
   return L;
}

// Box-Muller transform to generate normal random variates
float rand_normal(float mean, float stddev)
{  static float n2 = 0.0;
   static int n2_cached = 0;
   if (!n2_cached)
   {  float x, y, r;
      do
      {  x = float((2.0 * (MathRand()%32767)) / (32767 - 1));
         y = float((2.0 * (MathRand()%32767)) / (32767 - 1));
         r = x * x + y * y;
      }
      while (r == 0.0 || r > 1.0);
      float d = float(sqrt(-2.0 * log(r) / r));
      n2 = y * d;
      n2_cached = 1;
      return mean + x * d * stddev;
   }
   else
   {  n2_cached = 0;
      return mean + n2 * stddev;
   }
}
//+------------------------------------------------------------------+
