Encapsulando modelos ONNX em classes
Introdução
No artigo anterior, utilizamos dois modelos ONNX para elaborar um classificador de votação. Nesse processo, todo o código-fonte foi preparado em um único arquivo MQ5. Sim, todo o código foi dividido em funções, mas tente, por exemplo, trocar a ordem dos modelos. Bem, e se adicionarmos mais um modelo? O código-fonte ficará ainda mais extenso. Vamos tentar uma abordagem orientada a objetos.
1. Quais modelos vamos utilizar
No classificador de votação anterior, utilizamos um modelo de classificação e um modelo de regressão. No modelo de regressão, em vez de obtermos um movimento de preço previsto (para baixo, para cima, sem mudanças), obtemos um preço previsto com base no qual calculamos a classe. No entanto, nesse caso, não temos uma distribuição de probabilidades das classes, o que impede a realização do chamado "voto suave".
Preparamos 3 modelos de classificação. Dois desses modelos já foram utilizados no artigo "Exemplo de como montar modelos ONNX em MQL5". O primeiro modelo, que era de regressão, foi convertido em um modelo de classificação. O treinamento foi realizado em séries de 10 preços OHLC. O segundo modelo é de classificação e foi treinado em séries de 63 preços de fechamento.
Por fim, temos mais um modelo. O modelo de classificação foi treinado em séries de 30 preços de fechamento e séries de médias móveis simples com períodos de média de 21 e 34. Não fizemos suposições sobre o cruzamento das médias móveis com o gráfico de fechamento, e todas as regularidades serão calculadas e memorizadas pela rede na forma de matrizes de coeficientes entre as camadas.
Todos os modelos foram treinados nos dados do servidor MetaQuotes-Demo, no par de moedas EURUSD, período D1, de 01/01/2010 a 01/01/2023. Os scripts de treinamento de todos os três modelos foram escritos em Python e estão anexados a este artigo. Não os incluiremos aqui para não distrair o leitor do tópico principal do nosso artigo.
2. Necessidade de uma classe base para todos os modelos
Temos três modelos. Cada um difere do outro no tamanho dos dados de entrada e no pré-processamento desses dados. Todos os modelos têm algo em comum: a mesma interface. As classes de todos os modelos devem herdar de uma mesma classe base.
Vamos tentar apresentar a classe base.
//+------------------------------------------------------------------+ //| ModelSymbolPeriod.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ //--- price movement prediction #define PRICE_UP 0 #define PRICE_SAME 1 #define PRICE_DOWN 2 //+------------------------------------------------------------------+ //| Base class for models based on trained symbol and period | //+------------------------------------------------------------------+ class CModelSymbolPeriod { protected: long m_handle; // created model session handle string m_symbol; // symbol of trained data ENUM_TIMEFRAMES m_period; // timeframe of trained data datetime m_next_bar; // time of next bar (we work at bar begin only) double m_class_delta; // delta to recognize "price the same" in regression models public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001) { m_handle=INVALID_HANDLE; m_symbol=symbol; m_period=period; m_next_bar=0; m_class_delta=class_delta; } //+------------------------------------------------------------------+ //| Destructor | //+------------------------------------------------------------------+ ~CModelSymbolPeriod(void) { Shutdown(); } //+------------------------------------------------------------------+ //| virtual stub for Init | //+------------------------------------------------------------------+ virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period) { return(false); } //+------------------------------------------------------------------+ //| Check for initialization, create model | //+------------------------------------------------------------------+ bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[]) { //--- check symbol, period if(symbol!=m_symbol || period!=m_period) { PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period)); return(false); } //--- create a model from static buffer m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT); if(m_handle==INVALID_HANDLE) { Print("OnnxCreateFromBuffer error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Release ONNX session | //+------------------------------------------------------------------+ void Shutdown(void) { if(m_handle!=INVALID_HANDLE) { OnnxRelease(m_handle); m_handle=INVALID_HANDLE; } } //+------------------------------------------------------------------+ //| Check for continue OnTick | //+------------------------------------------------------------------+ virtual bool CheckOnTick(void) { //--- check new bar if(TimeCurrent()<m_next_bar) return(false); //--- set next bar time m_next_bar=TimeCurrent(); m_next_bar-=m_next_bar%PeriodSeconds(m_period); m_next_bar+=PeriodSeconds(m_period); //--- work on new day bar return(true); } //+------------------------------------------------------------------+ //| virtual stub for PredictPrice (regression model) | //+------------------------------------------------------------------+ virtual double PredictPrice(void) { return(DBL_MAX); } //+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(void) { double predicted_price=PredictPrice(); if(predicted_price==DBL_MAX) return(-1); int predicted_class=-1; double last_close=iClose(m_symbol,m_period,1); //--- classify predicted price movement double delta=last_close-predicted_price; if(fabs(delta)<=m_class_delta) predicted_class=PRICE_SAME; else { if(delta<0) predicted_class=PRICE_UP; else predicted_class=PRICE_DOWN; } //--- return predicted class return(predicted_class); } }; //+------------------------------------------------------------------+
Essa classe base pode ser usada tanto para modelos de regressão quanto para modelos de classificação. Apenas será necessário implementar o método adequado no classe filha - PredictPrice ou PredictClass.
Na classe base, é definido o símbolo e período com os quais o modelo deve trabalhar (com base nos dados em que foi treinado). Na classe base, é feita uma verificação para garantir que o expert que utiliza o modelo está trabalhando no símbolo e período corretos, e também é criada uma sessão ONNX para executar o modelo. A classe base garante que o expert funcione apenas no início de uma nova barra.
3. Classe para o primeiro modelo
Nosso primeiro modelo é chamado model.eurusd.D1.10.class.onnx, ou seja, é um modelo de classificação, treinado no EURUSD D1 com séries de 10 preços OHLC.
//+------------------------------------------------------------------+ //| ModelEurusdD1_10Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_10Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=10; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class)) { Print("model_eurusd_D1_10_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (OHLC) const long input_shape[] = {1,m_sample_size,4}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,4); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix mm(m_sample_size,4); // matrix of horizontal vectors Mean static matrix ms(m_sample_size,4); // matrix of horizontal vectors Std static matrix x_norm(m_sample_size,4); // matrix for prices normalize //--- prepare input data matrix rates; //--- request last bars if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size)) return(-1); //--- get series Mean vector m=rates.Mean(1); //--- get series Std vector s=rates.Std(1); //--- prepare matrices for prices normalization for(int i=0; i<m_sample_size; i++) { mm.Row(m,i); ms.Row(s,i); } //--- the input of the model must be a set of vertical OHLC vectors x_norm=rates.Transpose(); //--- normalize prices x_norm-=mm; x_norm/=ms; //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Como mencionado anteriormente: "Temos três modelos. Cada um difere do outro no tamanho dos dados de entrada, no pré-processamento desses dados". E redefinimos apenas dois métodos - Init e PredictClass. Nos outros dois classes para os outros dois modelos, os mesmos métodos serão redefinidos.
No método Init, o método CheckInit da classe base é chamado, onde é criada uma sessão para o nosso modelo ONNX. Além disso, os tamanhos dos tensores de entrada e saída são explicitamente configurados. Aqui, há mais comentários do que código.
No método PredictClass, é fornecida a mesma preparação dos dados de entrada que no treinamento do modelo. Uma matriz de preços OHLC normalizados é fornecida como entrada.
4. Vamos verificar como isso funciona
Para testar a funcionalidade de nossa classe, foi criado um expert muito compacto.
//+------------------------------------------------------------------+ //| ONNX.eurusd.D1.Prediction.mq5 | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2023, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #include "ModelEurusdD1_10Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_10Class ExtModel; CTrade ExtTrade; //+------------------------------------------------------------------+ //| Expert initialization function | //+------------------------------------------------------------------+ int OnInit() { if(!ExtModel.Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); } //+------------------------------------------------------------------+ //| Expert deinitialization function | //+------------------------------------------------------------------+ void OnDeinit(const int reason) { ExtModel.Shutdown(); } //+------------------------------------------------------------------+ //| Expert tick function | //+------------------------------------------------------------------+ void OnTick() { if(!ExtModel.CheckOnTick()) return; //--- predict next price movement int predicted_class=ExtModel.PredictClass(); //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); } //+------------------------------------------------------------------+ //| Check for open position conditions | //+------------------------------------------------------------------+ void CheckForOpen(const int predicted_class) { ENUM_ORDER_TYPE signal=WRONG_VALUE; //--- check signals if(predicted_class==PRICE_DOWN) signal=ORDER_TYPE_SELL; // sell condition else { if(predicted_class==PRICE_UP) signal=ORDER_TYPE_BUY; // buy condition } //--- open position if possible according to signal if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK); ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0); } } //+------------------------------------------------------------------+ //| Check for close position conditions | //+------------------------------------------------------------------+ void CheckForClose(const int predicted_class) { bool bsignal=false; //--- position already selected before long type=PositionGetInteger(POSITION_TYPE); //--- check signals if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN) bsignal=true; if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP) bsignal=true; //--- close position if possible if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { ExtTrade.PositionClose(_Symbol,3); //--- open opposite CheckForOpen(predicted_class); } } //+------------------------------------------------------------------+
Como o modelo foi treinado com dados de preços até 2023, executaremos o teste a partir de 1º de janeiro de 2023.
E obteremos o seguinte resultado:
Como podemos ver, o modelo funciona perfeitamente.
5. Classe para o segundo modelo
O segundo modelo é chamado model.eurusd.D1.30.class.onnx. É um modelo de classificação, treinado no EURUSD D1 com séries de 30 preços de fechamento e duas médias móveis simples com períodos de 21 e 34.
//+------------------------------------------------------------------+ //| ModelEurusdD1_30Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_30Class : public CModelSymbolPeriod { private: int m_sample_size; int m_fast_period; int m_slow_period; int m_sma_fast; int m_sma_slow; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=30; m_fast_period=21; m_slow_period=34; m_sma_fast=INVALID_HANDLE; m_sma_slow=INVALID_HANDLE; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class)) { Print("model_eurusd_D1_30_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow) const long input_shape[] = {1,m_sample_size,3}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- indicators m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE); m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE); if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE) { Print("model_eurusd_D1_30_class : cannot create indicator"); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,3); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix x_norm(m_sample_size,3); // matrix for prices normalize static vector vtemp(m_sample_size); static double ma_buffer[]; //--- request last bars if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean double m=vtemp.Mean(); //--- get series Std double s=vtemp.Std(); //--- normalize vtemp-=m; vtemp/=s; x_norm.Col(vtemp,0); //--- fast sma if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,1); //--- slow sma if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,2); //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Assim como na classe anterior, no método Init, o método CheckInit da classe base é chamado, onde é criada uma sessão para o modelo ONNX e os tamanhos dos tensores de entrada e saída são explicitamente configurados.
No método PredictClass, são fornecidas as séries dos últimos 30 preços de fechamento e as médias móveis calculadas. Os dados são normalizados da mesma forma que no treinamento.
Vamos verificar como esse modelo funciona. Para isso, modificaremos apenas duas linhas no expert de teste.
#include "ModelEurusdD1_30Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_30Class ExtModel; CTrade ExtTrade;
Os parâmetros de teste são os mesmos.
Vemos que o modelo está funcionando.
6. Classe para o terceiro modelo
O último modelo é chamado model.eurusd.D1.63.class.onnx. É um modelo de classificação, treinado no EURUSD D1 com séries de 63 preços de fechamento.
//+------------------------------------------------------------------+ //| ModelEurusdD1_63.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_63Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001) { m_sample_size=63; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class)) { Print("model_eurusd_D1_63_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size const long input_shape[] = {1,m_sample_size}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static vectorf input_data(m_sample_size); // vector for prepared input data static vectorf output_data(3); // vector to get result //--- request last bars if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean float m=input_data.Mean(); //--- get series Std float s=input_data.Std(); //--- normalize prices input_data-=m; input_data/=s; //--- run the inference if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Este é o modelo mais simples dos três. Portanto, o código do método PredictClass ficou muito compacto.
Vamos alterar novamente duas linhas no expert de teste.
#include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_63Class ExtModel; CTrade ExtTrade;
E iniciaremos o teste com as mesmas configurações.
O modelo está funcionando.
7. Combinando todos os modelos em um expert. Votação rígida
Todos os três modelos demonstraram sua eficácia. Agora vamos tentar combinar seus esforços. Faremos uma votação dos modelos.
Declarações e definições
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
Função OnInit
int OnInit() { ExtModels[0]=new CModelEurusdD1_10Class; ExtModels[1]=new CModelEurusdD1_30Class; ExtModels[2]=new CModelEurusdD1_63Class; for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); }
Função OnTick
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { int pred=ExtModels[i].PredictClass(); if(pred>=0) returned[pred]++; } //--- get one prediction for all models int predicted_class=-1; //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=2) { predicted_class=n; break; } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
A maioria dos votos é calculada pela fórmula <número total de votos>/2 + 1. Para um número total de votos igual a 3, a maioria é de 2 votos. Isso é chamado de "votação rígida".
Resultado do teste com as mesmas configurações.
Lembremos do desempenho de cada um dos três modelos individualmente, ou seja, a quantidade de trades lucrativos e não lucrativos. O primeiro modelo - 11:3, o segundo modelo - 6:1, o terceiro modelo - 16:10. 10.
Parece que, com a votação rígida, melhoramos o resultado - 16:4. 4. No entanto, é claro que precisamos verificar os relatórios completos e os gráficos de teste.
8. Votação Suave
A votação suave difere da votação rígida pelo fato de que não se leva em conta a quantidade de votos, mas sim a soma das probabilidades de todas as três classes de todos os três modelos. E a classe é escolhida com base na probabilidade mais alta.
Para implementar a votação suave, algumas alterações são necessárias.
Na classe base:
//+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- set predicted probability as 1.0 probabilities.Fill(0); if(predicted_class<(int)probabilities.Size()) probabilities[predicted_class]=1; //--- and return predicted class return(predicted_class); }
Nas classes filhas:
//+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- evaluate prediction probabilities.Assign(output_data); return(int(output_data.ArgMax())); }
No expert:
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> enum EnVotes { Two=2, // Two votes Three=3, // Three votes Soft=4 // Soft voting }; input double InpLots = 1.0; // Lots amount to open position input EnVotes InpVotes = Two; // Votes to make trade decision CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; vector soft=vector::Zeros(3); //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { vector prob(3); int pred=ExtModels[i].PredictClass(prob); if(pred>=0) { returned[pred]++; soft+=prob; } } //--- get one prediction for all models int predicted_class=-1; //--- soft or hard voting if(InpVotes==Soft) predicted_class=(int)soft.ArgMax(); else { //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=InpVotes) { predicted_class=n; break; } } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
Testamos tudo com as mesmas configurações. Nos parâmetros de entrada, selecionamos "Soft" (suave).
Obtemos o resultado.
Trades lucrativos - 15, Trades não lucrativos - 3. Em termos de lucro, a votação rígida também se mostrou superior à suave.
É interessante ver o resultado da votação unânime, ou seja, com três votos.
Um estilo de negociação muito conservador. O único trade não lucrativo foi fechado ao final do teste (possivelmente, ela não foi realmente não lucrativa).
Importante: destacamos que os modelos usados neste artigo são apenas para fins de demonstração do uso de modelos ONNX através da linguagem MQL5. O Expert Advisor não deve ser usado para negociações em contas reais.
Considerações finais
Demonstramos como a programação orientada a objetos permite simplificar a escrita de programas. Todas as complexidades dos modelos (e os modelos podem ser muito mais complexos do que os apresentados como exemplo) são encapsuladas em suas classes. E o restante da "complexidade" se encaixou em 45 linhas da função OnTick.
Traduzido do russo pela MetaQuotes Ltd.
Artigo original: https://www.mql5.com/ru/articles/12484
- Aplicativos de negociação gratuitos
- 8 000+ sinais para cópia
- Notícias econômicas para análise dos mercados financeiros
Você concorda com a política do site e com os termos de uso