//+------------------------------------------------------------------+
//|                                                     ROC_Demo.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
#property script_show_inputs
#include<logistic.mqh>
#include<ErrorDescription.mqh>
#include<Generic/SortedSet.mqh>
//---
enum CLASSIFICATION_TYPE
  {
   BINARY_CLASS = 0,//binary classification problem
   MULITI_CLASS//multiclass classification problem
  };
//--- input parameters
input double   Train_Test_Split = 0.5;
input int      Random_Seed = 125;
input CLASSIFICATION_TYPE classification_problem = BINARY_CLASS;
input ENUM_AVERAGE_MODE av_mode = AVERAGE_BINARY;
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//---
   CHighQualityRandStateShell rngstate;
   CHighQualityRand::HQRndSeed(Random_Seed,Random_Seed+Random_Seed,rngstate.GetInnerObj());
//---
   matrix data = np::readcsv("iris.csv");
   data = np::sliceMatrixCols(data,1);
//---
   if(classification_problem == BINARY_CLASS)
      data = np::sliceMatrixRows(data,0,100);
//---
   ulong rindices[],trainset[],testset[];
   np::arange(rindices,int(data.Rows()));
//---
//---
   if(!np::shuffleArray(rindices,GetPointer(rngstate)) || ArrayCopy(trainset,rindices,0,0,int(ceil(Train_Test_Split*rindices.Size())))<0 || !ArraySort(trainset))
     {
      Print(__LINE__, "  error ", ErrorDescription(GetLastError()));
      return;
     }
//---
   CSortedSet<ulong> test_set(rindices);
//---
   test_set.ExceptWith(trainset);
//---
   test_set.CopyTo(testset);
//---
   matrix testdata = np::selectMatrixRows(data,testset);
   matrix test_predictors = np::sliceMatrixCols(testdata,0,4);
   vector test_targets = testdata.Col(4);
   matrix traindata = np::selectMatrixRows(data,trainset);
   matrix train_preditors = np::sliceMatrixCols(traindata,0,4);
   vector train_targets = traindata.Col(4);
//---
   logistic::Clogit logit;
//--
   if(!logit.fit(train_preditors,train_targets))
     {
      Print(" error training logistic model ");
      return;
     }
//---
   matrix y_probas = logit.probas(test_predictors);
   vector y_preds = logit.predict(test_predictors);
//---
   vector auc = test_targets.ClassificationScore(y_probas,CLASSIFICATION_ROC_AUC,av_mode);
//---
   if(auc.Size()>0)
      Print(" AUC ", auc);
   else
      Print(" AUC error ", ErrorDescription(GetLastError()));
//---
   matrix fpr,tpr,threshs;
   if(!test_targets.ReceiverOperatingCharacteristic(y_probas,av_mode,fpr,tpr,threshs))
     {
      Print(" ROC error ", ErrorDescription(GetLastError()));
      return;
     }
//---
   string legend;
   for(ulong i = 0; i<auc.Size(); i++)
     {
      string temp = (i!=int(auc.Size()-1))?StringFormat("%.3lf,",auc[i]):StringFormat("%.3lf",auc[i]);
      StringAdd(legend,temp);
     }

   CGraphic* roc = np::plotMatrices(fpr, tpr,"ROC",false,"FPR","TPR",legend,true,0,0,10,10,600,500);
   if(CheckPointer(roc)!=POINTER_INVALID)
     {
      Sleep(7000);
      roc.Destroy();
      delete roc;
      ChartRedraw();
     }
  }
//+------------------------------------------------------------------+
