//+------------------------------------------------------------------+
//|                                                       mlffnn.mqh |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
//+------------------------------------------------------------------+
//|class for a basic feedforward neural network                      |
//+------------------------------------------------------------------+
class FFNN
  {
protected:

   bool              m_trained;             // flag noting if neural net successfully trained
   matrix            m_weights[];           // layer weights
   matrix            m_outputs[];           // hidden layer outputs
   matrix            m_result;              // training result
   uint              m_epochs;              // number of epochs
   ulong             m_num_inputs;          // number of input variables for nn
   ulong             m_layers;              // number of layers of neural net
   ulong             m_hidden_layers;       // number of hidden layers
   ulong             m_hidden_layer_size[]; //node config for layers
   double            m_learn_rate;          //learning rate
   ENUM_ACTIVATION_FUNCTION m_act_fn;       //activation function
   //+------------------------------------------------------------------+
   //| initialize the neural network structrue                          |
   //+------------------------------------------------------------------+

   virtual bool      create(void)
     {
      if(m_layers - m_hidden_layers != 1)
        {
         Print(__FUNCTION__,"  Network structure misconfiguration ");
         return false;
        }

      for(ulong i = 0; i<m_layers; i++)
        {
         if(i==0)
           {
            if(!m_weights[i].Init(m_num_inputs+1,m_hidden_layer_size[i]))
              {
               Print(__FUNCTION__," ",__LINE__," ", GetLastError());
               return false;
              }
           }
         else
            if(i == m_layers-1)
              {
               if(!m_weights[i].Init(m_hidden_layer_size[i-1]+1,1))
                 {
                  Print(__FUNCTION__," ",__LINE__," ", GetLastError());
                  return false;
                 }
              }
            else
              {
               if(!m_weights[i].Init(m_hidden_layer_size[i-1]+1,m_hidden_layer_size[i]))
                 {
                  Print(__FUNCTION__," ",__LINE__," ", GetLastError());
                  return false;
                 }
              }
        }

      return true;
     }
   //+------------------------------------------------------------------+
   //| calculate output  from all layers                                |
   //+------------------------------------------------------------------+

   virtual matrix    calculate(matrix &data)
     {
      if(data.Cols() != m_weights[0].Rows()-1)
        {
         Print(__FUNCTION__," input data not compatible with network structure ");
         return matrix::Zeros(0,0);
        }

      matrix temp = data;

      for(ulong i = 0; i<m_hidden_layers; i++)
        {
         if(!temp.Resize(temp.Rows(), m_weights[i].Rows()) ||
            !temp.Col(vector::Ones(temp.Rows()), m_weights[i].Rows() - 1))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            matrix::Zeros(0,0);
           }

         m_outputs[i]=temp.MatMul(m_weights[i]);

         if(!m_outputs[i].Activation(temp, m_act_fn))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            return matrix::Zeros(0,0);
           }

        }

      if(!temp.Resize(temp.Rows(), m_weights[m_hidden_layers].Rows()) ||
         !temp.Col(vector::Ones(temp.Rows()), m_weights[m_hidden_layers].Rows() - 1))
        {
         Print(__FUNCTION__," ",__LINE__," ", GetLastError());
         return matrix::Zeros(0,0);
        }

      return temp.MatMul(m_weights[m_hidden_layers]);
     }
   //+------------------------------------------------------------------+
   //|   backpropagation method                                         |
   //+------------------------------------------------------------------+

   virtual bool      backprop(matrix &data, matrix& targets, matrix &result)
     {
      if(targets.Rows() != result.Rows() ||
         targets.Cols() != result.Cols())
        {
         Print(__FUNCTION__," invalid function parameters ");
         return false;
        }
      matrix loss = (targets - result) * 2;
      matrix gradient = loss.MatMul(m_weights[m_hidden_layers].Transpose());
      matrix temp;

      for(long i = long(m_hidden_layers-1); i>-1; i--)
        {
         if(!m_outputs[i].Activation(temp, m_act_fn))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            return false;
           }

         if(!temp.Resize(temp.Rows(), m_weights[i+1].Rows()) ||
            !temp.Col(vector::Ones(temp.Rows()), m_weights[i+1].Rows() - 1))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            return false;
           }

         m_weights[i+1] = m_weights[i+1] + temp.Transpose().MatMul(loss) * m_learn_rate;

         if(!m_outputs[i].Derivative(temp, m_act_fn))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            return false;
           }

         if(!gradient.Resize(gradient.Rows(), gradient.Cols() - 1))
           {
            Print(__FUNCTION__," ",__LINE__," ", GetLastError());
            return false;
           }

         loss = gradient * temp;

         gradient = (i>0)?loss.MatMul(m_weights[i].Transpose()):gradient;

        }

      temp = data;
      if(!temp.Resize(temp.Rows(), m_weights[0].Rows()) ||
         !temp.Col(vector::Ones(temp.Rows()), m_weights[0].Rows() - 1))
        {
         Print(__FUNCTION__," ",__LINE__," ", GetLastError());
         return false;
        }

      m_weights[0] = m_weights[0] + temp.Transpose().MatMul(loss) * m_learn_rate;

      return true;

     }

public:
   //+------------------------------------------------------------------+
   //|  constructor                                                     |
   //+------------------------------------------------------------------+

                     FFNN(ulong &layersizes[], ulong num_layers = 3)
     {
      m_trained = false;
      m_layers = num_layers;
      m_hidden_layers = m_layers - 1;

      ArrayCopy(m_hidden_layer_size,layersizes,0,0,int(m_hidden_layers));

      ArrayResize(m_weights,int(m_layers));

      ArrayResize(m_outputs,int(m_hidden_layers));
     }
   //+------------------------------------------------------------------+
   //| destructor                                                       |
   //+------------------------------------------------------------------+

                    ~FFNN(void)
     {
     }
   //+------------------------------------------------------------------+
   //|  neural net training method                                      |
   //+------------------------------------------------------------------+

   bool              fit(matrix &data, matrix &targets,double learning_rate, ENUM_ACTIVATION_FUNCTION act_fn, uint num_epochs)
     {
      m_learn_rate = learning_rate;
      m_act_fn = act_fn;
      m_epochs = num_epochs;
      m_num_inputs = data.Cols();
      m_trained = false;


      if(!create())
         return false;

      for(uint ep = 0; ep < m_epochs; ep++)
        {
         m_result = calculate(data);

         if(!backprop(data, targets,m_result))
            return m_trained;
        }

      m_trained = true;
      return m_trained;
     }
   //+------------------------------------------------------------------+
   //|  predict method                                                  |
   //+------------------------------------------------------------------+

   matrix            predict(matrix &data)
     {
      if(m_trained)
         return calculate(data);
      else
         return matrix::Zeros(0,0);
     }




  };
//+------------------------------------------------------------------+
