//+------------------------------------------------------------------+
//|                                                   roc_curves.mqh |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#include<np.mqh>
//+------------------------------------------------------------------+
//|  confusion matrix stats                                          |
//+------------------------------------------------------------------+
struct conf_stats
  {
   double              tn;                //true negatives
   double              tp;                //true positives
   double              fn;                //false negatives
   double              fp;                //false positives
   double              num_targets;       //number of actual positive labels(target)
   double              num_non_targets;   //number of acutal negative labels(non targets)
   double            tp_rate;             //true positives rate - hit rate -  recall - sensitivity
   double            fp_rate;             //false positives rate - fall out - type 1 error
   double            fn_rate;             //false negatives rate - miss rate - type 2 error
   double            tn_rate;             //true negatives rate - specificity
   double            precision;           //precision - positive predictve value
   double            null_precision;      //null precision - false discovery rate
   double            prevalence;          //prevalence
   double            lr_plus;             //positive likelihood ratio
   double            lr_neg;              //negative likelihood ratio
   double            for_rate;            //false omission rate
   double            npv;                 //negative predictive value
   double            acc;                 //accuracy
   double            b_acc;               //balanced accuracy
   double            f1_score;            //f1 score
   double            mean_error;          //mean error
  };
//+------------------------------------------------------------------+
//| defines                                                          |
//+------------------------------------------------------------------+
bool roc_stats(conf_stats &cmat, vector &targets, vector &probas, double threshold, long target_label = 1, long non_target_label= 0)
  {
   vector all_labels = np::unique(targets);

   if(all_labels.Size()!=2 || long(all_labels[all_labels.ArgMin()])!=non_target_label || long(all_labels[all_labels.ArgMax()])!=target_label || target_label<=non_target_label)
     {
      Print(__FUNCTION__, " ", __LINE__, " invalid inputs ");
      return false;
     }
//---
   cmat.tp=cmat.fn=cmat.tn=cmat.fp = 0.0;
//---
   for(ulong i = 0; i<targets.Size(); i++)
     {
      if(probas[i]>=threshold && long(targets[i]) == target_label)
         cmat.tp++;
      else
         if(probas[i]>=threshold && long(targets[i]) == non_target_label)
            cmat.fp++;
         else
            if(probas[i]<threshold && long(targets[i]) == target_label)
               cmat.fn++;
            else
               cmat.tn++;
     }
//---
   cmat.num_targets = cmat.tp+cmat.fn;
   cmat.num_non_targets = cmat.fp+cmat.tn;
//---
   cmat.tp_rate = (cmat.tp+cmat.fn>0.0)?(cmat.tp/(cmat.tp+cmat.fn)):double("na");
   cmat.fp_rate = (cmat.tn+cmat.fp>0.0)?(cmat.fp/(cmat.tn+cmat.fp)):double("na");
   cmat.fn_rate = (cmat.tp+cmat.fn>0.0)?(cmat.fn/(cmat.tp+cmat.fn)):double("na");
   cmat.tn_rate = (cmat.tn+cmat.fp>0.0)?(cmat.tn/(cmat.tn+cmat.fp)):double("na");
   cmat.precision = (cmat.tp+cmat.fp>0.0)?(cmat.tp/(cmat.tp+cmat.fp)):double("na");
   cmat.null_precision = 1.0 - cmat.precision;
   cmat.for_rate = (cmat.tn+cmat.fn>0.0)?(cmat.fn/(cmat.tn+cmat.fn)):double("na");
   cmat.npv = 1.0 - cmat.for_rate;
   cmat.lr_plus = (cmat.fp_rate>0.0)?(cmat.tp_rate/cmat.fp_rate):double("na");
   cmat.lr_neg = (cmat.tn_rate>0.0)?(cmat.fn_rate/cmat.tn_rate):double("na");
   cmat.prevalence = (cmat.num_non_targets+cmat.num_targets>0.0)?(cmat.num_targets/(cmat.num_non_targets+cmat.num_targets)):double("na");
   cmat.acc = (cmat.num_non_targets+cmat.num_targets>0.0)?((cmat.tp+cmat.tn)/(cmat.num_non_targets+cmat.num_targets)):double("na");
   cmat.b_acc = ((cmat.tp_rate+cmat.tn_rate)/2.0);
   cmat.f1_score = (cmat.tp+cmat.fp+cmat.fn>0.0)?((2.0*cmat.tp)/(2.0*cmat.tp+cmat.fp+cmat.fn)):double("na");
   cmat.mean_error = ((cmat.fp_rate+cmat.fn_rate)/2.0);
//---
   return true;
//---
  }
//+------------------------------------------------------------------+
//| roc table                                                        |
//+------------------------------------------------------------------+
matrix roc_table(vector &true_targets,matrix &probas,ulong target_probs_col = 1, long target_label = 1, long non_target_label= 0)
  {
   matrix roctable(probas.Rows(),10);

   conf_stats mts;

   vector probs = probas.Col(target_probs_col);

   if(!np::quickSort(probs,false,0,probs.Size()-1))
      return matrix::Zeros(1,1);

   for(ulong i = 0; i<roctable.Rows(); i++)
     {
      if(!roc_stats(mts,true_targets,probas.Col(target_probs_col),probs[i],target_label,non_target_label))
         return matrix::Zeros(1,1);
      roctable[i][0] = mts.tp_rate;
      roctable[i][1] = mts.fp_rate;
      roctable[i][2] = mts.fn_rate;
      roctable[i][3] = mts.tn_rate;
      roctable[i][4] = mts.precision;
      roctable[i][5] = mts.null_precision;
      roctable[i][6] = mts.mean_error;
      roctable[i][7] = mts.acc;
      roctable[i][8] = mts.b_acc;
      roctable[i][9] = probs[i];
     }
//---
   return roctable;
  }
//+------------------------------------------------------------------+
//| roc curve table display                                          |
//+------------------------------------------------------------------+
string roc_table_display(matrix &out)
  {
   string output,temp;

   if(out.Rows()>=10)
     {
      output = "TPR   FPR   FNR   TNR   PREC  NPREC  M_E   ACC   B_ACC THRESH";
      for(ulong i = 0; i<out.Rows(); i++)
        {
         temp = StringFormat("\n%5.3lf %5.3lf %5.3lf %5.3lf %5.3lf %5.3lf %5.3lf %5.3lf %5.3lf %5.5lf",
                             out[i][0],out[i][1],out[i][2],out[i][3],out[i][4],out[i][5],out[i][6],out[i][7],out[i][8],out[i][9]);
         StringAdd(output,temp);
        }
     }
   return output;
  }
//+------------------------------------------------------------------+
//|  area under the curve                                            |
//+------------------------------------------------------------------+
double roc_auc(vector &true_classes, matrix &predicted_probs,long target_label=1,double max_fpr=1.0)
  {
   vector all_labels = np::unique(true_classes);

   if(all_labels.Size()!=2|| max_fpr<=0.0 || max_fpr>1.0)
     {
      Print(__FUNCTION__, " ", __LINE__, " invalid inputs ");
      return EMPTY_VALUE;
     }

   if(max_fpr == 1.0)
     {
      vector auc = true_classes.ClassificationScore(predicted_probs,CLASSIFICATION_ROC_AUC,AVERAGE_BINARY);
      return auc[0];
     }

   matrix tpr,fpr,threshs;

   if(!true_classes.ReceiverOperatingCharacteristic(predicted_probs,AVERAGE_BINARY,fpr,tpr,threshs))
     {
      Print(__FUNCTION__, " ", __LINE__, " invalid inputs ");
      return EMPTY_VALUE;
     }

   vector xp(1);
   xp[0] = max_fpr;
   vector stop;

   if(!np::searchsorted(fpr.Row(0),xp,true,stop))
     {
      Print(__FUNCTION__, " ", __LINE__, " searchsorted failed ");
      return EMPTY_VALUE;
     }

   vector xpts(2);
   vector ypts(2);

   xpts[0] = fpr[0][long(stop[0])-1];
   xpts[1] = fpr[0][long(stop[0])];
   ypts[0] = tpr[0][long(stop[0])-1];
   ypts[1] = tpr[0][long(stop[0])];

   vector vtpr = tpr.Row(0);
   vtpr = np::sliceVector(vtpr,0,long(stop[0]));
   vector vfpr = fpr.Row(0);
   vfpr = np::sliceVector(vfpr,0,long(stop[0]));

   if(!vtpr.Resize(vtpr.Size()+1) || !vfpr.Resize(vfpr.Size()+1))
     {
      Print(__FUNCTION__, " ", __LINE__, " error  ", GetLastError());
      return EMPTY_VALUE;
     }

   vfpr[vfpr.Size()-1] = max_fpr;

   vector yint = np::interp(xp,xpts,ypts);


   vtpr[vtpr.Size()-1] = yint[0];

   double direction = 1.0;

   vector dx = np::diff(vfpr);

   if(dx[dx.ArgMin()]<0.0)
      direction = -1.0;

   double partial_auc = direction*np::trapezoid(vtpr,vfpr);

   if(partial_auc == EMPTY_VALUE)
     {
      Print(__FUNCTION__, " ", __LINE__, " trapz failed ");
      return EMPTY_VALUE;
     }

   double minarea = 0.5*(max_fpr*max_fpr);
   double maxarea = max_fpr;

   return 0.5*(1+(partial_auc-minarea)/(maxarea-minarea));
  }
//+------------------------------------------------------------------+

//+------------------------------------------------------------------+
