//+------------------------------------------------------------------+
//|                                                  LSTMNetwork.mqh |
//|                                    Copyright 2018, Mukachi Corp. |
//|                                          https://www.mukachi.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2018, Mukachi Corp."
#property link      "https://www.mukachi.com"
#property version   "1.00"
//+------------------------------------------------------------------+
#include "TimeStep.mqh"
//+------------------------------------------------------------------+
//| Long Short-Term Memory Neural Network                            |
//+------------------------------------------------------------------+
class CLSTMNetwork
  {
private:
   int               m_gate_cnt;
   int               m_input_size;
   int               m_weight_cnt;
   int               m_timestep_cnt;

   int               m_patterns;          // teaching patterns

   double            m_mse;
   long              m_epoch;

   //--- Momentum Stochasstic Gradient Descent
   double            m_alpha;             // delta step
   double            m_gamma;             // velocity decay rate

   double            m_b_weights[4];
   double            m_u_weights[4];
   double            m_x_weights[][4];

   double            m_b_deltas[4];
   double            m_u_deltas[4];
   double            m_x_deltas[][4];

   double            m_b_velocity[4];
   double            m_u_velocity[4];
   double            m_x_velocity[][4];

   CTimeStep        *m_timesteps[];

   double            dtan_h(double out);

   void              forward_pass(double &inputs[]);
   void              backward_pass(double &targets[]);

   void              update_weights(double &inputs[]);
   void              compute_gradients(double &inputs[]);

   double            get_mse(double &targets[]);
   double            sigma_rand();
   double            tanh_rand();

public:
                     CLSTMNetwork(int,int,int);
                    ~CLSTMNetwork();
   //---
   void              Learn(double &inputs[],double &targets[],
                           double mse_limit,long epochs);
   //---
   double            Calculate(double &inputs[]);
   double            MSE() const { return(m_mse); }
   long              Epochs() const { return(m_epoch-1); }

   //--- TODO: Implement Save and Load Functions for the Model

  };
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
CLSTMNetwork::CLSTMNetwork(int patterns,int inputs,int timesteps)
   : m_gate_cnt(4),
     m_patterns(patterns),
     m_input_size(inputs),
     m_timestep_cnt(timesteps),
     m_weight_cnt(m_gate_cnt*inputs),
     m_epoch(0),
     m_mse(DBL_MAX),
     m_alpha(0.15),
     m_gamma(0.85)
  {
   ArrayResize(m_x_deltas,m_input_size);
   ArrayResize(m_x_weights,m_input_size);
   ArrayResize(m_x_velocity,m_input_size);
//---
   ArrayResize(m_timesteps,m_timestep_cnt);
//---
   ArrayInitialize(m_b_weights,1.0);
   ArrayInitialize(m_u_weights,0.0);
   ArrayInitialize(m_x_weights,0.0);
//---
   ArrayInitialize(m_b_deltas,1.0);
   ArrayInitialize(m_u_deltas,0.0);
   ArrayInitialize(m_x_deltas,0.0);
//---
   ArrayInitialize(m_b_velocity,1.0);
   ArrayInitialize(m_u_velocity,0.0);
   ArrayInitialize(m_x_velocity,0.0);

//--- create timesteps
   for(int t=0;t<m_timestep_cnt;t++)
      m_timesteps[t]=new CTimeStep();

//--- random seed generator seed
   MathSrand((uint)TimeLocal());

//--- 
   for(int g=0;g<m_gate_cnt;g++)
     {
      if(g==0) m_u_weights[g]=tanh_rand();
      else m_u_weights[g]=sigma_rand();
     }

   for(int i=0;i<m_input_size;i++)
     {
      for(int g=0;g<m_gate_cnt;g++)
        {
         if(g==0) m_x_weights[i][g]=tanh_rand();
         else m_x_weights[i][g]=sigma_rand();
        }
     }
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
CLSTMNetwork::~CLSTMNetwork()
  {
   for(int t=0;t<m_timestep_cnt;t++)
     {
      if(CheckPointer(m_timesteps[t])==POINTER_DYNAMIC)
         delete(m_timesteps[t]);
     }
   ArrayFree(m_timesteps);
  }
//+------------------------------------------------------------------+
void CLSTMNetwork::Learn(double &inputs[],double &targets[],
                         double mse_limit,long epochs)
  {
   ArrayInitialize(m_x_velocity,0.0);
   ArrayInitialize(m_b_velocity,0.0);
   ArrayInitialize(m_u_velocity,0.0);

   int tg_cnt=0;
   int in_cnt=0;

   double tg[],in[];

   ArrayResize(tg,m_timestep_cnt);
   ArrayResize(in,m_input_size*m_timestep_cnt);

   for(m_epoch=1;m_epoch<=epochs;m_epoch++)
     {
      in_cnt=0;
      tg_cnt=0;
      m_mse=0.0;

      for(int p=0;p<m_patterns;p++)
        {
         int in_idx=0;

         //--- set inputs & targets for pattern
         for(int t=0;t<m_timestep_cnt;t++)
           {
            for(int i=0;i<m_input_size;i++)
               in[in_idx++]=inputs[in_cnt++];
            tg[t]=targets[tg_cnt++];
           }

         // forward pass
         forward_pass(in);

         //--- get error
         m_mse+=get_mse(tg);

         //--- error limit check
         if(m_mse<mse_limit) return;

         //--- backward pass
         backward_pass(tg);

         //--- update all weights
         update_weights(in);
        }
     }
  }
//+------------------------------------------------------------------+
double CLSTMNetwork::Calculate(double &inputs[])
  {
   forward_pass(inputs);
   return(m_timesteps[m_timestep_cnt-1].out());
  }
//+------------------------------------------------------------------+
void CLSTMNetwork::forward_pass(double &inputs[])
  {
   CTimeStep *ts;
//---
   for(int t=0;t<m_timestep_cnt;t++)
     {
      double out=0.0;
      if(t>0) out=m_timesteps[t-1].out();

      ts=m_timesteps[t];
      int idx=t*m_input_size;

      //--- calculate gate out
      for(int g=0;g<m_gate_cnt;g++)
        {
         double net=0.0;
         for(int i=0;i<m_input_size;i++)
            net+=m_x_weights[i][g]*inputs[idx+i];
         net+=m_u_weights[g]*out+m_b_weights[g];

         //--- activate gate
         ts.gate(g).out(net);
        }

      //--- update state & out
      if(t==0) ts.update_state(0.0);
      else ts.update_state(m_timesteps[t-1].state());
     }
  }
//+------------------------------------------------------------------+
void CLSTMNetwork::backward_pass(double &targets[])
  {
   CTimeStep *ts;
   int last=m_timestep_cnt-1;
//---
   for(int t=last;t>=0;t--)
     {
      ts=m_timesteps[t];

      //--- total & ourt errors
      double out_error=0.0;
      double total_error=ts.out()-targets[t];

      //--- next timestep values
      double next_f=0.0;
      double next_state_delta=0.0;

      //--- update out error
      double out_delta=0.0;

      if(t==last)
        {
         out_delta=total_error;
        }
      else
        {
         ts=m_timesteps[t+1];

         out_error=ts.out_error();

         next_f=ts.f().out();
         next_state_delta=ts.state_delta();
        }

      ts=m_timesteps[t];

      //--- calculate out delta
      out_delta=total_error+out_error;

      //--- calculate state delta
      double state_delta=out_delta*ts.o().out()*
                         dtan_h(ts.state())+
                         next_state_delta*next_f;

      //--- update gate deltas
      if(t==0) ts.update_delta(NULL,state_delta,out_delta);
      else ts.update_delta(m_timesteps[t-1],state_delta,out_delta);

      //--- calculate timestep out error
      out_error=0.0;
      for(int g=0;g<m_gate_cnt;g++)
         out_error+=m_u_weights[g]*ts.gate(g).delta();

      //--- update timestep out error
      ts.out_error(out_error);
     }
  }
//+------------------------------------------------------------------+
void CLSTMNetwork::update_weights(double &inputs[])
  {
//--- compute update gradients
   compute_gradients(inputs);

   double step=0.0;
   double decay=0.0;

   for(int i=0;i<m_input_size;i++)
     {
      for(int g=0;g<m_gate_cnt;g++)
        {
         //--- inputs momentum step & decay
         step=m_alpha*m_x_deltas[i][g];
         decay=m_x_velocity[i][g]*m_gamma;

         //--- inputs update velocity & weights
         m_x_velocity[i][g]=decay+step;
         m_x_weights[i][g]-=m_x_velocity[i][g];

         if(i==0)
           {
            //--- out momentum step & decay
            step=m_alpha*m_u_deltas[g];
            decay=m_u_velocity[g]*m_gamma;

            //--- out update velocity & weights
            m_u_velocity[g]=decay+step;
            m_u_weights[g]-=m_u_velocity[g];

            //--- bias momentum step & decay
            step=m_alpha*m_b_deltas[g];
            decay=m_b_velocity[g]*m_gamma;

            //--- bias update velocity & weights
            m_b_velocity[g]=decay+step;
            m_b_weights[g]-=m_b_velocity[g];
           }
        }
     }
  }
//+------------------------------------------------------------------+
void CLSTMNetwork::compute_gradients(double &inputs[])
  {
   ArrayInitialize(m_x_deltas,0.0);
   ArrayInitialize(m_u_deltas,0.0);
   ArrayInitialize(m_b_deltas,0.0);

   CGate *gate;
   CTimeStep *ts;
//---
   for(int t=0;t<m_timestep_cnt;t++)
     {
      ts=m_timesteps[t];

      int idx=t*m_input_size;

      //--- calculate input deltas
      for(int i=0;i<m_input_size;i++)
        {
         for(int g=0;g<m_gate_cnt;g++)
           {
            gate=ts.gate(g);
            m_x_deltas[i][g]+=gate.delta()*inputs[idx+i];
           }
        }

      //--- calculate output deltas
      if(t>0)
        {
         double out=m_timesteps[t-1].out();
         for(int g=0;g<m_gate_cnt;g++)
           {
            gate=ts.gate(g);
            m_u_deltas[g]+=gate.delta()*out;
           }
        }

      //--- calculate bias deltas
      for(int g=0;g<m_gate_cnt;g++)
        {
         gate=ts.gate(g);
         m_b_deltas[g]+=gate.delta();
        }
     }
  }
//+------------------------------------------------------------------+
double CLSTMNetwork::get_mse(double &targets[])
  {
   CTimeStep *ts;
//---
   double error=0.0;
   int last=m_timestep_cnt-1;
//---
   ts=m_timesteps[last];
   error=ts.out()-targets[last];
   double ms_error=(error*error)*0.5;
//---
   return(ms_error);
  }
//+------------------------------------------------------------------+
double CLSTMNetwork::dtan_h(double out)
  {
   double th=tanh(out);
   return(1-(th*th));
  }
//+------------------------------------------------------------------+
double CLSTMNetwork::tanh_rand(void)
  {
   return(0.5*(double(rand())/double(SHORT_MAX)-0.5));
  }
//+------------------------------------------------------------------+
double CLSTMNetwork::sigma_rand()
  {
   return(double(rand())/double(SHORT_MAX));
  }
//+------------------------------------------------------------------+
