//+------------------------------------------------------------------+
//|                                        ROC_curves_table_demo.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
#property script_show_inputs
#include<logistic.mqh>
#include<roc_curves.mqh>
#include<Generic/SortedSet.mqh>
//---
enum IRIS_TARGET
 {
  Setosa = 0,//Setosa
  Versicolor,//Versicolor
  Virginica//Virginica
 };
//--- input parameters
input double   Train_Test_Split = 0.5;
input int      Random_Seed = 125;
input IRIS_TARGET Target_class = Setosa;
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//---
   CHighQualityRandStateShell rngstate;
   CHighQualityRand::HQRndSeed(Random_Seed,Random_Seed+Random_Seed,rngstate.GetInnerObj());
//---   
   matrix data = np::readcsv("iris.csv");
   data = np::sliceMatrixCols(data,1);
//---   
   ulong rindices[],trainset[],testset[];
   np::arange(rindices,int(data.Rows()));
//---
   long from, to;
   from = 50*long(Target_class);
   to  = from + 50;
//---
   if(!np::matrixFill(data,0.0,0,long(data.Rows()),1,4) || !np::matrixFill(data,1.0,from,to,1,4))
    {
     Print(" matrixFill() error ");
     return;
    }
//---   
   if(!np::shuffleArray(rindices,GetPointer(rngstate)) || ArrayCopy(trainset,rindices,0,0,int(ceil(Train_Test_Split*rindices.Size())))<0 || !ArraySort(trainset))
    {
     Print(__LINE__ , "  error ", GetLastError());
     return;
    }
//---
   CSortedSet<ulong> test_set(rindices);
//--- 
   test_set.ExceptWith(trainset);
//---
   test_set.CopyTo(testset);
//---
   matrix testdata = np::selectMatrixRows(data,testset);
   matrix test_predictors = np::sliceMatrixCols(testdata,0,4);
   vector test_targets = testdata.Col(4);
   matrix traindata = np::selectMatrixRows(data,trainset);  
   matrix train_preditors = np::sliceMatrixCols(traindata,0,4);
   vector train_targets = traindata.Col(4);
//---   
   logistic::Clogit logit;
//--
   if(!logit.fit(train_preditors,train_targets))
     {
      Print(" error training logistic model ");
      return;
     }
//---
   matrix y_probas = logit.probas(test_predictors);
   vector y_preds = logit.predict(test_predictors);
//---
   matrix roc_curve_table = roc_table(test_targets,y_probas);
   Print(roc_table_display(roc_curve_table));
//---
  }
//+------------------------------------------------------------------+
