//+------------------------------------------------------------------+
//|                                                        lfspy.mqh |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#include<JAson.mqh>
#include<Files/FileTxt.mqh>
#include<np.mqh>
//+------------------------------------------------------------------+
//|structure of model parameters                                     |
//+------------------------------------------------------------------+
struct LFS_PARAMS
{
 int alpha;
 int tau;
 int n_beta;
 int nrrp;
 int knn;
 int rr_seed;
 int sigma;
 ulong num_features;
 double gamma;
};
//+------------------------------------------------------------------+
//|  class encapsulates LFSpy model                                  |
//+------------------------------------------------------------------+
class Clfspy
  {
private:
   bool              loaded;
   LFS_PARAMS        model_params;
   matrix train_data,
          fstar;
   vector train_labels;
   //+------------------------------------------------------------------+
   //|  helper function for parsing model from file                     |
   //+------------------------------------------------------------------+
   bool              fromJSON(CJAVal &jsonmodel)
     {

      model_params.alpha = (int)jsonmodel["alpha"].ToInt();
      model_params.tau = (int)jsonmodel["tau"].ToInt();
      model_params.sigma = (int)jsonmodel["sigma"].ToInt();
      model_params.n_beta = (int)jsonmodel["n_beta"].ToInt();
      model_params.nrrp = (int)jsonmodel["nrrp"].ToInt();
      model_params.knn = (int)jsonmodel["knn"].ToInt();
      model_params.rr_seed = (int)jsonmodel["rr_seed"].ToInt();
      model_params.gamma = jsonmodel["gamma"].ToDbl();
      
      ulong observations = (ulong)jsonmodel["num_observations"].ToInt();
      model_params.num_features = (ulong)jsonmodel["num_features"].ToInt();

      if(!train_data.Resize(model_params.num_features,observations) || !train_labels.Resize(observations) ||
         !fstar.Resize(model_params.num_features,observations))
        {
         Print(__FUNCTION__, " error ", GetLastError());
         return false;
        }


      for(int i=0; i<int(model_params.num_features); i++)
        {
         for(int j = 0; j<int(observations); j++)
            {
             if(i==0)
                train_labels[j] = jsonmodel["training_labels"][j].ToDbl();
             train_data[i][j] = jsonmodel["training_data"][i][j].ToDbl();
             fstar[i][j] = jsonmodel["fstar"][i][j].ToDbl();
            } 
        }

      return true;
     }
   //+------------------------------------------------------------------+
   //| helper classification function                                   |
   //+------------------------------------------------------------------+
   matrix            classification(matrix &testing_data)
     {
      int N = int(train_labels.Size());
      int H = int(testing_data.Cols());

      matrix out(H,2);

      for(int i = 0; i<H; i++)
        {
         vector column = testing_data.Col(i);
         vector result = class_sim(column,train_data,train_labels,fstar,model_params.gamma,model_params.knn);
         if(!out.Row(result,i))
           {
            Print(__FUNCTION__, " row insertion failure ", GetLastError());
            return matrix::Zeros(1,1);
           }
        }

      return out;
     }
   //+------------------------------------------------------------------+
   //| internal feature classification function                         |
   //+------------------------------------------------------------------+
   vector            class_sim(vector &test,matrix &patterns,vector& targets, matrix &f_star, double gamma, int knn)
     {
      int N = int(targets.Size());
      int n_nt_cls_1 = (int)targets.Sum();
      int n_nt_cls_2 = N - n_nt_cls_1;
      int M = int(patterns.Rows());
      int NC1 = 0;
      int NC2 = 0;
      vector S = vector::Zeros(N);

      S.Fill(double("inf"));

      vector NoNNC1knn = vector::Zeros(N);
      vector NoNNC2knn = vector::Zeros(N);
      vector NoNNC1 = vector::Zeros(N);
      vector NoNNC2 = vector::Zeros(N);
      vector radious = vector::Zeros(N);
      double r = 0;
      int k = 0;
      for(int i = 0; i<N; i++)
        {
         vector fs = f_star.Col(i);
         matrix xpatterns = patterns * np::repeat_vector_as_rows_cols(fs,patterns.Cols(),false);
         vector testpr = test * fs;
         vector mtestpr = (-1.0 * testpr);
         matrix testprmat = np::repeat_vector_as_rows_cols(mtestpr,xpatterns.Cols(),false);
         vector dist = MathAbs(sqrt((pow(testprmat + xpatterns,2.0)).Sum(0)));
         vector min1 = dist;
         np::sort(min1);
         vector min_uniq = np::unique(min1);
         int m = -1;
         int no_nereser = 0;
         vector NN(dist.Size());
         while(no_nereser<int(knn))
           {
            m+=1;
            double a1  = min_uniq[m];
            for(ulong j = 0; j<dist.Size(); j++)
               NN[j]=(dist[j]<=a1)?1.0:0.0;
            no_nereser = (int)NN.Sum();
           }
         vector bitNN = np::bitwiseAnd(NN,targets);
         vector Not = np::bitwiseNot(targets);
         NoNNC1knn[i] = bitNN.Sum();
         bitNN = np::bitwiseAnd(NN,Not);
         NoNNC2knn[i] = bitNN.Sum();
         vector A(fs.Size());
         for(ulong v =0; v<A.Size(); v++)
             A[v] = (fs[v]==0.0)?1.0:0.0;
         vector f1(patterns.Cols());
         vector f2(patterns.Cols());
         if(A.Sum()<double(M))
           {
            for(ulong v =0; v<A.Size(); v++)
             A[v] = (A[v]==1.0)?0.0:1.0;
            matrix amask = matrix::Ones(patterns.Rows(), patterns.Cols());
            amask *= np::repeat_vector_as_rows_cols(A,patterns.Cols(),false);
            matrix patternsp = patterns*amask;
            vector testp = test*(amask.Col(0));
            vector testa = patternsp.Col(i) - testp;
            vector col = patternsp.Col(i);
            matrix colmat = np::repeat_vector_as_rows_cols(col,patternsp.Cols(),false);
            double Dist_test = MathAbs(sqrt((pow(col - testp,2.0)).Sum()));
            vector Dist_pat  = MathAbs(sqrt((pow(patternsp - colmat,2.0)).Sum(0)));
            vector eerep = Dist_pat; 
             np::sort(eerep);
            int remove = 0;
            if(targets[i] == 1.0)
              {
               vector unq = np::unique(eerep);
               k = -1;
               NC1+=1;
               if(remove!=1)
                 {
                  int Next = 1;
                  while(Next == 1)
                    {
                     k+=1;
                     r = unq[k];
                     for(ulong j = 0; j<Dist_pat.Size(); j++)
                       {
                        if(Dist_pat[j] == r)
                           f1[j] = 1.0;
                        else
                           f1[j] = 0.0;
                        if(Dist_pat[j]<=r)
                           f2[j] = 1.0;
                        else
                           f2[j] = 0.0;
                       }
                     vector f2t = np::bitwiseAnd(f2,targets);
                     vector tn = np::bitwiseNot(targets);
                     vector f2tn = np::bitwiseAnd(f2,tn);
                     double nocls1clst = f2t.Sum() - 1.0;
                     double nocls2clst = f2tn.Sum();
                     if(gamma *(nocls1clst/double(n_nt_cls_1-1)) < (nocls2clst/(double(n_nt_cls_2))))
                       {
                        Next = 0 ;
                        if((k-1) == 0)
                           r = unq[k];
                        else
                           r = 0.5 * (unq[k-1] + unq[k]);
                        if(r==0.0)
                           r = pow(10.0,-6.0);
                        r = 1.0*r;
                        for(ulong j = 0; j<Dist_pat.Size(); j++)
                          {
                           if(Dist_pat[j]<=r)
                              f2[j] = 1.0;
                           else
                              f2[j] = 0.0;
                          }
                        f2t = np::bitwiseAnd(f2,targets);
                        f2tn = np::bitwiseAnd(f2,tn);
                        nocls1clst = f2t.Sum() - 1.0;
                        nocls2clst = f2tn.Sum();
                       }
                    }
                  if(Dist_test<r)
                    {
                     patternsp = patterns * np::repeat_vector_as_rows_cols(fs,patterns.Cols(),false);
                     testp = test * fs;
                     dist = MathAbs(sqrt((pow(patternsp - np::repeat_vector_as_rows_cols(testp,patternsp.Cols(),false),2.0)).Sum(0)));
                     min1 = dist;
                     np::sort(min1);
                     min_uniq = np::unique(min1);
                     m = -1;
                     no_nereser = 0;
                     while(no_nereser<int(knn))
                       {
                        m+=1;
                        double a1  = min_uniq[m];
                        for(ulong j = 0; j<dist.Size(); j++)
                           NN[j]=(dist[j]<a1)?1.0:0.0;
                        no_nereser = (int)NN.Sum();
                       }
                     bitNN = np::bitwiseAnd(NN,targets);
                     Not = np::bitwiseNot(targets);
                     NoNNC1[i] = bitNN.Sum();
                     bitNN = np::bitwiseAnd(NN,Not);
                     NoNNC2[i] = bitNN.Sum();
                     if(NoNNC1[i]>NoNNC2[i])
                        S[i] = 1.0;
                    }
                 }
              }
            if(targets[i] == 0.0)
              {
               vector unq = np::unique(eerep);
               k=-1;
               NC2+=1;
               int Next;
               if(remove!=1)
                 {
                  Next =1;
                  while(Next==1)
                    {
                     k+=1;
                     r = unq[k];
                     for(ulong j = 0; j<Dist_pat.Size(); j++)
                       {
                        if(Dist_pat[j] == r)
                           f1[j] = 1.0;
                        else
                           f1[j] = 0.0;
                        if(Dist_pat[j]<=r)
                           f2[j] = 1.0;
                        else
                           f2[j] = 0.0;
                       }
                     vector f2t = np::bitwiseAnd(f2,targets);
                     vector tn = np::bitwiseNot(targets);
                     vector f2tn = np::bitwiseAnd(f2,tn);
                     double nocls1clst = f2t.Sum() ;
                     double nocls2clst = f2tn.Sum() -1.0;
                     if(gamma *(nocls2clst/double(n_nt_cls_2-1)) < (nocls1clst/(double(n_nt_cls_1))))
                       {
                        Next = 0 ;
                        if((k-1) == 0)
                           r = unq[k];
                        else
                           r = 0.5 * (unq[k-1] + unq[k]);
                        if(r==0.0)
                           r = pow(10.0,-6.0);
                        r = 1.0*r;
                        for(ulong j = 0; j<Dist_pat.Size(); j++)
                          {
                           if(Dist_pat[j]<=r)
                              f2[j] = 1.0;
                           else
                              f2[j] = 0.0;
                          }
                        f2t = np::bitwiseAnd(f2,targets);
                        f2tn = np::bitwiseAnd(f2,tn);
                        nocls1clst = f2t.Sum();
                        nocls2clst = f2tn.Sum() -1.0;
                       }
                    }
                  if(Dist_test<r)
                    {
                     patternsp = patterns * np::repeat_vector_as_rows_cols(fs,patterns.Cols(),false);
                     testp = test * fs;
                     dist = MathAbs(sqrt((pow(patternsp - np::repeat_vector_as_rows_cols(testp,patternsp.Cols(),false),2.0)).Sum(0)));
                     min1 = dist;
                     np::sort(min1);
                     min_uniq = np::unique(min1);
                     m = -1;
                     no_nereser = 0;
                     while(no_nereser<int(knn))
                       {
                        m+=1;
                        double a1  = min_uniq[m];
                        for(ulong j = 0; j<dist.Size(); j++)
                           NN[j]=(dist[j]<a1)?1.0:0.0;
                        no_nereser = (int)NN.Sum();
                       }
                     bitNN = np::bitwiseAnd(NN,targets);
                     Not = np::bitwiseNot(targets);
                     NoNNC1[i] = bitNN.Sum();
                     bitNN = np::bitwiseAnd(NN,Not);
                     NoNNC2[i] = bitNN.Sum();
                     if(NoNNC2[i]>NoNNC1[i])
                        S[i] = 1.0;
                    }
                 }
              }
           }
         radious[i] = r;
        }
      vector q1 = vector::Zeros(N);
      vector q2 = vector::Zeros(N);
      for(int i = 0; i<N; i++)
        {
         if(NoNNC1[i] > NoNNC2knn[i])
            q1[i] = 1.0;
         if(NoNNC2[i] > NoNNC1knn[i])
            q2[i] = 1.0;
        }

      vector ntargs = np::bitwiseNot(targets);
      vector c1 = np::bitwiseAnd(q1,targets);
      vector c2 = np::bitwiseAnd(q2,ntargs);

      double sc1 = c1.Sum()/NC1;
      double sc2 = c2.Sum()/NC2;

      if(sc1==0.0 && sc2==0.0)
        {
         q1.Fill(0.0);
         q2.Fill(0.0);

         for(int i = 0; i<N; i++)
           {
            if(NoNNC1knn[i] > NoNNC2knn[i])
               q1[i] = 1.0;
            if(NoNNC2knn[i] > NoNNC1knn[i])
               q2[i] = 1.0;
               
            if(!targets[i])
               ntargs[i] = 1.0;
            else
               ntargs[i] = 0.0;
           }
         
         c1 = np::bitwiseAnd(q1,targets);
         c2 = np::bitwiseAnd(q2,ntargs);

         sc1 = c1.Sum()/NC1;
         sc2 = c2.Sum()/NC2;
        }

      vector out(2);

      out[0] = sc1;
      out[1] = sc2;

      return out;
     }
public:
   //+------------------------------------------------------------------+
   //|    constructor                                                   |
   //+------------------------------------------------------------------+
                     Clfspy(void)
     {
      loaded = false;
     }
   //+------------------------------------------------------------------+
   //|  destructor                                                      |
   //+------------------------------------------------------------------+
                    ~Clfspy(void)
     {
     }
   //+------------------------------------------------------------------+
   //|  load a LFSpy trained model from file                            |
   //+------------------------------------------------------------------+
   bool              load(const string file_name, bool FILE_IN_COMMON_DIRECTORY = false)
     {
      loaded = false;
      CFileTxt modelFile;
      CJAVal js;
      ResetLastError();
      if(modelFile.Open(file_name,FILE_IN_COMMON_DIRECTORY?FILE_READ|FILE_COMMON:FILE_READ,0)==INVALID_HANDLE)
        {
         Print(__FUNCTION__," failed to open file ",file_name," .Error - ",::GetLastError());
         return false;
        }
      else
        {
         if(!js.Deserialize(modelFile.ReadString()))
           {
            Print("failed to read from ",file_name,".Error -",::GetLastError());
            return false;
           }
         loaded = fromJSON(js);
        }
      return loaded;
     }
   //+------------------------------------------------------------------+
   //|   make a prediction based specific inputs                        |
   //+------------------------------------------------------------------+
   vector            predict(matrix &inputs)
     {
      if(!loaded)
        {
         Print(__FUNCTION__, " No model available, Load a model first before calling this method ");
         return vector::Zeros(1);
        }

      if(inputs.Cols()!=train_data.Rows())
        {
         Print(__FUNCTION__, " input matrix does np::bitwiseNot match with shape of expected model inputs (columns)");
         return vector::Zeros(1);
        }

      matrix testdata = inputs.Transpose();

      matrix probs = classification(testdata);
      vector classes = vector::Zeros(probs.Rows());

      for(ulong i = 0; i<classes.Size(); i++)
         if(probs[i][0] > probs[i][1])
            classes[i] = 1.0;

      return classes;

     }
   //+------------------------------------------------------------------+
   //| get the parameters of the loaded model                           |
   //+------------------------------------------------------------------+
   LFS_PARAMS      getmodelparams(void)
     {
      return model_params;
     }
   

  };
//+------------------------------------------------------------------+
