How to set up onnx when there are 2 outputs?

 

This example is provided in build 3980:https://www.mql5.com/en/forum/454439

The iris.onnx model is generated through this python code:

from sys import argv

from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

data_path = argv[0]
last_index = data_path.rfind("\\") + 1
data_path = data_path[0:last_index]

iris_dataset = load_iris()

X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)

knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(knn, initial_types=initial_type)
path = data_path + "iris.onnx"
with open(path, "wb") as f:
    f.write(onx.SerializeToString())

I try to use it in MQL5, but don't know how to set the Shape of the second output.

This setting will report error 5802, ERR_ONNX_NOT_SUPPORTED.

const long output_pro[] = {2};
if(!OnnxSetOutputShape(m_model_handle,1,output_pro))
{
 Print("OnnxSetOutputShape error ",GetLastError());
 return(INIT_FAILED);
}

If do not use this function, running directly will report the error: ONNX: parameter is empty.

The following is the complete code:

//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
#resource "\\Files\\model\\iris.onnx" as uchar CiTorchH1ONNXModel[]

//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
class CiCatOnnx
  {
public:
                     CiCatOnnx();
                    ~CiCatOnnx();
   double            Scale;
   double            Bias;
   bool              UseBuy;
   bool              UseSell;
   bool              UseCloseBuyMA2;
   bool              UseCloseBuyCat;
   bool              UseCloseSell;
   bool              UseStopLoss;
   uint              StopLoss;
   bool              UseTakeProfit;
   uint              TakeProfit;
   bool              UseCloseProfit;
   int               CloseProfit;
   ENUM_TIMEFRAMES   TimeFrame;
   bool              UseTimeFilter;
   int               StartHour;
   int               EndHour;
   bool              CloseOnlyProfit;
   bool              UseMACDTypeFilter;

   int               m_matrix_row;
   matrixf           m_x_data;

   long              m_model_handle;
   int               m_sample_size;

   int               loock_back;

   int               hnd[];
   int               Init();

   double            Predict();
  };
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void CiCatOnnx::CiCatOnnx(void)
  {
   loock_back        = 5;
   m_sample_size     = 4;
   m_x_data          = matrixf::Zeros(m_matrix_row, m_sample_size);
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void CiCatOnnx::~CiCatOnnx(void)
  {

  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
int CiCatOnnx::Init()
  {
   m_model_handle=OnnxCreateFromBuffer(CiTorchH1ONNXModel,ONNX_DEBUG_LOGS);
   if(m_model_handle==INVALID_HANDLE)
     {
      Print("OnnxCreateFromBuffer error ",GetLastError());
      return(INIT_FAILED);
     }

   const long input_shape[] = {2,m_sample_size};
   if(!OnnxSetInputShape(m_model_handle,0,input_shape))
     {
      Print("OnnxSetInputShape error ",GetLastError());
      return(INIT_FAILED);
     }

   const long output_label[] = {2};
   if(!OnnxSetOutputShape(m_model_handle,0,output_label))
     {
      Print("OnnxSetOutputShape error ",GetLastError());
      return(INIT_FAILED);
     }

//const long output_pro[] = {2};
//if(!OnnxSetOutputShape(m_model_handle,1,output_pro))
//  {
//   Print("OnnxSetOutputShape error ",GetLastError());
//   return(INIT_FAILED);
//  }

   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
double CiCatOnnx::Predict()
  {
   double result = -1;
   double x_data[];

   static matrixf res_data = {{1.1, 2.0, 3.2, 4.2},{1.1, 2.0, 2.2, 4.2}};

   static matrix output_data[];

   struct MyMap
     {
      long              key[];
      float             value[];
     };

   MyMap output_probability[];

   if(!OnnxRun(m_model_handle,ONNX_DEBUG_LOGS,res_data,output_data,output_probability))
     {
      Print("OnnxRun");
      return -1;
     }

//output_data = output_data*(max-min)+min;
//result = output_data[0][m_matrix_row-1];
   return result;
  }

CiCatOnnx m_model;
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
int OnInit()
  {
   m_model.Init();

   return INIT_SUCCEEDED;
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void OnTick()
  {
   m_model.Predict();
  }
//+------------------------------------------------------------------+

Attached is the iris.onnx model.

Model input and output.

model input and output

Please help me, where did I write it wrong?

New MetaTrader 5 Platform build 3980: Improvements and fixes
New MetaTrader 5 Platform build 3980: Improvements and fixes
  • 2023.09.21
  • www.mql5.com
The updated version of the MetaTrader 5 platform will be released on Thursday, September 21, 2023...
Files:
iris.onnx  5 kb
 
const long output_pro[] = {2};
if(!OnnxSetOutputShape(m_model_handle,1,output_pro))
 {
   Print("OnnxSetOutputShape error ",GetLastError());
   return(INIT_FAILED);
 }

Problem solved!

OnnxSetOutputShape(m_model_handle,1,output_pro)

Always returns False, but works.

MetaEditor Version: 5.00 build 4040 20 Oct 2023