English Русский 中文 Español Deutsch 日本語 Português Français Italiano Türkçe
preview
클래스에서 ONNX 모델 래핑하기

클래스에서 ONNX 모델 래핑하기

MetaTrader 5 | 23 2월 2024, 13:26
121 0
MetaQuotes
MetaQuotes

소개

이전 글에서 우리는 투표 분류기를 배열하는 데 두 개의 ONNX 모델을 사용했습니다. 전체 소스 텍스트는 단일 MQ5 파일로 구성되었습니다. 전체 코드가 함수로 나뉘어져 있었습니다. 하지만 우리가 모델을 바꾸면 어떨까요? 혹은 다른 모델을 추가하면 어떨까요? 원본 텍스트가 더 커지게 될 것입니다. 이때 객체 지향 접근 방식을 사용해 보겠습니다.


1. 어떤 모델을 사용해 볼까요?

이전 투표 분류기에서 우리는 하나의 분류 모델과 하나의 회귀 모델을 사용했습니다. 회귀 모델에서는 예측된 가격 변동(하락, 상승, 변동 없음) 대신에 클래스를 계산하는 데 사용된 예측 가격을 우리는 얻게 됩니다. 그러나 이 경우 우리에게는 클래스별 확률 분포가 없기 때문에 소위 '소프트 투표'를 허용하지 않습니다.

여기 3가지 분류 모델을 준비했습니다. "MQL5에서 ONNX 모델을 앙상블하는 방법의 예" 문서에서 이미 두 가지 모델이 사용되었습니다. 첫 번째 모델(회귀)을 분류 모델로 변환했습니다. 10가지 OHLC 가격에 대한 학습이 진행되었습니다. 두 번째 모델은 분류 모델입니다. 일련의 63 종가에 대한 학습이 진행되었습니다.

마지막으로 모델이 하나 더 있습니다. 분류 모델은 일련의 30개의 종가와 평균 기간이 21 및 34인 일련의 단순 이동 평균에 대해 학습되었습니다. 우리는 이동 평균과 종가 차트의 교차점에 대해 어떤 가정도하지 않았습니다 - 모든 패턴은 레이어 간의 계수 행렬 형태로 네트워크에 의해 계산되고 기억됩니다.

모든 모델은 2010.01.01~2023.01.01의 EURUSD D1 MetaQuotes-Demo server data로 학습 되었습니다. 세 가지 모델 모두에 대한 학습 스크립트는 Python으로 작성되었으며 이 문서에 첨부되어 있습니다. 이 글의 주요 주제로부터 독자분들의 주의를 분산시키지 않기 위해 여기서는 소스 코드를 제공하지 않을 것입니다.


2. 모든 모델에 하나의 기본 클래스가 필요합니다.

세 가지 모델이 있습니다. 입력 데이터의 크기와 준비 방식이 각각 다릅니다. 모든 모델이 동일한 인터페이스를 갖습니다. 모든 모델의 클래스는 동일한 기본 클래스에서 상속되어야 합니다.

베이스 클래스를 표현해 보겠습니다.

//+------------------------------------------------------------------+
//|                                            ModelSymbolPeriod.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+

//--- price movement prediction
#define PRICE_UP   0
#define PRICE_SAME 1
#define PRICE_DOWN 2

//+------------------------------------------------------------------+
//| Base class for models based on trained symbol and period         |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
  {
protected:
   long              m_handle;           // created model session handle
   string            m_symbol;           // symbol of trained data
   ENUM_TIMEFRAMES   m_period;           // timeframe of trained data
   datetime          m_next_bar;         // time of next bar (we work at bar begin only)
   double            m_class_delta;      // delta to recognize "price the same" in regression models

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
     {
      m_handle=INVALID_HANDLE;
      m_symbol=symbol;
      m_period=period;
      m_next_bar=0;
      m_class_delta=class_delta;
     }

   //+------------------------------------------------------------------+
   //| Destructor                                                       |
   //+------------------------------------------------------------------+
   ~CModelSymbolPeriod(void)
     {
      Shutdown();
     }

   //+------------------------------------------------------------------+
   //| virtual stub for Init                                            |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period)
     {
      return(false);
     }

   //+------------------------------------------------------------------+
   //| Check for initialization, create model                           |
   //+------------------------------------------------------------------+
   bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
     {
      //--- check symbol, period
      if(symbol!=m_symbol || period!=m_period)
        {
         PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
         return(false);
        }

      //--- create a model from static buffer
      m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
      if(m_handle==INVALID_HANDLE)
        {
         Print("OnnxCreateFromBuffer error ",GetLastError());
         return(false);
        }

      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Release ONNX session                                             |
   //+------------------------------------------------------------------+
   void Shutdown(void)
     {
      if(m_handle!=INVALID_HANDLE)
        {
         OnnxRelease(m_handle);
         m_handle=INVALID_HANDLE;
        }
     }

   //+------------------------------------------------------------------+
   //| Check for continue OnTick                                        |
   //+------------------------------------------------------------------+
   virtual bool CheckOnTick(void)
     {
      //--- check new bar
      if(TimeCurrent()<m_next_bar)
         return(false);
      //--- set next bar time
      m_next_bar=TimeCurrent();
      m_next_bar-=m_next_bar%PeriodSeconds(m_period);
      m_next_bar+=PeriodSeconds(m_period);

      //--- work on new day bar
      return(true);
     }

   //+------------------------------------------------------------------+
   //| virtual stub for PredictPrice (regression model)                 |
   //+------------------------------------------------------------------+
   virtual double PredictPrice(void)
     {
      return(DBL_MAX);
     }

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      double predicted_price=PredictPrice();
      if(predicted_price==DBL_MAX)
         return(-1);

      int    predicted_class=-1;
      double last_close=iClose(m_symbol,m_period,1);
      //--- classify predicted price movement
      double delta=last_close-predicted_price;
      if(fabs(delta)<=m_class_delta)
         predicted_class=PRICE_SAME;
      else
        {
         if(delta<0)
            predicted_class=PRICE_UP;
         else
            predicted_class=PRICE_DOWN;
        }

      //--- return predicted class
      return(predicted_class);
     }
  };
//+------------------------------------------------------------------+

기본 클래스는 회귀 및 분류 모델 모두에 사용할 수 있습니다. 우리는 하위 클래스에서 적절한 메서드(PredictPrice 또는 PredictClass)를 구현하기만 하면 됩니다.

기본 클래스가 모델이 작업할 심볼 기간(모델의 학습이 이루어진 데이터)을 설정합니다. 또한 기본 클래스는 필요한 심볼 기간에 모델을 사용하는 EA가 작동하는지 확인하고 모델을 실행하기 위해 ONNX 세션을 생성합니다. 기본 클래스는 새로운 바가 시작될 때만 작업을 제공합니다.


3. 첫 번째 모델 클래스

첫 번째 모델은 model.eurusd.D1.10.class.onnx로, 일련의 10개의 OHLC 가격에 대한 EURUSD D1에 대해 학습된 분류 모델입니다.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_10Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_10Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=10;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
        {
         Print("model_eurusd_D1_10_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (OHLC)
      const long input_shape[] = {1,m_sample_size,4};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,4);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  mm(m_sample_size,4);            // matrix of horizontal vectors Mean
      static matrix  ms(m_sample_size,4);            // matrix of horizontal vectors Std
      static matrix  x_norm(m_sample_size,4);        // matrix for prices normalize
   
      //--- prepare input data
      matrix rates;
      //--- request last bars
      if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size))
         return(-1);
      //--- get series Mean
      vector m=rates.Mean(1);
      //--- get series Std
      vector s=rates.Std(1);
      //--- prepare matrices for prices normalization
      for(int i=0; i<m_sample_size; i++)
        {
         mm.Row(m,i);
         ms.Row(s,i);
        }
      //--- the input of the model must be a set of vertical OHLC vectors
      x_norm=rates.Transpose();
      //--- normalize prices
      x_norm-=mm;
      x_norm/=ms;
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

위에서 이미 언급했듯이: "세 가지 모델이 있습니다. 각각의 모델은 입력 데이터의 크기와 준비가 각각 다릅니다." 우리는 Init과 PredictClass의 두 가지 메서드만 재정의했습니다. 동일한 메서드가 다른 두 모델에 다른 두 클래스에서 재정의됩니다.

Init 메서드는 ONNX 모델에 대한 세션이 생성되고 입력 및 출력 텐서의 크기가 명시적으로 설정되는 CheckInit 베이스 클래스 메서드를 호출합니다. 여기에는 코드보다 코멘트가 더 많습니다.

PredictClass 메서드는 모델을 훈련할 때와 정확히 동일한 입력 데이터 준비를 제공합니다. 입력은 정규화된 OHLC 가격의 행렬입니다.


4. 이제 작동 방식을 확인해 보겠습니다.

클래스의 성능을 테스트하기 위해 매우 간단한 Expert Advisor가 만들어졌습니다.

//+------------------------------------------------------------------+
//|                                    ONNX.eurusd.D1.Prediction.mq5 |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright   "Copyright 2023, MetaQuotes Ltd."
#property link        "https://www.mql5.com"
#property version     "1.00"

#include "ModelEurusdD1_10Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_10Class ExtModel;
CTrade                 ExtTrade;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
   if(!ExtModel.Init(_Symbol,_Period))
      return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
   ExtModel.Shutdown();
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
   if(!ExtModel.CheckOnTick())
      return;

//--- predict next price movement
   int predicted_class=ExtModel.PredictClass();
//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }
//+------------------------------------------------------------------+
//| Check for open position conditions                               |
//+------------------------------------------------------------------+
void CheckForOpen(const int predicted_class)
  {
   ENUM_ORDER_TYPE signal=WRONG_VALUE;
//--- check signals

   if(predicted_class==PRICE_DOWN)
      signal=ORDER_TYPE_SELL;    // sell condition
   else
     {
      if(predicted_class==PRICE_UP)
         signal=ORDER_TYPE_BUY;  // buy condition
     }

//--- open position if possible according to signal
   if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK);
      ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0);
     }
  }
//+------------------------------------------------------------------+
//| Check for close position conditions                              |
//+------------------------------------------------------------------+
void CheckForClose(const int predicted_class)
  {
   bool bsignal=false;
//--- position already selected before
   long type=PositionGetInteger(POSITION_TYPE);
//--- check signals
   if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN)
      bsignal=true;
   if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP)
      bsignal=true;

//--- close position if possible
   if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      ExtTrade.PositionClose(_Symbol,3);
      //--- open opposite
      CheckForOpen(predicted_class);
     }
  }
//+------------------------------------------------------------------+

이 모델은 2023년까지 가격 데이터로 학습되었습니다. 우리는 2023년 1월 1일부터 테스트를 시작하겠습니다.

테스트 설정

결과가 아래와 같이 표시됩니다:

테스트 결과

보시다시피 이 모델은 완벽하게 작동합니다.


5. 두 번째 모델 클래스

두 번째 모델은 model.eurusd.D1.30.class.onnx입니다. 분류 모델은 일련의 30개의 종가와 평균 기간이 21, 34인 두 개의 단순 이동 평균에 대해 EURUSD D1을 학습시켰습니다.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_30Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_30Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;
   int               m_fast_period;
   int               m_slow_period;
   int               m_sma_fast;
   int               m_sma_slow;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=30;
      m_fast_period=21;
      m_slow_period=34;
      m_sma_fast=INVALID_HANDLE;
      m_sma_slow=INVALID_HANDLE;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class))
        {
         Print("model_eurusd_D1_30_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow)
      const long input_shape[] = {1,m_sample_size,3};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- indicators
      m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE);
      m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE);
      if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE)
        {
         Print("model_eurusd_D1_30_class : cannot create indicator");
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,3);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  x_norm(m_sample_size,3);        // matrix for prices normalize
      static vector  vtemp(m_sample_size);
      static double  ma_buffer[];
   
      //--- request last bars
      if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      double m=vtemp.Mean();
      //--- get series Std
      double s=vtemp.Std();
      //--- normalize
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,0);
      //--- fast sma
      if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,1);
      //--- slow sma
      if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,2);
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

이전 클래스에서와 마찬가지로, CheckInit 베이스 클래스 메서드는 Init 메서드에서 호출됩니다. 베이스 클래스 메서드에서는 ONNX 모델에 대한 세션이 생성되고 입력 및 출력 텐서의 크기가 명시적으로 설정됩니다.

PredictClass 메서드는 30개의 이전 종가 및 계산된 이동 평균을 제공합니다. 데이터는 학습시와 마찬가지의 방식으로 정규화됩니다.

이 모델이 어떻게 작동하는지 살펴봅시다. 이를 위해 테스트 EA의 두 문자열만 변경해 보겠습니다.

#include "ModelEurusdD1_30Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_30Class ExtModel;
CTrade                 ExtTrade;

테스트 매개변수는 동일합니다.

두 번째 모델 테스트 결과

모델이 작동합니다.


6. 세 번째 모델 클래스

마지막 모델은 model.eurusd.D1.63.class.onnx입니다. 분류 모델은 일련의 63 종가에 대해 EURUSD D1을 학습시켰습니다.

//+------------------------------------------------------------------+
//|                                             ModelEurusdD1_63.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_63Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001)
     {
      m_sample_size=63;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class))
        {
         Print("model_eurusd_D1_63_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size
      const long input_shape[] = {1,m_sample_size};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static vectorf input_data(m_sample_size);  // vector for prepared input data
      static vectorf output_data(3);             // vector to get result
   
      //--- request last bars
      if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      float m=input_data.Mean();
      //--- get series Std
      float s=input_data.Std();
      //--- normalize prices
      input_data-=m;
      input_data/=s;
   
      //--- run the inference
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

세 가지 모델 중 가장 간단한 모델입니다. 이것이 바로 PredictClass 메서드의 코드가 매우 간결한 이유입니다.

EA에서 두 개의 문자열을 다시 변경해 보겠습니다.

#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_63Class ExtModel;
CTrade                 ExtTrade;

동일한 설정으로 테스트를 시작합니다.

세 번째 모델 테스트 결과

모델이 작동합니다.



7. 모든 모델을 하나의 EA에 수집합니다. 하드 투표

세 모델 모두 작업 능력이 있음을 보여주었습니다. 이제 이들의 능력을 합쳐 보겠습니다. 모델 투표를 준비해 봅시다.

선언 및 정의 전달

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double  InpLots  = 1.0;    // Lots amount to open position

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;

OnInit 함수

int OnInit()
  {
   ExtModels[0]=new CModelEurusdD1_10Class;
   ExtModels[1]=new CModelEurusdD1_30Class;
   ExtModels[2]=new CModelEurusdD1_63Class;

   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].Init(_Symbol,_Period))
         return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }

OnTick 기능

void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int returned[3]={0,0,0};
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      int pred=ExtModels[i].PredictClass();
      if(pred>=0)
         returned[pred]++;
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- count votes for predictions
   for(int n=0; n<3; n++)
     {
      if(returned[n]>=2)
        {
         predicted_class=n;
         break;
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

과반수 득표는 <총 득표수>/2 + 1이라는 공식에 따라 계산됩니다. 총 3표 중 과반수는 2표입니다. 이를 이른바 '하드 투표'라고 합니다.

테스트 결과는 여전히 동일한 설정으로 한 것입니다.

하드 투표 테스트 결과

세 가지 모델의 작업, 즉 수익성 있는 거래와 수익성 없는 거래의 수를 개별적으로 다시 살펴 보겠습니다. 첫 번째 모델 - 11 : 3, second — 6 : 1, third — 16 : 10.

하드 투표를 통해 결과를 개선한 것 같습니다 - 16 : 4. 하지만 물론 전체 보고서와 테스트 차트를 살펴봐야 합니다.


8. 소프트 투표

소프트 투표는 투표 수가 아니라 세 가지 모델에서 세 가지 클래스 모두의 확률의 합을 중요시 한다는 점에서 하드 투표와 다릅니다. 가장 높은 확률로 클래스가 선택됩니다.

소프트 투표를 보장하려면 몇 가지 사항을 변경해야 합니다.

기본 클래스에서:

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- set predicted probability as 1.0
      probabilities.Fill(0);
      if(predicted_class<(int)probabilities.Size())
         probabilities[predicted_class]=1;
      //--- and return predicted class
      return(predicted_class);
     }

자식 클래스에서:

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- evaluate prediction
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }

EA에서:

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

enum EnVotes
  {
   Two=2,    // Two votes
   Three=3,  // Three votes
   Soft=4    // Soft voting
  };

input double  InpLots  = 1.0;    // Lots amount to open position
input EnVotes InpVotes = Two;    // Votes to make trade decision

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;
void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int    returned[3]={0,0,0};
   vector soft=vector::Zeros(3);
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      vector prob(3);
      int    pred=ExtModels[i].PredictClass(prob);
      if(pred>=0)
        {
         returned[pred]++;
         soft+=prob;
        }
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- soft or hard voting
   if(InpVotes==Soft)
      predicted_class=(int)soft.ArgMax();
   else
     {
      //--- count votes for predictions
      for(int n=0; n<3; n++)
        {
         if(returned[n]>=InpVotes)
           {
            predicted_class=n;
            break;
           }
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

테스트 설정은 동일합니다. 입력에서 Soft를 선택합니다.

입력 설정

결과는 다음과 같습니다.

소프트 투표 테스트 결과

수익성 있는 거래 - 15, 수익성 없는 거래 - 3. 수익의 관점에서도 하드 투표가 소프트 투표보다 더 나은 것으로 나타났습니다.


만장일치인 투표 결과, 즉 투표 수가 3인 경우를 살펴봅시다.

만장일치 투표 테스트 결과

매우 보수적인 트레이딩. 유일하게 수익성이 없는 거래는 테스트 종료 시 종료되었습니다(수익성이 없는 거래는 아닐 수도 있습니다).

만장일치 투표 테스트 그래프


중요 참고 사항: 이 문서에 사용된 모델은 MQL5 언어를 사용하여 ONNX 모델로 작업하는 방법을 시연하기 위한 용도로만 제공되었습니다. Expert Advisor는 실제 계좌에서 거래용으로 만든 것이 아닙니다.


결론

이 글에서 우리는 객체 지향 프로그래밍을 통해 프로그램을 더 쉽게 작성할 수 있는 방법을 알아보았습니다. 모델의 모든 복잡성은 해당 클래스에 숨겨져 있습니다(예시로 제시한 모델보다 훨씬 더 복잡할 수 있습니다). 나머지 '복잡성'은 OnTick 함수의 45개 문자열에 들어가 있습니다.


MetaQuotes 소프트웨어 사를 통해 러시아어가 번역됨.
원본 기고글: https://www.mql5.com/ru/articles/12484

파일 첨부됨 |
MQL5.zip (190.3 KB)
모집단 최적화 알고리즘: 침입성 잡초 최적화(IWO) 모집단 최적화 알고리즘: 침입성 잡초 최적화(IWO)
다양한 조건에서 살아남는 잡초의 놀라운 능력은 강력한 최적화 알고리즘을 만들기 위한 아이디어가 되었습니다. IWO는 앞서 검토한 알고리즘 중 가장 우수한 알고리즘 중 하나입니다.
Expert Advisor 개발 기초부터(25부): 시스템 견고성 확보(II) Expert Advisor 개발 기초부터(25부): 시스템 견고성 확보(II)
이 글에서는 EA의 성능을 향상하기 위한 마지막 단계를 밟아보겠습니다. 그러니 오랫동안 읽을 준비를 하세요. Expert Advisor의 신뢰성을 높이기 위해 우리는 코드에서 모든 것을 제거합니다. 이 코드는 거래 시스템의 일부가 아닌 코드입니다.
프랙탈로 트레이딩 시스템 설계하는 방법 알아보기 프랙탈로 트레이딩 시스템 설계하는 방법 알아보기
이 글은 가장 인기 있는 보조지표를 기반으로 트레이딩 시스템을 설계하는 방법에 대한 시리즈의 새로운 글입니다. 우리는 프랙탈 지표인 새로운 지표에 대해 배우고 이를 기반으로 MetaTrader 5 터미널에서 실행될 거래 시스템을 설계하는 방법을 알아볼 것입니다.
회귀 메트릭을 사용하여 ONNX 모델 평가하기 회귀 메트릭을 사용하여 ONNX 모델 평가하기
회귀는 레이블이 지정되지 않은 예제에서 실제의 값을 예측하는 작업입니다. 회귀 메트릭은 회귀 모델 예측의 정확도를 평가하는 데 사용됩니다.