Оборачиваем ONNX-модели в классы

MetaQuotes | 12 мая, 2023

Введение

В предыдущей статье мы использовали две ONNX-модели для организации классификатора голосования. При этом весь исходный текст был организван в виде одного MQ5-файла. Да, весь код разбит на функции. Но попробуйте, например, поменять местами модели. А если добавить ещё модель? Исходный текст ещё более распухнет. Попробуем объектно-ориентированный подход.


1. Какие модели мы собираемся использовать

В предыдущем классификаторе голосования мы использовали одну классификационную модель и одну регрессионную модель. В регрессионной модели вместо предсказанного движения цены (вниз, вверх, не изменяется) мы получаем предсказанную цену, на основе которой и вычисляем класс. Однако, в этом случае мы не имеем распределения вероятностей по классам, что не позволяет проводить так называемое "мягкое голосование".

Мы подготовили 3 классификационные модели. Две модели уже использовались в статье "Пример ансамбля ONNX-моделей в MQL5". Первая модель — регрессионная — переделана в классификационную. Обучение проводилось на сериях из 10 цен OHLC. Вторая модель — классификационная. Обучение проводилось на сериях из 63 цен Close.

Наконец, ещё одна модель. Классификационная модель обучалась на сериях из 30 цен Close и сериях простых скользящих средних с периодами усреднения 21 и 34. Мы не делали никаких предположений по поводу пересечения скользящих средних с графиком Close и между собой — все закономерности посчитает и запомнит сеть в виде матриц коэффициентов между слоями.

Все модели обучались на данных сервера MetaQuotes-Demo, EURUSD D1 с 2010.01.01 по 2023.01.01. Тренировочные скрипты всех трёх моделей написаны на Питоне и приложены к данной статье. Мы не будем приводить их исходные коды здесь, чтобы не отвлекать внимание читателя от основной темы нашей статьи.


2. Нужен один базовый класс для всех моделей

Три модели. Каждая отличается от другой размером входных данных, подготовкой входных данных. У всех моделей есть общность. Один и тот же интерфейс. Классы всех моделей должны наследоваться от одного и того же базового класса.

Попробуем представить базовый класс.

//+------------------------------------------------------------------+
//|                                            ModelSymbolPeriod.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+

//--- price movement prediction
#define PRICE_UP   0
#define PRICE_SAME 1
#define PRICE_DOWN 2

//+------------------------------------------------------------------+
//| Base class for models based on trained symbol and period         |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
  {
protected:
   long              m_handle;           // created model session handle
   string            m_symbol;           // symbol of trained data
   ENUM_TIMEFRAMES   m_period;           // timeframe of trained data
   datetime          m_next_bar;         // time of next bar (we work at bar begin only)
   double            m_class_delta;      // delta to recognize "price the same" in regression models

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
     {
      m_handle=INVALID_HANDLE;
      m_symbol=symbol;
      m_period=period;
      m_next_bar=0;
      m_class_delta=class_delta;
     }

   //+------------------------------------------------------------------+
   //| Destructor                                                       |
   //+------------------------------------------------------------------+
   ~CModelSymbolPeriod(void)
     {
      Shutdown();
     }

   //+------------------------------------------------------------------+
   //| virtual stub for Init                                            |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period)
     {
      return(false);
     }

   //+------------------------------------------------------------------+
   //| Check for initialization, create model                           |
   //+------------------------------------------------------------------+
   bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[])
     {
      //--- check symbol, period
      if(symbol!=m_symbol || period!=m_period)
        {
         PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period));
         return(false);
        }

      //--- create a model from static buffer
      m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
      if(m_handle==INVALID_HANDLE)
        {
         Print("OnnxCreateFromBuffer error ",GetLastError());
         return(false);
        }

      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Release ONNX session                                             |
   //+------------------------------------------------------------------+
   void Shutdown(void)
     {
      if(m_handle!=INVALID_HANDLE)
        {
         OnnxRelease(m_handle);
         m_handle=INVALID_HANDLE;
        }
     }

   //+------------------------------------------------------------------+
   //| Check for continue OnTick                                        |
   //+------------------------------------------------------------------+
   virtual bool CheckOnTick(void)
     {
      //--- check new bar
      if(TimeCurrent()<m_next_bar)
         return(false);
      //--- set next bar time
      m_next_bar=TimeCurrent();
      m_next_bar-=m_next_bar%PeriodSeconds(m_period);
      m_next_bar+=PeriodSeconds(m_period);

      //--- work on new day bar
      return(true);
     }

   //+------------------------------------------------------------------+
   //| virtual stub for PredictPrice (regression model)                 |
   //+------------------------------------------------------------------+
   virtual double PredictPrice(void)
     {
      return(DBL_MAX);
     }

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      double predicted_price=PredictPrice();
      if(predicted_price==DBL_MAX)
         return(-1);

      int    predicted_class=-1;
      double last_close=iClose(m_symbol,m_period,1);
      //--- classify predicted price movement
      double delta=last_close-predicted_price;
      if(fabs(delta)<=m_class_delta)
         predicted_class=PRICE_SAME;
      else
        {
         if(delta<0)
            predicted_class=PRICE_UP;
         else
            predicted_class=PRICE_DOWN;
        }

      //--- return predicted class
      return(predicted_class);
     }
  };
//+------------------------------------------------------------------+

Базовый класс можно использовать как для моделей регрессии, так и для моделей классификации. Надо будет только реализовать в классе-наследнике соответствующий метод — PredictPrice или PredictClass.

В базовом классе задаётся с каким символом-периодом должна работать модель (на каких данных обучалась модель). В базовом классе проводится проверка, что эксперт, использующий модель, работает на нужном символе-периоде, а также создаётся ONNX-сессия для исполнения модели. В базовом классе обеспечивается работа только в начале нового бара.


3. Класс для первой модели

Наша первая модель называется model.eurusd.D1.10.class.onnx, то есть классификационная модель, тренированная на EURUSD D1 на сериях из 10 цен OHLC.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_10Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_10Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=10;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
        {
         Print("model_eurusd_D1_10_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (OHLC)
      const long input_shape[] = {1,m_sample_size,4};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,4);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  mm(m_sample_size,4);            // matrix of horizontal vectors Mean
      static matrix  ms(m_sample_size,4);            // matrix of horizontal vectors Std
      static matrix  x_norm(m_sample_size,4);        // matrix for prices normalize
   
      //--- prepare input data
      matrix rates;
      //--- request last bars
      if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size))
         return(-1);
      //--- get series Mean
      vector m=rates.Mean(1);
      //--- get series Std
      vector s=rates.Std(1);
      //--- prepare matrices for prices normalization
      for(int i=0; i<m_sample_size; i++)
        {
         mm.Row(m,i);
         ms.Row(s,i);
        }
      //--- the input of the model must be a set of vertical OHLC vectors
      x_norm=rates.Transpose();
      //--- normalize prices
      x_norm-=mm;
      x_norm/=ms;
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Как было сказано выше: "Три модели. Каждая отличается от другой размером входных данных, подготовкой входных данных". И мы переопределили всего два метода — Init и PredictClass. В двух других классах для двух других моделей будут переопределены эти же самые методы.

В методе Init вызывается метод базового класса CheckInit, где создаётся сессия для нашей ONNX-модели. А также явно выставляются размеры входного и выходного тензоров. Здесь больше комментариев, чем кода.

В методе PredictClass обеспечивается точно такая же подготовка входных данных, что и при обучении модели. На вход подаётся матрица нормализованных цен OHLC.


4. Проверим, как это работает

Для проверки работоспособности нашего класса был создан очень компактный эксперт.

//+------------------------------------------------------------------+
//|                                    ONNX.eurusd.D1.Prediction.mq5 |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright   "Copyright 2023, MetaQuotes Ltd."
#property link        "https://www.mql5.com"
#property version     "1.00"

#include "ModelEurusdD1_10Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_10Class ExtModel;
CTrade                 ExtTrade;

//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
   if(!ExtModel.Init(_Symbol,_Period))
      return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
   ExtModel.Shutdown();
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
   if(!ExtModel.CheckOnTick())
      return;

//--- predict next price movement
   int predicted_class=ExtModel.PredictClass();
//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }
//+------------------------------------------------------------------+
//| Check for open position conditions                               |
//+------------------------------------------------------------------+
void CheckForOpen(const int predicted_class)
  {
   ENUM_ORDER_TYPE signal=WRONG_VALUE;
//--- check signals

   if(predicted_class==PRICE_DOWN)
      signal=ORDER_TYPE_SELL;    // sell condition
   else
     {
      if(predicted_class==PRICE_UP)
         signal=ORDER_TYPE_BUY;  // buy condition
     }

//--- open position if possible according to signal
   if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK);
      ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0);
     }
  }
//+------------------------------------------------------------------+
//| Check for close position conditions                              |
//+------------------------------------------------------------------+
void CheckForClose(const int predicted_class)
  {
   bool bsignal=false;
//--- position already selected before
   long type=PositionGetInteger(POSITION_TYPE);
//--- check signals
   if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN)
      bsignal=true;
   if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP)
      bsignal=true;

//--- close position if possible
   if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      ExtTrade.PositionClose(_Symbol,3);
      //--- open opposite
      CheckForOpen(predicted_class);
     }
  }
//+------------------------------------------------------------------+

Так как модель обучалась на ценовых данных до 2023 года, запустим тестирование с 1 января 2023 года.

Настройки тестирования

И получим следующий результат:

Результаты тестирования

Как видим, вполне работоспособная модель.


5. Класс для второй модели

Вторая модель называется model.eurusd.D1.30.class.onnx. Классификационная модель, тренированная на EURUSD D1 на сериях из 30 цен Close и двух простых скользящих средних с периодами усреднения 21 и 34.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_30Class.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_30Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;
   int               m_fast_period;
   int               m_slow_period;
   int               m_sma_fast;
   int               m_sma_slow;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=30;
      m_fast_period=21;
      m_slow_period=34;
      m_sma_fast=INVALID_HANDLE;
      m_sma_slow=INVALID_HANDLE;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class))
        {
         Print("model_eurusd_D1_30_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow)
      const long input_shape[] = {1,m_sample_size,3};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- indicators
      m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE);
      m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE);
      if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE)
        {
         Print("model_eurusd_D1_30_class : cannot create indicator");
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static matrixf input_data(m_sample_size,3);    // matrix for prepared input data
      static vectorf output_data(3);                 // vector to get result
      static matrix  x_norm(m_sample_size,3);        // matrix for prices normalize
      static vector  vtemp(m_sample_size);
      static double  ma_buffer[];
   
      //--- request last bars
      if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      double m=vtemp.Mean();
      //--- get series Std
      double s=vtemp.Std();
      //--- normalize
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,0);
      //--- fast sma
      if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,1);
      //--- slow sma
      if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,2);
   
      //--- run the inference
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Как и в предыдущем классе, в методе Init вызывается метод базового класса CheckInit, где создаётся сессия для ONNX-модели и явно выставляются размеры входного и выходного тензоров

В методе PredictClass обеспечиваются серии из 30 предыдущих Close и рассчитанные скользящие средние. Данные нормируются тем же способом, что и при обучении.

Проверим, как работает эта модель. Для этого изменим всего две строки проверочного эксперта

#include "ModelEurusdD1_30Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_30Class ExtModel;
CTrade                 ExtTrade;

Параметры тестирования те же самые.

Результаты тестирования второй модели

Видим, что модель работает.


6. Класс для третьей модели

Последняя модель называется model.eurusd.D1.63.class.onnx. Классификационная модель, тренированная на EURUSD D1 на сериях из 63 цен Close.

//+------------------------------------------------------------------+
//|                                             ModelEurusdD1_63.mqh |
//|                                  Copyright 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[]

//+------------------------------------------------------------------+
//| ONNX-model wrapper class                                         |
//+------------------------------------------------------------------+
class CModelEurusdD1_63Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Constructor                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001)
     {
      m_sample_size=63;
     }

   //+------------------------------------------------------------------+
   //| ONNX-model initialization                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- check symbol, period, create model
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class))
        {
         Print("model_eurusd_D1_63_class : initialization error");
         return(false);
        }

      //--- since not all sizes defined in the input tensor we must set them explicitly
      //--- first index - batch size, second index - series size
      const long input_shape[] = {1,m_sample_size};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError());
         return(false);
        }
   
      //--- since not all sizes defined in the output tensor we must set them explicitly
      //--- first index - batch size, must match the batch size of the input tensor
      //--- second index - number of classes (up, same or down)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError());
         return(false);
        }
      //--- ok
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(void)
     {
      static vectorf input_data(m_sample_size);  // vector for prepared input data
      static vectorf output_data(3);             // vector to get result
   
      //--- request last bars
      if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- get series Mean
      float m=input_data.Mean();
      //--- get series Std
      float s=input_data.Std();
      //--- normalize prices
      input_data-=m;
      input_data/=s;
   
      //--- run the inference
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- evaluate prediction
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Это — самая простая модель из трёх. Поэтому код метода PredictClass получился таким компактным.

Опять изменим две строки в эксперте

#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_63Class ExtModel;
CTrade                 ExtTrade;

И запустим тестирование с теми же самыми настройками.

Результаты тестирования третьей модели

Модель работает



7. Собираем все модели в одном эксперте. Жёсткое голосование

Все три модели показали свою работоспособность. Теперь попробуем объединить их усилия. Устроим голосование моделей.

Предварительные объявления и определения

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double  InpLots  = 1.0;    // Lots amount to open position

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;

Функция OnInit

int OnInit()
  {
   ExtModels[0]=new CModelEurusdD1_10Class;
   ExtModels[1]=new CModelEurusdD1_30Class;
   ExtModels[2]=new CModelEurusdD1_63Class;

   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].Init(_Symbol,_Period))
         return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }

Функция OnTick

void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int returned[3]={0,0,0};
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      int pred=ExtModels[i].PredictClass();
      if(pred>=0)
         returned[pred]++;
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- count votes for predictions
   for(int n=0; n<3; n++)
     {
      if(returned[n]>=2)
        {
         predicted_class=n;
         break;
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

Большинство голосов считается по формуле <общее количество голосов>/2 + 1. Для общего числа голосов 3 большинством являются 2 голоса. Это - так называемое "жёсткое голосование"

Результат тестирования всё с теми же самыми настройками.

Результаты тестирования жёсткого голосования

Вспомним работу всех трёх моделей по отдельности, а именно количество прибыльных и убыточных трейдов. Первая модель — 11 : 3, вторая модель — 6 : 1, третья модель — 16 : 10.

Похоже, при помощи жёсткого голосования мы улучшили результат — 16 : 4. Но, конечно же, необходимо смотреть полные отчёты и графики тестирования.


8. Мягкое голосование

Мягкое голосование отличается от жёсткого тем, что учитывается не количество голосов, а считается сумма вероятностей всех трёх классов от всех трёх моделей. И уже по самой высокой вероятности выбирается класс.

Для обеспечения мягкого голосования необходимо внести некоторые изменения.

В базовом классе:

   //+------------------------------------------------------------------+
   //| Predict class (regression -> classification)                     |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- set predicted probability as 1.0
      probabilities.Fill(0);
      if(predicted_class<(int)probabilities.Size())
         probabilities[predicted_class]=1;
      //--- and return predicted class
      return(predicted_class);
     }

В классах наследниках:

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- evaluate prediction
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }

В эксперте:

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

enum EnVotes
  {
   Two=2,    // Two votes
   Three=3,  // Three votes
   Soft=4    // Soft voting
  };

input double  InpLots  = 1.0;    // Lots amount to open position
input EnVotes InpVotes = Two;    // Votes to make trade decision

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;
void OnTick()
  {
   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- predict next price movement
   int    returned[3]={0,0,0};
   vector soft=vector::Zeros(3);
//--- collect returned classes
   for(long i=0; i<ExtModels.Size(); i++)
     {
      vector prob(3);
      int    pred=ExtModels[i].PredictClass(prob);
      if(pred>=0)
        {
         returned[pred]++;
         soft+=prob;
        }
     }
//--- get one prediction for all models
   int predicted_class=-1;
//--- soft or hard voting
   if(InpVotes==Soft)
      predicted_class=(int)soft.ArgMax();
   else
     {
      //--- count votes for predictions
      for(int n=0; n<3; n++)
        {
         if(returned[n]>=InpVotes)
           {
            predicted_class=n;
            break;
           }
        }
     }

//--- check trading according to prediction
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

Тестируем всё с теми же настройками. Во входных параметрах выбираем Soft.

Настройки входных параметров

Получим результат.

Результаты тестирования мягкого голосования

Прибыльных трейдов — 15, Убыточных трейдов — 3. В денежном выражении жёсткое голосование тоже оказалось лучше, чем мягкое.


Интересно посмотреть на результат единогласного голосования, то есть при количестве голосов 3.

Результаты тестирования единогласного голосования

Очень консервативная торговля. При этом единственная убыточная сделка была закрыта при окончании тестирования (возможно, она и не убыточная на самом деле).

График тестирования единогласного голосования


Важно: обращаем ваше внимание, что использованные в статье модели представлены только в целях демонстрации работы с ONNX-моделями средствами языка MQL5. Советник не предназначен для торговли на реальных счетах.


Заключение

Мы показали, как объектно-ориентированное программирование позволяет упростить написание программ. Все сложности моделей (а модели могут быть гораздо более сложными, чем представленные нами в качестве примера) прячутся в своих классах. А остальная "сложность" уместилась в 45 строках функции OnTick.