
클래스에서 ONNX 모델 래핑하기
소개
이전 글에서 우리는 투표 분류기를 배열하는 데 두 개의 ONNX 모델을 사용했습니다. 전체 소스 텍스트는 단일 MQ5 파일로 구성되었습니다. 전체 코드가 함수로 나뉘어져 있었습니다. 하지만 우리가 모델을 바꾸면 어떨까요? 혹은 다른 모델을 추가하면 어떨까요? 원본 텍스트가 더 커지게 될 것입니다. 이때 객체 지향 접근 방식을 사용해 보겠습니다.
1. 어떤 모델을 사용해 볼까요?
이전 투표 분류기에서 우리는 하나의 분류 모델과 하나의 회귀 모델을 사용했습니다. 회귀 모델에서는 예측된 가격 변동(하락, 상승, 변동 없음) 대신에 클래스를 계산하는 데 사용된 예측 가격을 우리는 얻게 됩니다. 그러나 이 경우 우리에게는 클래스별 확률 분포가 없기 때문에 소위 '소프트 투표'를 허용하지 않습니다.
여기 3가지 분류 모델을 준비했습니다. "MQL5에서 ONNX 모델을 앙상블하는 방법의 예" 문서에서 이미 두 가지 모델이 사용되었습니다. 첫 번째 모델(회귀)을 분류 모델로 변환했습니다. 10가지 OHLC 가격에 대한 학습이 진행되었습니다. 두 번째 모델은 분류 모델입니다. 일련의 63 종가에 대한 학습이 진행되었습니다.
마지막으로 모델이 하나 더 있습니다. 분류 모델은 일련의 30개의 종가와 평균 기간이 21 및 34인 일련의 단순 이동 평균에 대해 학습되었습니다. 우리는 이동 평균과 종가 차트의 교차점에 대해 어떤 가정도하지 않았습니다 - 모든 패턴은 레이어 간의 계수 행렬 형태로 네트워크에 의해 계산되고 기억됩니다.
모든 모델은 2010.01.01~2023.01.01의 EURUSD D1 MetaQuotes-Demo server data로 학습 되었습니다. 세 가지 모델 모두에 대한 학습 스크립트는 Python으로 작성되었으며 이 문서에 첨부되어 있습니다. 이 글의 주요 주제로부터 독자분들의 주의를 분산시키지 않기 위해 여기서는 소스 코드를 제공하지 않을 것입니다.
2. 모든 모델에 하나의 기본 클래스가 필요합니다.
세 가지 모델이 있습니다. 입력 데이터의 크기와 준비 방식이 각각 다릅니다. 모든 모델이 동일한 인터페이스를 갖습니다. 모든 모델의 클래스는 동일한 기본 클래스에서 상속되어야 합니다.
베이스 클래스를 표현해 보겠습니다.
//+------------------------------------------------------------------+ //| ModelSymbolPeriod.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ //--- price movement prediction #define PRICE_UP 0 #define PRICE_SAME 1 #define PRICE_DOWN 2 //+------------------------------------------------------------------+ //| Base class for models based on trained symbol and period | //+------------------------------------------------------------------+ class CModelSymbolPeriod { protected: long m_handle; // created model session handle string m_symbol; // symbol of trained data ENUM_TIMEFRAMES m_period; // timeframe of trained data datetime m_next_bar; // time of next bar (we work at bar begin only) double m_class_delta; // delta to recognize "price the same" in regression models public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001) { m_handle=INVALID_HANDLE; m_symbol=symbol; m_period=period; m_next_bar=0; m_class_delta=class_delta; } //+------------------------------------------------------------------+ //| Destructor | //+------------------------------------------------------------------+ ~CModelSymbolPeriod(void) { Shutdown(); } //+------------------------------------------------------------------+ //| virtual stub for Init | //+------------------------------------------------------------------+ virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period) { return(false); } //+------------------------------------------------------------------+ //| Check for initialization, create model | //+------------------------------------------------------------------+ bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[]) { //--- check symbol, period if(symbol!=m_symbol || period!=m_period) { PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period)); return(false); } //--- create a model from static buffer m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT); if(m_handle==INVALID_HANDLE) { Print("OnnxCreateFromBuffer error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Release ONNX session | //+------------------------------------------------------------------+ void Shutdown(void) { if(m_handle!=INVALID_HANDLE) { OnnxRelease(m_handle); m_handle=INVALID_HANDLE; } } //+------------------------------------------------------------------+ //| Check for continue OnTick | //+------------------------------------------------------------------+ virtual bool CheckOnTick(void) { //--- check new bar if(TimeCurrent()<m_next_bar) return(false); //--- set next bar time m_next_bar=TimeCurrent(); m_next_bar-=m_next_bar%PeriodSeconds(m_period); m_next_bar+=PeriodSeconds(m_period); //--- work on new day bar return(true); } //+------------------------------------------------------------------+ //| virtual stub for PredictPrice (regression model) | //+------------------------------------------------------------------+ virtual double PredictPrice(void) { return(DBL_MAX); } //+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(void) { double predicted_price=PredictPrice(); if(predicted_price==DBL_MAX) return(-1); int predicted_class=-1; double last_close=iClose(m_symbol,m_period,1); //--- classify predicted price movement double delta=last_close-predicted_price; if(fabs(delta)<=m_class_delta) predicted_class=PRICE_SAME; else { if(delta<0) predicted_class=PRICE_UP; else predicted_class=PRICE_DOWN; } //--- return predicted class return(predicted_class); } }; //+------------------------------------------------------------------+
기본 클래스는 회귀 및 분류 모델 모두에 사용할 수 있습니다. 우리는 하위 클래스에서 적절한 메서드(PredictPrice 또는 PredictClass)를 구현하기만 하면 됩니다.
기본 클래스가 모델이 작업할 심볼 기간(모델의 학습이 이루어진 데이터)을 설정합니다. 또한 기본 클래스는 필요한 심볼 기간에 모델을 사용하는 EA가 작동하는지 확인하고 모델을 실행하기 위해 ONNX 세션을 생성합니다. 기본 클래스는 새로운 바가 시작될 때만 작업을 제공합니다.
3. 첫 번째 모델 클래스
첫 번째 모델은 model.eurusd.D1.10.class.onnx로, 일련의 10개의 OHLC 가격에 대한 EURUSD D1에 대해 학습된 분류 모델입니다.
//+------------------------------------------------------------------+ //| ModelEurusdD1_10Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_10Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=10; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class)) { Print("model_eurusd_D1_10_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (OHLC) const long input_shape[] = {1,m_sample_size,4}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,4); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix mm(m_sample_size,4); // matrix of horizontal vectors Mean static matrix ms(m_sample_size,4); // matrix of horizontal vectors Std static matrix x_norm(m_sample_size,4); // matrix for prices normalize //--- prepare input data matrix rates; //--- request last bars if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size)) return(-1); //--- get series Mean vector m=rates.Mean(1); //--- get series Std vector s=rates.Std(1); //--- prepare matrices for prices normalization for(int i=0; i<m_sample_size; i++) { mm.Row(m,i); ms.Row(s,i); } //--- the input of the model must be a set of vertical OHLC vectors x_norm=rates.Transpose(); //--- normalize prices x_norm-=mm; x_norm/=ms; //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
위에서 이미 언급했듯이: "세 가지 모델이 있습니다. 각각의 모델은 입력 데이터의 크기와 준비가 각각 다릅니다." 우리는 Init과 PredictClass의 두 가지 메서드만 재정의했습니다. 동일한 메서드가 다른 두 모델에 다른 두 클래스에서 재정의됩니다.
Init 메서드는 ONNX 모델에 대한 세션이 생성되고 입력 및 출력 텐서의 크기가 명시적으로 설정되는 CheckInit 베이스 클래스 메서드를 호출합니다. 여기에는 코드보다 코멘트가 더 많습니다.
PredictClass 메서드는 모델을 훈련할 때와 정확히 동일한 입력 데이터 준비를 제공합니다. 입력은 정규화된 OHLC 가격의 행렬입니다.
4. 이제 작동 방식을 확인해 보겠습니다.
클래스의 성능을 테스트하기 위해 매우 간단한 Expert Advisor가 만들어졌습니다.
//+------------------------------------------------------------------+ //| ONNX.eurusd.D1.Prediction.mq5 | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2023, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #include "ModelEurusdD1_10Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_10Class ExtModel; CTrade ExtTrade; //+------------------------------------------------------------------+ //| Expert initialization function | //+------------------------------------------------------------------+ int OnInit() { if(!ExtModel.Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); } //+------------------------------------------------------------------+ //| Expert deinitialization function | //+------------------------------------------------------------------+ void OnDeinit(const int reason) { ExtModel.Shutdown(); } //+------------------------------------------------------------------+ //| Expert tick function | //+------------------------------------------------------------------+ void OnTick() { if(!ExtModel.CheckOnTick()) return; //--- predict next price movement int predicted_class=ExtModel.PredictClass(); //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); } //+------------------------------------------------------------------+ //| Check for open position conditions | //+------------------------------------------------------------------+ void CheckForOpen(const int predicted_class) { ENUM_ORDER_TYPE signal=WRONG_VALUE; //--- check signals if(predicted_class==PRICE_DOWN) signal=ORDER_TYPE_SELL; // sell condition else { if(predicted_class==PRICE_UP) signal=ORDER_TYPE_BUY; // buy condition } //--- open position if possible according to signal if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK); ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0); } } //+------------------------------------------------------------------+ //| Check for close position conditions | //+------------------------------------------------------------------+ void CheckForClose(const int predicted_class) { bool bsignal=false; //--- position already selected before long type=PositionGetInteger(POSITION_TYPE); //--- check signals if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN) bsignal=true; if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP) bsignal=true; //--- close position if possible if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { ExtTrade.PositionClose(_Symbol,3); //--- open opposite CheckForOpen(predicted_class); } } //+------------------------------------------------------------------+
이 모델은 2023년까지 가격 데이터로 학습되었습니다. 우리는 2023년 1월 1일부터 테스트를 시작하겠습니다.
결과가 아래와 같이 표시됩니다:
보시다시피 이 모델은 완벽하게 작동합니다.
5. 두 번째 모델 클래스
두 번째 모델은 model.eurusd.D1.30.class.onnx입니다. 분류 모델은 일련의 30개의 종가와 평균 기간이 21, 34인 두 개의 단순 이동 평균에 대해 EURUSD D1을 학습시켰습니다.
//+------------------------------------------------------------------+ //| ModelEurusdD1_30Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_30Class : public CModelSymbolPeriod { private: int m_sample_size; int m_fast_period; int m_slow_period; int m_sma_fast; int m_sma_slow; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=30; m_fast_period=21; m_slow_period=34; m_sma_fast=INVALID_HANDLE; m_sma_slow=INVALID_HANDLE; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class)) { Print("model_eurusd_D1_30_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow) const long input_shape[] = {1,m_sample_size,3}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- indicators m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE); m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE); if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE) { Print("model_eurusd_D1_30_class : cannot create indicator"); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,3); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix x_norm(m_sample_size,3); // matrix for prices normalize static vector vtemp(m_sample_size); static double ma_buffer[]; //--- request last bars if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean double m=vtemp.Mean(); //--- get series Std double s=vtemp.Std(); //--- normalize vtemp-=m; vtemp/=s; x_norm.Col(vtemp,0); //--- fast sma if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,1); //--- slow sma if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,2); //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
이전 클래스에서와 마찬가지로, CheckInit 베이스 클래스 메서드는 Init 메서드에서 호출됩니다. 베이스 클래스 메서드에서는 ONNX 모델에 대한 세션이 생성되고 입력 및 출력 텐서의 크기가 명시적으로 설정됩니다.
PredictClass 메서드는 30개의 이전 종가 및 계산된 이동 평균을 제공합니다. 데이터는 학습시와 마찬가지의 방식으로 정규화됩니다.
이 모델이 어떻게 작동하는지 살펴봅시다. 이를 위해 테스트 EA의 두 문자열만 변경해 보겠습니다.
#include "ModelEurusdD1_30Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_30Class ExtModel; CTrade ExtTrade;
테스트 매개변수는 동일합니다.
모델이 작동합니다.
6. 세 번째 모델 클래스
마지막 모델은 model.eurusd.D1.63.class.onnx입니다. 분류 모델은 일련의 63 종가에 대해 EURUSD D1을 학습시켰습니다.
//+------------------------------------------------------------------+ //| ModelEurusdD1_63.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_63Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001) { m_sample_size=63; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class)) { Print("model_eurusd_D1_63_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size const long input_shape[] = {1,m_sample_size}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static vectorf input_data(m_sample_size); // vector for prepared input data static vectorf output_data(3); // vector to get result //--- request last bars if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean float m=input_data.Mean(); //--- get series Std float s=input_data.Std(); //--- normalize prices input_data-=m; input_data/=s; //--- run the inference if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
세 가지 모델 중 가장 간단한 모델입니다. 이것이 바로 PredictClass 메서드의 코드가 매우 간결한 이유입니다.
EA에서 두 개의 문자열을 다시 변경해 보겠습니다.
#include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_63Class ExtModel; CTrade ExtTrade;
동일한 설정으로 테스트를 시작합니다.
모델이 작동합니다.
7. 모든 모델을 하나의 EA에 수집합니다. 하드 투표
세 모델 모두 작업 능력이 있음을 보여주었습니다. 이제 이들의 능력을 합쳐 보겠습니다. 모델 투표를 준비해 봅시다.
선언 및 정의 전달
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
OnInit 함수
int OnInit() { ExtModels[0]=new CModelEurusdD1_10Class; ExtModels[1]=new CModelEurusdD1_30Class; ExtModels[2]=new CModelEurusdD1_63Class; for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); }
OnTick 기능
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { int pred=ExtModels[i].PredictClass(); if(pred>=0) returned[pred]++; } //--- get one prediction for all models int predicted_class=-1; //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=2) { predicted_class=n; break; } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
과반수 득표는 <총 득표수>/2 + 1이라는 공식에 따라 계산됩니다. 총 3표 중 과반수는 2표입니다. 이를 이른바 '하드 투표'라고 합니다.
테스트 결과는 여전히 동일한 설정으로 한 것입니다.
세 가지 모델의 작업, 즉 수익성 있는 거래와 수익성 없는 거래의 수를 개별적으로 다시 살펴 보겠습니다. 첫 번째 모델 - 11 : 3, second — 6 : 1, third — 16 : 10.
하드 투표를 통해 결과를 개선한 것 같습니다 - 16 : 4. 하지만 물론 전체 보고서와 테스트 차트를 살펴봐야 합니다.
8. 소프트 투표
소프트 투표는 투표 수가 아니라 세 가지 모델에서 세 가지 클래스 모두의 확률의 합을 중요시 한다는 점에서 하드 투표와 다릅니다. 가장 높은 확률로 클래스가 선택됩니다.
소프트 투표를 보장하려면 몇 가지 사항을 변경해야 합니다.
기본 클래스에서:
//+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- set predicted probability as 1.0 probabilities.Fill(0); if(predicted_class<(int)probabilities.Size()) probabilities[predicted_class]=1; //--- and return predicted class return(predicted_class); }
자식 클래스에서:
//+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- evaluate prediction probabilities.Assign(output_data); return(int(output_data.ArgMax())); }
EA에서:
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> enum EnVotes { Two=2, // Two votes Three=3, // Three votes Soft=4 // Soft voting }; input double InpLots = 1.0; // Lots amount to open position input EnVotes InpVotes = Two; // Votes to make trade decision CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; vector soft=vector::Zeros(3); //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { vector prob(3); int pred=ExtModels[i].PredictClass(prob); if(pred>=0) { returned[pred]++; soft+=prob; } } //--- get one prediction for all models int predicted_class=-1; //--- soft or hard voting if(InpVotes==Soft) predicted_class=(int)soft.ArgMax(); else { //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=InpVotes) { predicted_class=n; break; } } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
테스트 설정은 동일합니다. 입력에서 Soft를 선택합니다.
결과는 다음과 같습니다.
수익성 있는 거래 - 15, 수익성 없는 거래 - 3. 수익의 관점에서도 하드 투표가 소프트 투표보다 더 나은 것으로 나타났습니다.
만장일치인 투표 결과, 즉 투표 수가 3인 경우를 살펴봅시다.
매우 보수적인 트레이딩. 유일하게 수익성이 없는 거래는 테스트 종료 시 종료되었습니다(수익성이 없는 거래는 아닐 수도 있습니다).
중요 참고 사항: 이 문서에 사용된 모델은 MQL5 언어를 사용하여 ONNX 모델로 작업하는 방법을 시연하기 위한 용도로만 제공되었습니다. Expert Advisor는 실제 계좌에서 거래용으로 만든 것이 아닙니다.
결론
이 글에서 우리는 객체 지향 프로그래밍을 통해 프로그램을 더 쉽게 작성할 수 있는 방법을 알아보았습니다. 모델의 모든 복잡성은 해당 클래스에 숨겨져 있습니다(예시로 제시한 모델보다 훨씬 더 복잡할 수 있습니다). 나머지 '복잡성'은 OnTick 함수의 45개 문자열에 들어가 있습니다.
MetaQuotes 소프트웨어 사를 통해 러시아어가 번역됨.
원본 기고글: https://www.mql5.com/ru/articles/12484



