//+------------------------------------------------------------------+
//|                                                      CMatrix.mqh |
//|                        Copyright 2012, MetaQuotes Software Corp. |
//|                                              http://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2012, MetaQuotes Software Corp."
#property link      "http://www.mql5.com"
//+------------------------------------------------------------------+
//| CMatrix class                                                    |
//+------------------------------------------------------------------+
class CMatrix
  {
   double            m_matrix[];
   int               m_rows;
   int               m_columns;
public:
   void              SetSize(int nrows,int ncolumns);
   double            Get(int i,int j);
   void              Set(int i,int j,double val);
   void              GaussSolve(double &v[]);
   void              Test();
  };
//+------------------------------------------------------------------+
//| Sets matrix size (m,n)                                           |
//+------------------------------------------------------------------+
void CMatrix::SetSize(int nrows,int ncolumns)
  {
   if(nrows<=0 || ncolumns<=0) return;
   m_rows=nrows;
   m_columns=ncolumns;
   ArrayResize(m_matrix,(m_rows+1)*(m_columns+1));
  }
//+------------------------------------------------------------------+
//| Gets matrix element [i,j]                                        |
//+------------------------------------------------------------------+
double CMatrix::Get(int i,int j)
  {
   if(i<1 || i>m_rows) return(0);
   if(j<1 || j>m_columns) return(0);
   return(m_matrix[(j-1)*m_columns+(i-1)]);
  }
//+------------------------------------------------------------------+
//| Sets matrix element [i,j]                                        |
//+------------------------------------------------------------------+
void CMatrix::Set(int i,int j,double val)
  {
   if(i<1 || i>m_rows) return;
   if(j<1 || j>m_columns) return;
   m_matrix[(j-1)*m_columns+(i-1)]=val;
  }
//+------------------------------------------------------------------+
//| Linear system solver (Gauss method)                              |
//+------------------------------------------------------------------+
void CMatrix::GaussSolve(double &v[])
  {
   if(m_rows<=0 || m_columns<=0) return;
   if(m_columns!=m_rows+1) return;

   int n=m_rows;
   int j;
   double tmp;

   for(int i=1; i<=n-1; i++)
     {
      j=i;
      while(Get(i,j)==0 && j<n+1) {j++;}
      if(j==n+1) {Print("Error"); return;}

      if(i!=j)
        {
         for(int k=1; k<=n+1; k++)
           {
            tmp=Get(i,k);
            Set(i,k,Get(j,k));
            Set(j,k,tmp);
           }
        }
      for(j=i+1; j<=n; j++)
        {
         tmp=Get(j,i);
         for(int k=i; k<=n+1; k++) {Set(j,k,(Get(j,k)*Get(i,i)-Get(i,k)*tmp));}
        }
     }
//--- find solution
   double x[];
   ArrayResize(x,n+1);
   for(int i=n; i>=1; i--)
     {
      x[i]=Get(i,n+1);
      for(j=n; j>=i+1; j--) {x[i]=x[i]-x[j]*Get(i,j);}
      if(Get(i,i)!=0) {x[i]=x[i]/Get(i,i);}
     }
//--- set values
   ArrayResize(v,n);
   for(int i=0; i<n; i++) v[i]=x[i+1];
  }
//+------------------------------------------------------------------+
//| TestGauss                                                        |
//+------------------------------------------------------------------+
void CMatrix::Test()
  {
//--- test system
//   x + 2y -3z = 2
//   x + 4y -5z = 4
// -3x + 2y -1z = 0
//-- solution x=1, y=2, z=1;
   SetSize(3,4);
//-- column 1
   Set(1,1,1);
   Set(2,1,1);
   Set(3,1,-3);
//-- column 2
   Set(1,2,2);
   Set(2,2,4);
   Set(3,2,2);
//-- column 3
   Set(1,3,-3);
   Set(2,3,-5);
   Set(3,3,-1);
//-- column 4
   Set(1,4,2);
   Set(2,4,4);
   Set(3,4,0);
//--- solution vector
   double v[];
   GaussSolve(v);
   for(int i=0; i<3; i++) Print(v[i]);
  }
//+------------------------------------------------------------------+
