English Русский 中文 Español Deutsch Português 한국어 Français Italiano Türkçe
preview
ONNXモデルをクラスでラップする

ONNXモデルをクラスでラップする

MetaTrader 5 | 13 9月 2023, 10:45
443 0
MetaQuotes
MetaQuotes

はじめに

前回の記事では、投票分類器をアレンジするために2つのONNXモデルを使用しました。ソーステキスト全体は、1つのMQ5ファイルとして整理されました。コード全体が関数に分割されました。ただし、モデルを入れ替えたり、別のモデルを追加する場合、原文がさらに大きくなります。オブジェクト指向のアプローチを試してみましょう。


1.どんなモデルを使うのか

前回の投票分類器では、1つの分類モデルと1つの回帰モデルを使用しました。回帰モデルでは、予測された価格の動き(下降、上昇、変化なし)の代わりに、クラスを計算するために使用される予測価格が得られます。しかしこの場合、クラス別の確率分布がないため、いわゆる「ソフト投票」ができません。

は3つの分類モデルを用意しました。MQL5でONNXモデルをアンサンブルする方法の例」稿では、すでに2つのモデルが使用されています。最初のモデル(回帰)は分類モデルに変換されました。トレーニングはOHLCの10価格シリーズでおこなわれました。l2番目のモデルは分類モデルです。トレーニングは63の終値でおこなわれました。

最後に、もうひとつモデルがあります。分類モデルは、30の終値と、平均期間が21と34の単純移動平均のシリーズで訓練されました。移動平均線と終値チャートの交点や、移動平均線同士の交点については一切想定していません。すべてのパターンは、レイヤー間の係数行列の形で計算され、ネットワークに記憶されます。

すべてのモデルは、2010.01.01から2023.01.01までのMetaQuotes-Demoサーバーのデータ、EURUSD D1で学習されました。3つのモデルすべてのトレーニングスクリプトはPythonで書かれており、この記事に添付されています。本稿の本題から読者の注意を逸らさないために、ここではそれらのソースコードは提供しません。


2.すべてのモデルに1つの基本クラスが必要

3つのモデルがあります。それぞれの違いは、入力データのサイズと準備にあります。どのモデルもインターフェイスは同じです。すべてのモデルのクラスは、同じ基本クラスから継承されなければなりません。

基本クラスを表現してみましょう。

//+------------------------------------------------------------------+
//|                                            ModelSymbolPeriod.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+

//--- price movement prediction
#define PRICE_UP   0
#define PRICE_SAME 1
#define PRICE_DOWN 2

//+------------------------------------------------------------------+
//| Base class for models based on trained symbol and period         |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
  {
protected:
   long              m_handle;           // created model session handle
   string            m_symbol;           // symbol of trained data
   ENUM_TIMEFRAMES   m_period;           // timeframe of trained data
   datetime          m_next_bar;         // time of next bar (we work at bar begin only)
   double            m_class_delta;      // delta to recognize "price the same" in regression models

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
     {
      m_handle=INVALID_HANDLE;
      m_symbol=symbol;
      m_period=period;
      m_next_bar=0;
      m_class_delta=class_delta;
     }

   //+------------------------------------------------------------------+
   //| Destructor                                                       |
   //+------------------------------------------------------------------+
   ~CModelSymbolPeriod(void)
     {
      Shutdown();
     }

   //+------------------------------------------------------------------+
   //| virtual stub for Init                                            |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period)
     {
      return(false);
     }

   //+------------------------------------------------------------------+
   //| Check for initialization, create model                           |
   //+------------------------------------------------------------------+
   bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
     {
      //--- check symbol, period
      if(symbol!=m_symbol || period!=m_period)
        {
         PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
         return(false);
        }

      //--- create a model from static buffer
      m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
      if(m_handle==INVALID_HANDLE)
        {
         Print("OnnxCreateFromBuffer error ",GetLastError());
         return(false);
        }

      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Release ONNX session                                             |
   //+------------------------------------------------------------------+
   void Shutdown(void)
     {
      if(m_handle!=INVALID_HANDLE)
        {
         OnnxRelease(m_handle);
         m_handle=INVALID_HANDLE;
        }
     }

   //+------------------------------------------------------------------+
   //| Check for continue OnTick                                        |
   //+------------------------------------------------------------------+
   virtual bool CheckOnTick(void)
     {
      //--- check new bar
      if(TimeCurrent()<m_next_bar)
         return(false);
      //--- set next bar time
      m_next_bar=TimeCurrent();
      m_next_bar-=m_next_bar%PeriodSeconds(m_period);
      m_next_bar+=PeriodSeconds(m_period);

      //--- work on new day bar
      return(true);
     }

   //+------------------------------------------------------------------+
   //| virtual stub for PredictPrice (regression model)                 |
   //+------------------------------------------------------------------+
   virtual double PredictPrice(void)
     {
      return(DBL_MAX);
     }

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      double predicted_price=PredictPrice();
      if(predicted_price==DBL_MAX)
         return(-1);

      int    predicted_class=-1;
      double last_close=iClose(m_symbol,m_period,1);
      //--- classify predicted price movement
      double delta=last_close-predicted_price;
      if(fabs(delta)<=m_class_delta)
         predicted_class=PRICE_SAME;
      else
        {
         if(delta<0)
            predicted_class=PRICE_UP;
         else
            predicted_class=PRICE_DOWN;
        }

      //--- return predicted class
      return(predicted_class);
     }
  };
//+------------------------------------------------------------------+

基本クラスは回帰モデルにも分類モデルにも使えます。子孫クラス(PredictPriceまたはPredictClass)に適切なメソッドを実装するだけです。

基本クラスは、モデルが扱う銘柄と期間(モデルが訓練されたデータ)を設定します。基本クラスはまた、モデルを使用するEAが必要な銘柄と期間で動作することを確認し、モデルを実行するためのONNXセッションを作成します。基本クラスは、新しいバーの始まりにのみ仕事を提供します。


3.最初のモデルクラス

最初のモデルは、model.eurusd.D1.10.class.onnxと呼ばれるもので、EURUSD D1の10のOHLC価格シリーズで訓練した分類モデルです。

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_10Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_10Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=10;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
        {
         Print("model_eurusd_D1_10_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (OHLC)
      const long input_shape[] = {1,m_sample_size,4};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,4);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  mm(m_sample_size,4);            // matrix of horizontal vectors Mean
      static matrix  ms(m_sample_size,4);            // matrix of horizontal vectors Std
      static matrix  x_norm(m_sample_size,4);        // matrix for prices normalize
   
      //--- prepare input data
      matrix rates;
      //--- request last bars
      if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size))
         return(-1);
      //--- get series Mean
      vector m=rates.Mean(1);
      //--- get series Std
      vector s=rates.Std(1);
      //--- prepare matrices for prices normalization
      for(int i=0; i<m_sample_size; i++)
        {
         mm.Row(m,i);
         ms.Row(s,i);
        }
      //--- the input of the model must be a set of vertical OHLC vectors
      x_norm=rates.Transpose();
      //--- normalize prices
      x_norm-=mm;
      x_norm/=ms;
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

すでに述べたとおり、モデルは3つあり、それぞれの違いは、入力データのサイズと準備にあります。InitとPredictClassの2つのメソッドだけを再定義しました。同じメソッドが、他の2つのモデルのために他の2つのクラスで再定義されます。

Initメソッドは、ONNXモデルのセッションが作成され、入力テンソルと出力テンソルのサイズが明示的に設定されるCheckInit基底クラスメソッドを呼び出します。ここではコードよりもコメントの方が多くなります。

PredictClassメソッドは、モデルをトレーニングするときとまったく同じ入力データを準備します。入力は正規化されたOHLC価格の行列です。


4.どのように機能するのか確認してみましょう。

このクラスのパフォーマンスをテストするために、非常にコンパクトなエキスパートアドバイザー(EA)が作成されました。

//+------------------------------------------------------------------+
//|                                    ONNX.eurusd.D1.Prediction.mq5 |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright   "Copyright 2023, MetaQuotes Ltd."
#property link        "https://www.mql5.com"
#property version     "1.00"

#include "ModelEurusdD1_10Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_10Class ExtModel;
CTrade                 ExtTrade;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
   if(!ExtModel.Init(_Symbol,_Period))
      return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
   ExtModel.Shutdown();
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
   if(!ExtModel.CheckOnTick())
      return;

//--- predict next price movement
   int predicted_class=ExtModel.PredictClass();
//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }
//+------------------------------------------------------------------+
//| Check for open position conditions                               |
//+------------------------------------------------------------------+
void CheckForOpen(const int predicted_class)
  {
   ENUM_ORDER_TYPE signal=WRONG_VALUE;
//--- check signals

   if(predicted_class==PRICE_DOWN)
      signal=ORDER_TYPE_SELL;    // sell condition
   else
     {
      if(predicted_class==PRICE_UP)
         signal=ORDER_TYPE_BUY;  // buy condition
     }

//--- open position if possible according to signal
   if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK);
      ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0);
     }
  }
//+------------------------------------------------------------------+
//| Check for close position conditions                              |
//+------------------------------------------------------------------+
void CheckForClose(const int predicted_class)
  {
   bool bsignal=false;
//--- position already selected before
   long type=PositionGetInteger(POSITION_TYPE);
//--- check signals
   if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN)
      bsignal=true;
   if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP)
      bsignal=true;

//--- close position if possible
   if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      ExtTrade.PositionClose(_Symbol,3);
      //--- open opposite
      CheckForOpen(predicted_class);
     }
  }
//+------------------------------------------------------------------+

モデルは2023年までの価格データで学習されたので、2023年1月1日からテストを開始しましょう。

テスト設定

結果は以下のように表示されます。

テスト結果

見ての通り、このモデルは完全に機能します。


5.2つ目のモデルクラス

2つ目のモデルはmodel.eurusd.D1.30.class.onnxと呼ばれます。EURUSDのD1で訓練された分類モデルは、30の終値と、平均期間が21と34の2つの単純移動平均のシリーズで訓練されました。

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_30Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_30Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;
   int               m_fast_period;
   int               m_slow_period;
   int               m_sma_fast;
   int               m_sma_slow;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=30;
      m_fast_period=21;
      m_slow_period=34;
      m_sma_fast=INVALID_HANDLE;
      m_sma_slow=INVALID_HANDLE;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class))
        {
         Print("model_eurusd_D1_30_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow)
      const long input_shape[] = {1,m_sample_size,3};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- indicators
      m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE);
      m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE);
      if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE)
        {
         Print("model_eurusd_D1_30_class : cannot create indicator");
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,3);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  x_norm(m_sample_size,3);        // matrix for prices normalize
      static vector  vtemp(m_sample_size);
      static double  ma_buffer[];
   
      //--- request last bars
      if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      double m=vtemp.Mean();
      //--- get series Std
      double s=vtemp.Std();
      //--- normalize
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,0);
      //--- fast sma
      if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,1);
      //--- slow sma
      if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,2);
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

前のクラスと同様に、CheckInit基本クラスメソッドはInitメソッドの中で呼び出されます。基本クラスのメソッドでは、ONNXモデルのセッションが作成され、入力テンソルと出力テンソルのサイズが明示的に設定されます。

PredictClassメソッドは、過去30回の終値と算出された移動平均を提供します。データはトレーニングと同じ方法で正規化されます。

このモデルがどのように機能するか見てみましょう。そのために、テストEAの2つの文字列だけを変更します。

#include "ModelEurusdD1_30Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_30Class ExtModel;
CTrade                 ExtTrade;

テストパラメータは同じです。

2番目のモデルのテスト結果

このモデルは機能していることがわかります。


6.3番目のモデルクラス

最後のモデルはmodel.eurusd.D1.63.class.onnxと呼ばれます。EURUSDのD1で訓練された分類モデルは、63の終値のシリーズを対象としました。

//+------------------------------------------------------------------+
//|                                             ModelEurusdD1_63.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_63Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001)
     {
      m_sample_size=63;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class))
        {
         Print("model_eurusd_D1_63_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size
      const long input_shape[] = {1,m_sample_size};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static vectorf input_data(m_sample_size);  // vector for prepared input data
      static vectorf output_data(3);             // vector to get result
   
      //--- request last bars
      if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      float m=input_data.Mean();
      //--- get series Std
      float s=input_data.Std();
      //--- normalize prices
      input_data-=m;
      input_data/=s;
   
      //--- run the inference
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

これは3つの中で最もシンプルなモデルです。これが、PredictClassメソッドのコードが非常にコンパクトな理由です。

EAの2つの文字列をもう一度変更してみましょう。

#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_63Class ExtModel;
CTrade                 ExtTrade;

同じ設定でテストを開始します。

3番目のモデルのテスト結果

このモデルは機能します。



7.すべてのモデルを1つのEAに集める - ハード投票

3モデルともその能力を発揮しているので、次に、それらの努力を結集してみましょう。モデルの投票を手配しましょう。

前方宣言と定義

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double  InpLots  = 1.0;    // Lots amount to open position

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;

OnInit関数

int OnInit()
  {
   ExtModels[0]=new CModelEurusdD1_10Class;
   ExtModels[1]=new CModelEurusdD1_30Class;
   ExtModels[2]=new CModelEurusdD1_63Class;

   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].Init(_Symbol,_Period))
         return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }

OnTick関数

void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int returned[3]={0,0,0};
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      int pred=ExtModels[i].PredictClass();
      if(pred>=0)
         returned[pred]++;
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- count votes for predictions
   for(int n=0; n<3; n++)
     {
      if(returned[n]>=2)
        {
         predicted_class=n;
         break;
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

過半数の得票は、<総得票数>/2+1という式に従って計算されます。合計3票で、過半数は2票です。いわゆる「ハード投票」です。

テスト結果は同じ設定のままです。

ハード投票テスト結果

3つのモデルの仕事、すなわち、利益の出る取引と利益の出ない取引の数を別々に見てみましょう。モデル1 - 11:3、モデル2 - 6:1、モデル3 - 16:10です。

ハード投票で結果は改善されたようです(16:4)が、もちろん、フルレポートやテストチャートを見る必要があります。


8.ソフト投票

ソフト投票がハード投票と異なるのは、考慮されるのが投票数ではなく、3つのモデルから得られる3つのクラスすべての確率の合計であるという点です。クラスは最も高い確率で選ばれます。

ソフト投票を保証するためには、いくつかの変更を加える必要があります。

基本クラス

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- set predicted probability as 1.0
      probabilities.Fill(0);
      if(predicted_class<(int)probabilities.Size())
         probabilities[predicted_class]=1;
      //--- and return predicted class
      return(predicted_class);
     }

子孫のクラス

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- evaluate prediction
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }

EA

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

enum EnVotes
  {
   Two=2,    // Two votes
   Three=3,  // Three votes
   Soft=4    // Soft voting
  };

input double  InpLots  = 1.0;    // Lots amount to open position
input EnVotes InpVotes = Two;    // Votes to make trade decision

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;
void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int    returned[3]={0,0,0};
   vector soft=vector::Zeros(3);
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      vector prob(3);
      int    pred=ExtModels[i].PredictClass(prob);
      if(pred>=0)
        {
         returned[pred]++;
         soft+=prob;
        }
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- soft or hard voting
   if(InpVotes==Soft)
      predicted_class=(int)soft.ArgMax();
   else
     {
      //--- count votes for predictions
      for(int n=0; n<3; n++)
        {
         if(returned[n]>=InpVotes)
           {
            predicted_class=n;
            break;
           }
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

テストの設定は同じです。入力で、[ソフト]を選択します。

入力設定

結果は以下のとおりです。

ソフト投票テスト結果

有益な取引 - 15件、有益でない取引 - 3件。金銭的な面でも、ハード投票の方がソフト投票よりも優れていることが判明しました。


全会一致、つまり投票数3の結果を見てみましょう。

全会一致投票テスト結果

非常に保守的な取引だ。唯一の不採算取引はテスト終了時に決済された(おそらく、不採算ではない)。

全会一致投票テストグラフ


重要な注意点:この記事で使用されているモデルは、MQL5言語を使用してONNXモデルを操作する方法を示すためにのみ提示されています。EAは、実際の口座での取引を意図したものではありません。


結論

この記事では、オブジェクト指向プログラミングがいかにプログラムを書きやすくするかを紹介しました。モデルの複雑さはすべて、そのクラスに隠されています(モデルは、例として紹介したものよりもはるかに複雑になることができます)。残りの「複雑さ」は、OnTick関数の45個の文字列に収まっています。


MetaQuotes Ltdによってロシア語から翻訳されました。
元の記事: https://www.mql5.com/ru/articles/12484

添付されたファイル |
MQL5.zip (190.3 KB)
MQL4およびMQL5開発のフレームワーク内のOpenAI ChatGPT機能 MQL4およびMQL5開発のフレームワーク内のOpenAI ChatGPT機能
この記事では、エキスパートアドバイザー(EA)、指標、スクリプトの開発にかかる時間と労力を削減するという観点から、OpenAI ChatGPTの機能を理解するために、ChatGPTをいじっていきます。このテクノロジーについて簡単に説明し、MQL4およびMQL5でのプログラミングにこのテクノロジーを正しく使用する方法を説明します。
改善された同事ローソク足パターン認識指標に基づく取引戦略 改善された同事ローソク足パターン認識指標に基づく取引戦略
メタバーベースの指標は、従来のものよりも多くのローソク足を検出しました。これが自動売買に本当に役立つのか、検証してみましょう。
DoEasy - コントロール(第32部):水平スクロールバー、マウスホイールスクロール DoEasy - コントロール(第32部):水平スクロールバー、マウスホイールスクロール
この記事では、水平スクロールバーオブジェクト機能の開発を完成します。また、スクロールバーのスライダーを動かしたり、マウスホイールを回転させたりしてコンテナの内容をスクロールできるようにするほか、MQL5の新しい注文実行ポリシーや新しいランタイムエラーコードを考慮したライブラリへの追加もおこないます。
MQL5の圏論(第14回):線形順序を持つ関手 MQL5の圏論(第14回):線形順序を持つ関手
この記事は、MQL5における圏論の実装に関する広範な連載の一部であり、関手について掘り下げます。関手のおかげで線形順序が集合にどのように写像できるかを検証します。一般的には何のつながりもないと見なされてしまうような2つのデータ集合について考えます。