//+------------------------------------------------------------------+
//|                                             GMDH_Price_Model.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
#property script_show_inputs
#include <GMDH\combi.mqh>
#include <GMDH\mia.mqh>
#include <GMDH\multi.mqh>
#include <ErrorDescription.mqh>
#include <Graphics\Graphic.mqh>
//+------------------------------------------------------------------+
//|  enumeration of gmdh model type                                  |
//+------------------------------------------------------------------+
enum ENUM_GMDH_MODEL
{
  Combi=0,//COMBI
  Mia,//MIA
  Multi//MULTI
};
//--- input parameters
input string   SetSymbol="";
input ENUM_GMDH_MODEL modelType = Combi;
input datetime TrainingSampleStartDate=D'2019.12.31';
input datetime TrainingSampleStopDate=D'2022.12.31';
input datetime TestSampleStartDate = D'2023.01.01';
input datetime TestSampleStopDate = D'2023.12.31';
input ENUM_TIMEFRAMES tf=PERIOD_D1;    //time frame
input int Numlags = 3;
input CriterionType critType = stab;
input PolynomialType polyType = linear_cov;
input int Average  = 10;
input int NumBest  = 10;
input double DataSplitSize = 0.2;
input double critLimit = 0;
input ulong NumTestSamplesPlot = 20;
//--- x axis for plot
double xaxis[];
//+------------------------------------------------------------------+
//|global integer variables                                          |
//+------------------------------------------------------------------+
int size_outsample,                //testing set size
    size_observations;             //size of of both training and testing sets combined
//+------------------------------------------------------------------+
//|double global variables                                           |
//+------------------------------------------------------------------+
matrix testprices;
vector prices,                   //price series
       predictions;               //predictions    
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//get relative shift of IS and OOS sets
   int trainstart,trainstop, teststart, teststop;
   trainstart=iBarShift(SetSymbol!=""?SetSymbol:NULL,tf,TrainingSampleStartDate);
   trainstop=iBarShift(SetSymbol!=""?SetSymbol:NULL,tf,TrainingSampleStopDate);
   teststart=iBarShift(SetSymbol!=""?SetSymbol:NULL,tf,TestSampleStartDate);
   teststop=iBarShift(SetSymbol!=""?SetSymbol:NULL,tf,TestSampleStopDate);
//check for errors from ibarshift calls
   if( trainstart<0 || trainstop<0 || teststart<0 || teststop<0)
     {
      Print(ErrorDescription(GetLastError()));
      return;
     }
//---set the size of the sample sets
   size_observations=(trainstart - trainstop) + 1 ;
   size_outsample = (teststart - teststop) + 1;
//---check for input errors
   if( size_observations <= 0  || size_outsample<=0 )
     {
      Print("Invalid inputs ");
      return;
     }
//---download insample prices for training
  int try = 10;
  while(!prices.CopyRates(SetSymbol,tf,COPY_RATES_CLOSE,TrainingSampleStartDate,TrainingSampleStopDate) && try)
      {
       try--;
       if(!try)
         {
          Print("error copying to prices  ",GetLastError());
          return;
         }
       Sleep(5000);   
      } 
//---download out of sample prices testing
   try = 10;
  while(!testprices.CopyRates(SetSymbol,tf,COPY_RATES_CLOSE|COPY_RATES_TIME|COPY_RATES_VERTICAL,TestSampleStartDate,TestSampleStopDate) && try)
      {
       try--;
       if(!try)
         {
          Print("error copying to testprices  ",GetLastError());
          return;
         } 
       Sleep(5000);  
      } 
//--- resize vector of predictions             
 predictions.Resize(testprices.Rows()-Numlags);
//--- train and make predictions 
   switch(modelType)
    {
      case Combi:
       {
        COMBI combi;
        if(!combi.fit(prices,Numlags,DataSplitSize,critType))
           return;
        Print("Model ", combi.getBestPolynomial()); 
        MakePredictions(combi,testprices.Col(0),predictions);
       }     
       break;
      
      case Mia:
       {
        MIA mia;
        if(!mia.fit(prices,Numlags,DataSplitSize,polyType,critType,NumBest,Average,critLimit))
           return;
        Print("Model ", mia.getBestPolynomial());
        MakePredictions(mia,testprices.Col(0),predictions);
       } 
      break;
      
      case Multi:
       {
        MULTI multi;
        if(!multi.fit(prices,Numlags,DataSplitSize,critType,NumBest,Average,critLimit))
           return;
        Print("Model ", multi.getBestPolynomial());
        MakePredictions(multi,testprices.Col(0),predictions);
       }  
      break;
      
      default: 
       Print("Invalid GMDH model type ");
      return;
    }
//--- 
   ulong TestSamplesPlot = (NumTestSamplesPlot>0)?NumTestSamplesPlot:20;
//---   
   if(NumTestSamplesPlot>=testprices.Rows())
       TestSamplesPlot = testprices.Rows()-Numlags;
//---         
   vector testsample(100,slice,testprices.Col(0),Numlags,Numlags+TestSamplesPlot-1);
   vector testpredictions(100,slice,predictions,0,TestSamplesPlot-1);
   vector dates(100,slice,testprices.Col(1),Numlags,Numlags+TestSamplesPlot-1);
//---  
   //Print(testpredictions.Size(), ":", testsample.Size());
//---  
   double y[], y_hat[];
//---  
   if(vecToArray(testpredictions,y_hat) && vecToArray(testsample,y) && vecToArray(dates,xaxis))
    {
      PlotPrices(y_hat,y);
    }
//---
   ChartRedraw();         
  }
//+------------------------------------------------------------------+
//| Plot price predictions curves                                    |
//+------------------------------------------------------------------+
void PlotPrices(double &predicted_values[],double &true_values[],int displaytime_seconds=20)
{
//---
   long chart=0;
   string name="GMDH price model";
//---    
   ChartSetInteger(0,CHART_SHOW,false);  
//--
   CGraphic graphic;
   if(ObjectFind(chart,name)<0)
       graphic.Create(chart,name,0,0,0,700,500);
   else
      graphic.Attach(chart,name);
//---
   graphic.BackgroundMainColor(ColorToARGB(clrBlack));
   graphic.BackgroundMainSize(30);
   graphic.BackgroundMain(EnumToString(modelType)+" GMDH Price Model");
//---
   graphic.HistoryNameSize(15);
   graphic.HistoryNameWidth(100);   
//---   
   graphic.CurveAdd(true_values,ColorToARGB(clrBlue),CURVE_POINTS_AND_LINES,"True Prices");
   graphic.CurveAdd(predicted_values,ColorToARGB(clrRed),CURVE_POINTS_AND_LINES,"Predicted Prices");
//--- get the Y-axis
   CAxis *yAxis = graphic.YAxis();
   yAxis.NameSize(15);
   yAxis.Name("Price");
//--- get the X-axis
   CAxis *xAxis=graphic.XAxis();
   xAxis.NameSize(15);
   xAxis.Name("Date");
//--- sets the X-axis properties
   xAxis.AutoScale(false);
   xAxis.Type(AXIS_TYPE_CUSTOM);
   xAxis.ValuesFunctionFormat(TimeFormat);
   xAxis.DefaultStep(4.0);
//---
   graphic.CurvePlotAll();
//---
   graphic.Update();
//---
   Sleep(displaytime_seconds*1000);
//--- 
   ChartSetInteger(0,CHART_SHOW,true);
//---
   graphic.Destroy();
//---
   ChartRedraw(); 
//--- 

}
//+------------------------------------------------------------------+
//| Make out of sample predictions                                   |
//+------------------------------------------------------------------+
void MakePredictions(GmdhModel &model, vector &inputs, vector &output)
{
  for(ulong i = 0; i<inputs.Size()-Numlags; i++)
        {
         vector in(ulong(Numlags),slice,inputs,i,i+Numlags-1);
         vector prediction = model.predict(in,1);
         output[i]=prediction[0]; 
        }
}
//+------------------------------------------------------------------+
//| Custom function for create values on X-axis                      |
//+------------------------------------------------------------------+
string TimeFormat(double x,void *cbdata)
  {
   return(TimeToString((datetime)xaxis[(int)x]));
  }
//+------------------------------------------------------------------+