//+------------------------------------------------------------------+
//|                   ErrorVarianceEstimation_ClassificationDemo.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
#property script_show_inputs
#include<error_variance_estimation.mqh>
#include<OLS.mqh>
//--- input parameters
input ulong      NumSamples=15;
input ulong      NumBootStraps = 1000;
input ulong      NumReplications = 100;
input double     PredictionDifficultyLevel = 0.0;
//---
//+------------------------------------------------------------------+
//|  normal(rngstate)                                                |
//+------------------------------------------------------------------+
double normal(CHighQualityRandStateShell &state)
  {
   return CAlglib::HQRndNormal(state);
  }
//+------------------------------------------------------------------+
//|   unifrand(rngstate)                                             |
//+------------------------------------------------------------------+
double unifrand(CHighQualityRandStateShell &state)
  {
   return CAlglib::HQRndUniformR(state);
  }
//+------------------------------------------------------------------+
//| ordinary least squares class                                     |
//+------------------------------------------------------------------+
class COrdReg:public IModel
  {
private:
   OLS*              m_ols;
public:
                     COrdReg(void)
     {
      m_ols = new OLS();
     }
                    ~COrdReg(void)
     {
      if(CheckPointer(m_ols) == POINTER_DYNAMIC)
         delete m_ols;
     }
   bool              train(matrix &predictors,matrix& targets)
     {
      return m_ols.Fit(targets.Col(0),predictors);
     }
   double            forecast(vector &predictors)
     {
      return m_ols.Predict(predictors);
     }
  };
//+------------------------------------------------------------------+
//| error variance for classification models                         |
//+------------------------------------------------------------------+
class CErrorVarC:public CErrorVar
  {
public:
                     CErrorVarC(void)
     {
     }
                    ~CErrorVarC(void)
     {
     }

   virtual double    error_fun(const double truevalue,const double predictedvalue)
     {
      if(truevalue*predictedvalue>0.0)
         return 0.0;
      else
         return 1.0;
     }


  };
//---
ulong nreplications, itry, nsamps, nboots, divisor, ndone;
vector computed_err_cv, computed_err_boot, predictions;
vector computed_err_E0, computed_err_E632 ;
double temperr,sum_observed_error, mean_computed_err, var_computed_err,dfactor,dif;
matrix xdata, testdata,trainpreds,traintargs,testpreds,testtargs;
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
   CHighQualityRandStateShell rngstate;
   CHighQualityRand::HQRndRandomize(rngstate.GetInnerObj());
//---
   nboots = NumBootStraps;
   nsamps = NumSamples ;
   nreplications = NumReplications ;
   dfactor = PredictionDifficultyLevel ;

   if((nsamps <= 3)  || (nreplications <= 0) || (dfactor < 0.0) || nboots<=0)
     {
      Alert(" Invalid inputs ");
      return;
     }

   double std = sqrt(dfactor) ;

   divisor = 1000000 / (nsamps * nboots) ;  // This is for progress reports only
   if(divisor < 2)
      divisor = 2 ;

   xdata = matrix::Zeros(nsamps,3);

   sum_observed_error = mean_computed_err = var_computed_err = 0.0;
   computed_err_cv = vector::Zeros(nreplications);
   computed_err_E0 = vector::Zeros(nreplications);
   computed_err_E632 = vector::Zeros(nreplications);
   computed_err_boot = vector::Zeros(nreplications);

   testdata = matrix::Zeros(nsamps*10,3);
   predictions = vector::Zeros(nsamps*10);

   CErrorVarC errorvar;
   COrdReg olsmodel;

   for(ulong irep = 0; irep<nreplications; irep++)
     {
      ndone = irep + 1 ;

      for(ulong i =0; i<nsamps; i++)
        {
         xdata[i][0] = normal(rngstate);
         xdata[i][1] = 0.7071 * xdata[i][0]  +  0.7071 * normal(rngstate);
         if(CAlglib::HQRndUniformR(rngstate)>0.5)
           {
            xdata[i][0] -=dfactor;
            xdata[i][1] +=dfactor;
            xdata[i][2] = 1.0;
           }
         else
           {
            xdata[i][0] +=dfactor;
            xdata[i][1] -=dfactor;
            xdata[i][2] = -1.0;
           }
        }


      for(ulong j =0; j<testdata.Rows(); j++)
        {
         testdata[j][0] = normal(rngstate);
         testdata[j][1] = 0.7071 * testdata[j][0]  +  0.7071 * normal(rngstate);
         if(CAlglib::HQRndUniformR(rngstate)>0.5)
           {
            testdata[j][0] -=dfactor;
            testdata[j][1] +=dfactor;
            testdata[j][2] = 1.0;
           }
         else
           {
            testdata[j][0] +=dfactor;
            testdata[j][1] -=dfactor;
            testdata[j][2] = -1.0;
           }
        }

      trainpreds = np::sliceMatrixCols(xdata,0,2);
      traintargs = np::sliceMatrixCols(xdata,2);

      if(!olsmodel.train(trainpreds,traintargs))
        {
         Print(" fitting first model failed ");
         return;
        }

      testpreds=np::sliceMatrixCols(testdata,0,2);
      testtargs=np::sliceMatrixCols(testdata,2);
      temperr = 0.0;
      for(ulong i = 0;i<testpreds.Rows(); i++)
        {
         predictions[i] = olsmodel.forecast(testpreds.Row(i));
         temperr += errorvar.error_fun(testtargs[i][0],predictions[i]);
        }

      sum_observed_error += temperr/double(10*nsamps);

      if(!errorvar.cross_validation(trainpreds,traintargs,olsmodel,computed_err_cv[irep]) ||
         !errorvar.boot_strap(nboots,trainpreds,traintargs,olsmodel,computed_err_boot[irep]) ||
         !errorvar.efrons_0(nboots,trainpreds,traintargs,olsmodel,computed_err_E0[irep]) ||
         !errorvar.efrons_632(nboots,trainpreds,traintargs,olsmodel,computed_err_E632[irep])
        )
        {
         Print(" error variance calculation failed ");
         return;
        }
     }

   PrintFormat("Number of Iterations %d   Observed error = %.5lf",ndone, sum_observed_error / double(ndone)) ;
//---
   PrintFormat("CV: computed error  mean=%10.5lf      std=%10.5lf",computed_err_cv.Mean(), computed_err_cv.Std()) ;
//---
   PrintFormat("BOOT: computed error  mean=%10.5lf    std=%10.5lf",computed_err_boot.Mean(), computed_err_boot.Std()) ;
//---
   PrintFormat("E0: computed error  mean=%10.5lf      std=%10.5lf",computed_err_E0.Mean(), computed_err_E0.Std()) ;
//---
   PrintFormat("E632: computed error  mean=%10.5lf    std=%10.5lf",computed_err_E632.Mean(), computed_err_E632.Std()) ;
  }
//+--------------------------------------------------------------------+
