//+------------------------------------------------------------------+
//|                                                         grnn.mqh |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#define EPS1 1.e-180
#include<Math/Stat/Normal.mqh>
//+------------------------------------------------------------------+
//| General regression neural network                                |
//+------------------------------------------------------------------+
class CGrnn
  {
public:
                     CGrnn(void);
                     CGrnn(int num_outer, int num_inner, double start_std);
                    ~CGrnn(void);
   bool              fit(matrix &predictors,matrix &targets);
   vector            predict(vector &predictors);
   //double            get_mse(void);
private:
   bool              train(void);
   double            execute(void);
   ulong             m_inputs,m_outputs;
   int               m_inner,m_outer;
   double            m_start_std;
   ulong             m_rows,m_cols;
   bool              m_trained;
   vector            m_sigma;
   matrix            m_targets,m_preds;
  };
//+------------------------------------------------------------------+
//| default constructor                                              |
//+------------------------------------------------------------------+
CGrnn::CGrnn(void)
  {
   m_inner = 100;
   m_outer = 10;
   m_start_std = 3.0;
  }
//+------------------------------------------------------------------+
//|parametric constructor                                            |
//+------------------------------------------------------------------+
CGrnn::CGrnn(int num_outer,int num_inner,double start_std)
  {
   m_inner = num_inner;
   m_outer = num_outer;
   m_start_std = start_std;
  }
//+------------------------------------------------------------------+
//| destructor                                                       |
//+------------------------------------------------------------------+
CGrnn::~CGrnn(void)
  {
  }
//+------------------------------------------------------------------+
//| fit data to a model                                              |
//+------------------------------------------------------------------+
bool CGrnn::fit(matrix &predictors,matrix &targets)
  {
   m_targets = targets;
   m_preds = predictors;
   m_trained = false;
   m_rows = m_preds.Rows();
   m_cols = m_preds.Cols();
   m_sigma = vector::Zeros(m_preds.Cols());

   if(m_targets.Rows() != m_preds.Rows())
     {
      Print(__FUNCTION__, " invalid inputs ");
      return false;
     }

   m_trained = train();

   return m_trained;
  }
//+------------------------------------------------------------------+
//| make a prediction with a trained model                           |
//+------------------------------------------------------------------+
vector CGrnn::predict(vector &predictors)
  {
   if(!m_trained)
     {
      Print(__FUNCTION__, " no trained model available for predictions ");
      return vector::Zeros(1);
     }

   if(predictors.Size() != m_cols)
     {
      Print(__FUNCTION__, " invalid inputs ");
      return vector::Zeros(1);
     }

   vector output  = vector::Zeros(m_targets.Cols());
   double diff,dist,psum=0.0;

   for(ulong i = 0; i<m_rows; i++)
     {
      dist  = 0.0;
      for(ulong j = 0; j<m_cols; j++)
        {
         diff  = predictors[j]  - m_preds[i][j];
         diff/= m_sigma[j];
         dist += (diff*diff);
        }
      dist  = exp(-dist);
      if(dist< EPS1)
         dist = EPS1;
      for(ulong k = 0; k<m_targets.Cols(); k++)
         output[k] += dist * m_targets[i][k];
      psum += dist;
     }
   output/=psum;
   return output;
  }
//+------------------------------------------------------------------+
//| main training routine                                            |
//+------------------------------------------------------------------+
bool CGrnn::train(void)
  {

   double error, best_error,std;

   vector best_wts(m_preds.Cols());
   vector test_wts(m_preds.Cols());
   vector center(m_preds.Cols());

   center.Fill(0.0);

   best_error = 1.0;
   std = m_start_std;
   int er;

   for(int outer=0; outer<m_outer; outer++)
     {
      for(int inner=0; inner<m_inner; inner++)
        {
         for(ulong i =0; i<m_preds.Cols(); i++)
           {
            test_wts[i] = center[i] +std*MathRandomNormal(0.0,1.0,er);
            if(er)
              {
               Print(__FUNCTION__, "Normal error ", er);
               return false;
              }
            m_sigma[i] = exp(test_wts[i]);
           }
         error = execute();
         if((best_error<0.0) || (error<best_error))
           {
            best_error = error;
            best_wts = test_wts;
           }
        }
      center = best_wts;
      std*=0.7;
     }

   m_sigma = exp(best_wts);

   return true;
  }
//+------------------------------------------------------------------+
//|  get the error                                                   |
//+------------------------------------------------------------------+
double CGrnn::execute(void)
  {
   double err = 0.0;
   double psum,dist,diff;
   vector out(m_targets.Cols());
   for(ulong itest  = 0; itest<m_preds.Rows(); itest++)
     {
      for(ulong iout=0; iout<m_targets.Cols(); iout++)
        {
         out[iout] = 0.0;
        }
      psum = 0.0;
      for(ulong icase=0; icase<m_preds.Rows(); icase++)
        {
         if(icase == itest)
            continue;

         dist = 0.0;
         for(ulong ivar=0; ivar<m_preds.Cols(); ivar++)
           {
            diff = m_preds[itest][ivar] - m_preds[icase][ivar];
            diff/=m_sigma[ivar];
            dist+=(diff*diff);
           }
         dist = exp(-dist);
         if(dist<EPS1)
            dist = EPS1;
         for(ulong ivar=0; ivar<m_targets.Cols(); ivar++)
           {
            out[ivar] += dist * m_targets[icase][ivar];
           }
         psum  += dist;
        }
      for(ulong ivar=0; ivar<m_targets.Cols(); ivar++)
        {
         out[ivar]/=psum;
         diff=out[ivar] - m_targets[itest][ivar];
         err+=(diff*diff);
        }
     }
   err/=double(m_targets.Rows()*m_targets.Cols());

   return err;
  }
//+------------------------------------------------------------------+
