English 中文 日本語
preview
Работа с ONNX-моделями в форматах float16 и float8

Работа с ONNX-моделями в форматах float16 и float8

MetaTrader 5Интеграция | 27 февраля 2024, 15:01
628 6
MetaQuotes
MetaQuotes

Содержание

С развитием технологий машинного обучения и искусственного интеллекта возникает необходимость в оптимизации процессов работы с моделями. Эффективность работы моделей напрямую зависит от форматов данных, используемых для их представления. В последние годы появилось несколько новых типов данных, предназначенных специально для работы с моделями глубокого обучения.

В данной статье мы сосредоточимся на двух таких новых форматах данных — float16 и float8, которые начинают активно использоваться в современных ONNX-моделях. Эти форматы представляют собой альтернативные варианты более точных, но требовательных к ресурсам форматам данных с плавающей точкой. Они обеспечивают оптимальное сочетание производительности и точности, что делает их особенно привлекательными для различных задач машинного обучения. Мы изучим основные характеристики и преимущества форматов float16 и float8, а также представим функции для их преобразования в стандартные float и double.

Это поможет разработчикам и исследователям лучше понять, как эффективно использовать эти форматы в своих проектах и моделях. В качестве примера мы рассмотрим работу ONNX-модели ESRGAN, которая применяется для улучшения качества изображений.


1. Новые типы данных для работы с ONNX-моделями

Для ускорения расчетов некоторые модели используют типы данных с меньшей точностью, такие как Float16 и даже Float8.

Для работы с ONNX-моделями в язык MQL5 добавлена поддержка новых типов данных, позволяющих работать с 8 и 16-битными представлениями чисел с плавающей точкой.

Скрипт выводит полный список элементов перечисления ENUM_ONNX_DATA_TYPE.

//+------------------------------------------------------------------+
//|                                              ONNX_Data_Types.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
  {
//---
   for(int i=0; i<21; i++)
      PrintFormat("%2d %s",i,EnumToString(ENUM_ONNX_DATA_TYPE(i)));
  }

Результат:

 0: ONNX_DATA_TYPE_UNDEFINED
 1: ONNX_DATA_TYPE_FLOAT
 2: ONNX_DATA_TYPE_UINT8
 3: ONNX_DATA_TYPE_INT8
 4: ONNX_DATA_TYPE_UINT16
 5: ONNX_DATA_TYPE_INT16
 6: ONNX_DATA_TYPE_INT32
 7: ONNX_DATA_TYPE_INT64
 8: ONNX_DATA_TYPE_STRING
 9: ONNX_DATA_TYPE_BOOL
10: ONNX_DATA_TYPE_FLOAT16
11: ONNX_DATA_TYPE_DOUBLE
12: ONNX_DATA_TYPE_UINT32
13: ONNX_DATA_TYPE_UINT64
14: ONNX_DATA_TYPE_COMPLEX64
15: ONNX_DATA_TYPE_COMPLEX128
16: ONNX_DATA_TYPE_BFLOAT16
17: ONNX_DATA_TYPE_FLOAT8E4M3FN
18: ONNX_DATA_TYPE_FLOAT8E4M3FNUZ
19: ONNX_DATA_TYPE_FLOAT8E5M2
20: ONNX_DATA_TYPE_FLOAT8E5M2FNUZ

Таким образом, теперь можно исполнять ONNX-модели, работающие с такими данными.

Кроме того, в MQL5 появились дополнительные функции для конвертации данных:

bool ArrayToFP16(ushort &dst_array[],const float &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP16(ushort &dst_array[],const double &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const float &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayToFP8(uchar &dst_array[],const double &src_array[],ENUM_FLOAT8_FORMAT fmt);

bool ArrayFromFP16(float &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP16(double &dst_array[],const ushort &src_array[],ENUM_FLOAT16_FORMAT fmt);
bool ArrayFromFP8(float &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);
bool ArrayFromFP8(double &dst_array[],const uchar &src_array[],ENUM_FLOAT8_FORMAT fmt);

Поскольку форматы вещественных чисел для 16 и 8 бит могут отличаться, в параметре fmt в функциях конверсии необходимо указывать, какой именно формат числа требуется обработать.

Для 16-битных версий используется новое перечисление ENUM_FLOAT16_FORMAT, которое на данный момент имеет следующие значения:

  • FLOAT_FP16 — стандартный 16-битный формат, так же известный как half.
  • FLOAT_BFP16 — специальный формат brain float point.
Для 8-битных версий используется новое перечисление ENUM_FLOAT8_FORMAT, которое на данный момент имеет следующие значения:
  • FLOAT_FP8_E4M3FN — 8-битное число с плавающей точкой, 4 бита порядок и 3 бита мантисса. Обычно используется как коэффициенты.
  • FLOAT_FP8_E4M3FNUZ — 8-битное число с плавающей точкой, 4 бит порядок и 3 бита мантисса. Поддерживает NaN, не поддерживается отрицательный ноль и Inf. Обычно используется как коэффициенты.
  • FLOAT_FP8_E5M2FN — 8-битное число с плавающей точкой, 5 бит порядок и 2 бита мантисса. Поддерживает NaN и Inf. Обычно используется для градиентов.
  • FLOAT_FP8_E5M2FNUZ — 8-битное число с плавающей точкой, 5 бит порядок и 2 бита мантисса. Поддерживает NaN и Inf, не поддерживает отрицательный ноль. Также используется для градиентов.


1.1. Формат FP16

Форматы FLOAT16 и BFLOAT16 представляют собой типы данных, используемые для представления чисел с плавающей точкой.

FLOAT16, также известный как половинная точность или формат "half-precision float", использует 16 бит для представления числа с плавающей точкой. Этот формат обеспечивает баланс между точностью и эффективностью вычислений. FLOAT16 широко применяется в глубоком обучении и нейронных сетях, где требуется высокая производительность при обработке больших объемов данных. Этот формат позволяет ускорить вычисления за счет сокращения размера чисел, что особенно важно при обучении глубоких нейронных сетей на графических процессорах (GPU).

BFLOAT16 (или Brain Floating Point 16) также использует 16 бит, но он различается от FLOAT16 в способе представления чисел. В этом формате 8 бит выделены для представления экспоненты, а оставшиеся 7 бит используются для представления мантиссы. Этот формат был разработан для использования в глубоком обучении и искусственном интеллекте, особенно в процессорах Google Tensor Processing Unit (TPU). BFLOAT16 обладает хорошей производительностью при обучении нейронных сетей и может быть эффективно использован для ускорения вычислений.

Оба этих формата имеют свои преимущества и ограничения. FLOAT16 обеспечивает более высокую точность, но требует больше ресурсов для хранения и вычислений. BFLOAT16, с другой стороны, обеспечивает более высокую производительность и эффективность при обработке данных, но может быть менее точным.


Рис. Форматы битового представления чисел с плавающей точкой FLOAT16 и BFLOAT16

Рис.1. Форматы битового представления чисел с плавающей точкой FLOAT16 и BFLOAT16


Рис. Детали формата FLOAT16

Табл.1. Числа с плавающей точкой в формате FLOAT16


1.1.1. Тесты исполнения ONNX-оператора Cast для FLOAT16

В качестве иллюстрации рассмотрим задачу преобразования данных типа FLOAT16 в типы float и double.

ONNX-модели с операцией Cast:

Параметры ONNX-модели test_cast_FLOAT16_to_DOUBLE.onnx

Рис.2. Входные и выходные параметры модели test_cast_FLOAT16_to_DOUBLE.onnx


 Рис. Входные и выходные параметры модели test_cast_FLOAT16_to_FLOAT.onnx

Рис.3. Входные и выходные параметры модели test_cast_FLOAT16_to_FLOAT.onnx


Как видно из описания свойств ONNX-моделей, на входе требуется данные типа  ONNX_DATA_TYPE_FLOAT16, выходные данные модель возвратит в формате ONNX_DATA_TYPE_FLOAT.

Для конвертации значений будем использовать функции ArrayToFP16() и ArrayFromFP16() с параметром FLOAT_FP16.

Пример:

//+------------------------------------------------------------------+
//|                                              TestCastFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_FLOAT16_to_DOUBLE.onnx" as const uchar ExtModel1[];
#resource "models\\test_cast_FLOAT16_to_FLOAT.onnx" as const uchar ExtModel2[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastFloat16ToDouble                                           |
//+------------------------------------------------------------------+
bool RunCastFloat16ToDouble(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<double> output_double_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_double_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_double_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-output_double_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output double %f = %s  difference=%f",i,output_double_values[i].value,ArrayToString(output_double_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| RunCastFloat16ToFloat                                            |
//+------------------------------------------------------------------+
bool RunCastFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_FP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_FP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| TestCastFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool TestCastFloat16ToFloat(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToFloat(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| TestCastFloat16ToDouble                                          |
//+------------------------------------------------------------------+
bool TestCastFloat16ToDouble(const uchar &res_model[])
  {
   uchar model[];
   PatchONNXModel(res_model,model);
//---
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return(false);
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return(false);
//--- run ONNX model
   if(!RunCastFloat16ToDouble(model_handle))
      return(false);
//--- release model handle
   OnnxRelease(model_handle);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   if(!TestCastFloat16ToDouble(ExtModel1))
      return 1;

   if(!TestCastFloat16ToFloat(ExtModel2))
      return 1;
//---
   return 0;
  }
//+------------------------------------------------------------------+

Результат:

TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output double 1.000000 = [00,00,00,00,00,00,f0,3f]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output double 2.000000 = [00,00,00,00,00,00,00,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output double 3.000000 = [00,00,00,00,00,00,08,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output double 4.000000 = [00,00,00,00,00,00,10,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output double 5.000000 = [00,00,00,00,00,00,14,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output double 6.000000 = [00,00,00,00,00,00,18,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output double 7.000000 = [00,00,00,00,00,00,1c,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output double 8.000000 = [00,00,00,00,00,00,20,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output double 9.000000 = [00,00,00,00,00,00,22,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output double 10.000000 = [00,00,00,00,00,00,24,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output double 11.000000 = [00,00,00,00,00,00,26,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output double 12.000000 = [00,00,00,00,00,00,28,40]  difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToDouble   sum_error=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat
TestCastFloat16 (EURUSD,H1)     test array:
TestCastFloat16 (EURUSD,H1)      1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastFloat16 (EURUSD,H1)     ArrayToFP16:
TestCastFloat16 (EURUSD,H1)     15360 16384 16896 17408 17664 17920 18176 18432 18560 18688 18816 18944
TestCastFloat16 (EURUSD,H1)     0 input value =1.000000  Hex float16 = [00,3c]  ushort value=15360
TestCastFloat16 (EURUSD,H1)     1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastFloat16 (EURUSD,H1)     2 input value =3.000000  Hex float16 = [00,42]  ushort value=16896
TestCastFloat16 (EURUSD,H1)     3 input value =4.000000  Hex float16 = [00,44]  ushort value=17408
TestCastFloat16 (EURUSD,H1)     4 input value =5.000000  Hex float16 = [00,45]  ushort value=17664
TestCastFloat16 (EURUSD,H1)     5 input value =6.000000  Hex float16 = [00,46]  ushort value=17920
TestCastFloat16 (EURUSD,H1)     6 input value =7.000000  Hex float16 = [00,47]  ushort value=18176
TestCastFloat16 (EURUSD,H1)     7 input value =8.000000  Hex float16 = [00,48]  ushort value=18432
TestCastFloat16 (EURUSD,H1)     8 input value =9.000000  Hex float16 = [80,48]  ushort value=18560
TestCastFloat16 (EURUSD,H1)     9 input value =10.000000  Hex float16 = [00,49]  ushort value=18688
TestCastFloat16 (EURUSD,H1)     10 input value =11.000000  Hex float16 = [80,49]  ushort value=18816
TestCastFloat16 (EURUSD,H1)     11 input value =12.000000  Hex float16 = [00,4a]  ushort value=18944
TestCastFloat16 (EURUSD,H1)     ONNX input array:
TestCastFloat16 (EURUSD,H1)          [uc] [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...   15360
TestCastFloat16 (EURUSD,H1)     [ 1]  ...   16384
TestCastFloat16 (EURUSD,H1)     [ 2]  ...   16896
TestCastFloat16 (EURUSD,H1)     [ 3]  ...   17408
TestCastFloat16 (EURUSD,H1)     [ 4]  ...   17664
TestCastFloat16 (EURUSD,H1)     [ 5]  ...   17920
TestCastFloat16 (EURUSD,H1)     [ 6]  ...   18176
TestCastFloat16 (EURUSD,H1)     [ 7]  ...   18432
TestCastFloat16 (EURUSD,H1)     [ 8]  ...   18560
TestCastFloat16 (EURUSD,H1)     [ 9]  ...   18688
TestCastFloat16 (EURUSD,H1)     [10]  ...   18816
TestCastFloat16 (EURUSD,H1)     [11]  ...   18944
TestCastFloat16 (EURUSD,H1)     ONNX output array:
TestCastFloat16 (EURUSD,H1)          [uc]  [value]
TestCastFloat16 (EURUSD,H1)     [ 0]  ...  1.00000
TestCastFloat16 (EURUSD,H1)     [ 1]  ...  2.00000
TestCastFloat16 (EURUSD,H1)     [ 2]  ...  3.00000
TestCastFloat16 (EURUSD,H1)     [ 3]  ...  4.00000
TestCastFloat16 (EURUSD,H1)     [ 4]  ...  5.00000
TestCastFloat16 (EURUSD,H1)     [ 5]  ...  6.00000
TestCastFloat16 (EURUSD,H1)     [ 6]  ...  7.00000
TestCastFloat16 (EURUSD,H1)     [ 7]  ...  8.00000
TestCastFloat16 (EURUSD,H1)     [ 8]  ...  9.00000
TestCastFloat16 (EURUSD,H1)     [ 9]  ... 10.00000
TestCastFloat16 (EURUSD,H1)     [10]  ... 11.00000
TestCastFloat16 (EURUSD,H1)     [11]  ... 12.00000
TestCastFloat16 (EURUSD,H1)     0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastFloat16 (EURUSD,H1)     1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastFloat16 (EURUSD,H1)     7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastFloat16 (EURUSD,H1)     test=RunCastFloat16ToFloat   sum_error=0.000000


1.1.2. Тесты исполнения ONNX-оператора Cast для BFLOAT16

В этом примере рассматривается преобразование из типа BFLOAT16 в float.

ONNX-модель с операцией Cast:


Рис. Входные и выходные параметры модели test_cast_BFLOAT16_to_FLOAT.onnx

Рис.4. Входные и выходные параметры модели test_cast_BFLOAT16_to_FLOAT.onnx

На входе требуется данные типа  ONNX_DATA_TYPE_BFLOAT16, выходные данные модель возвратит в формате ONNX_DATA_TYPE_FLOAT.

Для конвертации значений будем использовать функции ArrayToFP16() и ArrayFromFP16() с параметром BFLOAT_FP16.

//+------------------------------------------------------------------+
//|                                             TestCastBFloat16.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"

#resource "models\\test_cast_BFLOAT16_to_FLOAT.onnx" as const uchar ExtModel1[];

//+------------------------------------------------------------------+
//| union for data conversion                                        |
//+------------------------------------------------------------------+
template<typename T>
union U
  {
   uchar uc[sizeof(T)];
   T value;
  };
//+------------------------------------------------------------------+
//| ArrayToString                                                    |
//+------------------------------------------------------------------+
template<typename T>
string ArrayToString(const T &data[],uint length=16)
  {
   string res;

   for(uint n=0; n<MathMin(length,data.Size()); n++)
      res+="," + StringFormat("%.2x",data[n]);

   StringSetCharacter(res,0,'[');
   return res+"]";
  }

//+------------------------------------------------------------------+
//| PatchONNXModel                                                   |
//+------------------------------------------------------------------+
void PatchONNXModel(const uchar &original_model[],uchar &patched_model[])
  {
   ArrayCopy(patched_model,original_model,0,0,WHOLE_ARRAY);
//--- special ONNX model patch(IR=9,Opset=20)
   patched_model[1]=0x09;
   patched_model[ArraySize(patched_model)-1]=0x14;
  }
//+------------------------------------------------------------------+
//| CreateModel                                                      |
//+------------------------------------------------------------------+
bool CreateModel(long &model_handle,const uchar &model[])
  {
   model_handle=INVALID_HANDLE;
   ulong flags=ONNX_DEFAULT;
//ulong flags=ONNX_DEBUG_LOGS;
//---
   model_handle=OnnxCreateFromBuffer(model,flags);
   if(model_handle==INVALID_HANDLE)
      return(false);
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PrepareShapes                                                    |
//+------------------------------------------------------------------+
bool PrepareShapes(long model_handle)
  {
   ulong input_shape1[]= {3,4};
   if(!OnnxSetInputShape(model_handle,0,input_shape1))
     {
      PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   ulong output_shape[]= {3,4};
   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
      //--
      OnnxRelease(model_handle);
      return(false);
     }
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| RunCastBFloat16ToFloat                                           |
//+------------------------------------------------------------------+
bool RunCastBFloat16ToFloat(long model_handle)
  {
   PrintFormat("test=%s",__FUNCTION__);

   double test_data[12]= {1,2,3,4,5,6,7,8,9,10,11,12};
   ushort data_uint16[12];
   if(!ArrayToFP16(data_uint16,test_data,FLOAT_BFP16))
     {
      Print("error in ArrayToFP16. error code=",GetLastError());
      return(false);
     }
   Print("test array:");
   ArrayPrint(test_data);
   Print("ArrayToFP16:");
   ArrayPrint(data_uint16);

   U<ushort> input_float16_values[3*4];
   U<float>  output_float_values[3*4];

   float test_data_float[];
   if(!ArrayFromFP16(test_data_float,data_uint16,FLOAT_BFP16))
     {
      Print("error in ArrayFromFP16. error code=",GetLastError());
      return(false);
     }

   for(int i=0; i<12; i++)
     {
      input_float16_values[i].value=data_uint16[i];
      PrintFormat("%d input value =%f  Hex float16 = %s  ushort value=%d",i,test_data_float[i],ArrayToString(input_float16_values[i].uc),input_float16_values[i].value);
     }

   Print("ONNX input array:");
   ArrayPrint(input_float16_values);

   bool res=OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float16_values,output_float_values);
   if(!res)
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      return(false);
     }

   Print("ONNX output array:");
   ArrayPrint(output_float_values);
//---
   double sum_error=0.0;
   for(int i=0; i<12; i++)
     {
      double delta=test_data[i]-(double)output_float_values[i].value;
      sum_error+=MathAbs(delta);
      PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToString(output_float_values[i].uc),delta);
     }
//---
   PrintFormat("test=%s   sum_error=%f",__FUNCTION__,sum_error);
//---
   return(true);
  }

//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
   uchar model[];
   PatchONNXModel(ExtModel1,model);
//--- get model handle
   long model_handle=INVALID_HANDLE;
//--- get model handle
   if(!CreateModel(model_handle,model))
      return 1;
//--- prepare input and output shapes
   if(!PrepareShapes(model_handle))
      return 1;
//--- run ONNX model
   if(!RunCastBFloat16ToFloat(model_handle))
      return 1;
//--- release model handle
   OnnxRelease(model_handle);
//---
   return 0;
  }
//+------------------------------------------------------------------+
Результат:
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat
TestCastBFloat16 (EURUSD,H1)    test array:
TestCastBFloat16 (EURUSD,H1)     1.00000  2.00000  3.00000  4.00000  5.00000  6.00000  7.00000  8.00000  9.00000 10.00000 11.00000 12.00000
TestCastBFloat16 (EURUSD,H1)    ArrayToFP16:
TestCastBFloat16 (EURUSD,H1)    16256 16384 16448 16512 16544 16576 16608 16640 16656 16672 16688 16704
TestCastBFloat16 (EURUSD,H1)    0 input value =1.000000  Hex float16 = [80,3f]  ushort value=16256
TestCastBFloat16 (EURUSD,H1)    1 input value =2.000000  Hex float16 = [00,40]  ushort value=16384
TestCastBFloat16 (EURUSD,H1)    2 input value =3.000000  Hex float16 = [40,40]  ushort value=16448
TestCastBFloat16 (EURUSD,H1)    3 input value =4.000000  Hex float16 = [80,40]  ushort value=16512
TestCastBFloat16 (EURUSD,H1)    4 input value =5.000000  Hex float16 = [a0,40]  ushort value=16544
TestCastBFloat16 (EURUSD,H1)    5 input value =6.000000  Hex float16 = [c0,40]  ushort value=16576
TestCastBFloat16 (EURUSD,H1)    6 input value =7.000000  Hex float16 = [e0,40]  ushort value=16608
TestCastBFloat16 (EURUSD,H1)    7 input value =8.000000  Hex float16 = [00,41]  ushort value=16640
TestCastBFloat16 (EURUSD,H1)    8 input value =9.000000  Hex float16 = [10,41]  ushort value=16656
TestCastBFloat16 (EURUSD,H1)    9 input value =10.000000  Hex float16 = [20,41]  ushort value=16672
TestCastBFloat16 (EURUSD,H1)    10 input value =11.000000  Hex float16 = [30,41]  ushort value=16688
TestCastBFloat16 (EURUSD,H1)    11 input value =12.000000  Hex float16 = [40,41]  ushort value=16704
TestCastBFloat16 (EURUSD,H1)    ONNX input array:
TestCastBFloat16 (EURUSD,H1)         [uc] [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...   16256
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...   16384
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...   16448
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...   16512
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...   16544
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...   16576
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...   16608
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...   16640
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...   16656
TestCastBFloat16 (EURUSD,H1)    [ 9]  ...   16672
TestCastBFloat16 (EURUSD,H1)    [10]  ...   16688
TestCastBFloat16 (EURUSD,H1)    [11]  ...   16704
TestCastBFloat16 (EURUSD,H1)    ONNX output array:
TestCastBFloat16 (EURUSD,H1)         [uc]  [value]
TestCastBFloat16 (EURUSD,H1)    [ 0]  ...  1.00000
TestCastBFloat16 (EURUSD,H1)    [ 1]  ...  2.00000
TestCastBFloat16 (EURUSD,H1)    [ 2]  ...  3.00000
TestCastBFloat16 (EURUSD,H1)    [ 3]  ...  4.00000
TestCastBFloat16 (EURUSD,H1)    [ 4]  ...  5.00000
TestCastBFloat16 (EURUSD,H1)    [ 5]  ...  6.00000
TestCastBFloat16 (EURUSD,H1)    [ 6]  ...  7.00000
TestCastBFloat16 (EURUSD,H1)    [ 7]  ...  8.00000
TestCastBFloat16 (EURUSD,H1)    [ 8]  ...  9.00000
TestCastBFloat16 (EURUSD,H1)    [ 9]  ... 10.00000
TestCastBFloat16 (EURUSD,H1)    [10]  ... 11.00000
TestCastBFloat16 (EURUSD,H1)    [11]  ... 12.00000
TestCastBFloat16 (EURUSD,H1)    0 output float 1.000000 = [00,00,80,3f] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    1 output float 2.000000 = [00,00,00,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    2 output float 3.000000 = [00,00,40,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    3 output float 4.000000 = [00,00,80,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    4 output float 5.000000 = [00,00,a0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    5 output float 6.000000 = [00,00,c0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    6 output float 7.000000 = [00,00,e0,40] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    7 output float 8.000000 = [00,00,00,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    8 output float 9.000000 = [00,00,10,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    9 output float 10.000000 = [00,00,20,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    10 output float 11.000000 = [00,00,30,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    11 output float 12.000000 = [00,00,40,41] difference=0.000000
TestCastBFloat16 (EURUSD,H1)    test=RunCastBFloat16ToFloat   sum_error=0.000000


1.2. Формат FP8

Современные языковые модели могут содержать миллиарды параметров. Обучение моделей с использованием чисел FP16 уже показало свою эффективность. Переход от 16-битного числа с плавающей точкой к FP8 позволяет вдвое уменьшить требования к памяти а также ускорить обучение и исполнение моделей.

Формат FP8 (8-битное число с плавающей запятой) представляет собой один из типов данных, используемых для представления чисел с плавающей точкой. В FP8 каждое число представлено 8 битами данных, которые обычно разделяются на три компонента: знак, экспоненту и мантиссу. Этот формат обеспечивает компромисс между точностью и эффективностью хранения данных, что делает его привлекательным для использования в приложениях, где требуется экономия памяти и вычислительных ресурсов. 

Одним из ключевых преимуществ FP8 является его эффективность при обработке больших объемов данных. Благодаря компактному представлению чисел, FP8 позволяет уменьшить требования к памяти и ускорить вычисления. Это особенно важно в приложениях машинного обучения и искусственного интеллекта, где обработка больших наборов данных является обычным делом.

Кроме того, FP8 может быть полезен для реализации низкоуровневых операций, таких как арифметические вычисления и обработка сигналов. Его компактный формат делает его подходящим для использования во встраиваемых системах и приложениях, где ограничены ресурсы. Однако, следует отметить, что FP8 имеет свои ограничения, связанные с его ограниченной точностью. В некоторых приложениях, где требуется высокая точность вычислений, таких как научные вычисления или финансовая аналитика, использование FP8 может быть недостаточным.


1.2.1. Форматы fp8_e5m2 и fp8_e4m3

В 2022 году было опубликовано две статьи, вводящие числа с плавающей точкой, хранящиеся в одном байте, в отличие от чисел float32, хранящихся в 4 байтах.

В статье FP8 Formats for Deep Learning (2022) от NVIDIA, Intel и ARM вводятся два типа, следуя спецификациям IEEE. Первый тип - это E4M3, 1 бит для знака, 4 бита для экспоненты и 3 бита для мантиссы. Второй тип - это E5M2, 1 бит для знака, 5 бит для экспоненты и 2 для мантиссы.Первый тип обычно используется для весов, второй - для градиентов.

Вторая статья "8-bit Numerical Formats For Deep Neural Networks" представляет схожие типы. Стандарт IEEE присваивает одинаковое значение +0 (или целое число 0) и -0 (или целое число 128).  В статье предлагается присвоить различные значения float этим двум числам. Кроме того, также исследуются различные разделения между экспонентой и мантиссой, и показано, что E4M3 и E5M2 являются лучшими.

В результате в ONNX (с версии 1.15.0) было введено 4 новых типа:

  • E4M3FN: 1 бит для знака, 4 бита для экспоненты, 3 бита для мантиссы, только значения NaN и нет бесконечных значений (FN),
  • E4M3FNUZ: 1 бит для знака, 4 бита для экспоненты, 3 бита для мантиссы, только значения NaN и нет бесконечных значений (FN), нет отрицательного нуля (UZ)
  • E5M2: 1 бит для знака, 5 бит для экспоненты, 2 бита для мантиссы,
  • E5M2FNUZ: 1 бит для знака, 5 бит для экспоненты, 2 бита для мантиссы, только значения NaN и нет бесконечных значений (FN), нет отрицательного нуля (UZ)

Реализация обычно зависит от аппаратных средств. NVIDIA, Intel и Arm реализуют E4M3FN, а E5M2 реализованы в современных графических процессорах. GraphCore делает то же самое только с E4M3FNUZ и E5M2FNUZ.

Приведем кратко основные сведения о типе FP8 согласно статье NVIDIA Hopper: H100 and FP8 Support.


Рис. Формат битового представления чисел с плавающей точкой FP8_E4M3

Рис.5. Формат битового представления чисел с плавающей точкой FP8_E4M3


Детали формата FP8_E5M2

Табл.3. Числа с плавающей точкой в формате E5M2


Детали формата FP8_E5M2

Табл.4. Числа с плавающей точкой в формате E4M3

Сравнение дипазонов положительных значений чисел FP8_E4M3 и FP8_E5M2 приведены на рис

 Сравнение диапазонов положительных значений чисел FP8

Рис.6. Сравнение диапазонов положительных значений чисел FP8 (источник)


Сравнение точности выполнения арифметических операций (Add, Mul, Div) для чисел в форматах FP8_E5M2  и FP8_E4M3 приведены на рис:

Сравнение точности арифметических операций для чисел в форматах float8_e5m2 и float8_e4m3 (источник)

Рис.7. Сравнение точности арифметических операций для чисел в форматах float8_e5m2 и float8_e4m3 (источник)

Рекомендуемое использование чисел в формате FP8:

  • E4M3 для тензоров весов и активации;
  • E5M2 для тензоров градиентов.


1.2.2. Тесты исполнения ONNX-оператора Cast для FLOAT8

В этом примере рассматривается преобразование из различных типов FLOAT8 в float.

ONNX-модели с операцией Cast:


Рис.8. Входные и выходные параметры модели test_cast_FLOAT8E4M3FN_to_FLOAT.onnx в MetaEditor

Рис.8. Входные и выходные параметры модели test_cast_FLOAT8E4M3FN_to_FLOAT.onnx в MetaEditor



 Рис.9. Входные и выходные параметры модели test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx в MetaEditor

Рис.9. Входные и выходные параметры модели test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx в MetaEditor


Рис.10. Входные и выходные параметры модели test_cast_FLOAT8E5M2_to_FLOAT.onnx в MetaEditor

Рис.10. Входные и выходные параметры модели test_cast_FLOAT8E5M2_to_FLOAT.onnx в MetaEditor


 Рис.11. Входные и выходные параметры модели test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx в MetaEditor

Рис.11. Входные и выходные параметры модели test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx в MetaEditor


    Пример:

    //+------------------------------------------------------------------+
    //|                                              TestCastBFloat8.mq5 |
    //|                                  Copyright 2024, MetaQuotes Ltd. |
    //|                                             https://www.mql5.com |
    //+------------------------------------------------------------------+
    #property copyright "Copyright 2024, MetaQuotes Ltd."
    #property link      "https://www.mql5.com"
    #property version   "1.00"
    
    #resource "models\\test_cast_FLOAT8E4M3FN_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FN_to_FLOAT[];
    #resource "models\\test_cast_FLOAT8E4M3FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E4M3FNUZ_to_FLOAT[];
    #resource "models\\test_cast_FLOAT8E5M2_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2_to_FLOAT[];
    #resource "models\\test_cast_FLOAT8E5M2FNUZ_to_FLOAT.onnx" as const uchar ExtModel_FLOAT8E5M2FNUZ_to_FLOAT[];
    
    #define TEST_PASSED 0
    #define TEST_FAILED 1
    //+------------------------------------------------------------------+
    //| union for data conversion                                        |
    //+------------------------------------------------------------------+
    template<typename T>
    union U
      {
       uchar uc[sizeof(T)];
       T value;
      };
    //+------------------------------------------------------------------+
    //| ArrayToHexString                                                 |
    //+------------------------------------------------------------------+
    template<typename T>
    string ArrayToHexString(const T &data[],uint length=16)
      {
       string res;
    
       for(uint n=0; n<MathMin(length,data.Size()); n++)
          res+="," + StringFormat("%.2x",data[n]);
    
       StringSetCharacter(res,0,'[');
       return(res+"]");
      }
    //+------------------------------------------------------------------+
    //| ArrayToString                                                    |
    //+------------------------------------------------------------------+
    template<typename T>
    string ArrayToString(const U<T> &data[],uint length=16)
      {
       string res;
    
       for(uint n=0; n<MathMin(length,data.Size()); n++)
          res+="," + (string)data[n].value;
    
       StringSetCharacter(res,0,'[');
       return(res+"]");
      }
    //+------------------------------------------------------------------+
    //| PatchONNXModel                                                   |
    //+------------------------------------------------------------------+
    long CreatePatchedModel(const uchar &original_model[])
      {
       uchar patched_model[];
       ArrayCopy(patched_model,original_model);
    //--- special ONNX model patch(IR=9,Opset=20)
       patched_model[1]=0x09;
       patched_model[ArraySize(patched_model)-1]=0x14;
    
       return(OnnxCreateFromBuffer(patched_model,ONNX_DEFAULT));
      }
    //+------------------------------------------------------------------+
    //| PrepareShapes                                                    |
    //+------------------------------------------------------------------+
    bool PrepareShapes(long model_handle)
      {
    //--- configure input shape
       ulong input_shape[]= {3,5};
    
       if(!OnnxSetInputShape(model_handle,0,input_shape))
         {
          PrintFormat("error in OnnxSetInputShape for input1. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(false);
         }
    //--- configure output shape
       ulong output_shape[]= {3,5};
    
       if(!OnnxSetOutputShape(model_handle,0,output_shape))
         {
          PrintFormat("error in OnnxSetOutputShape for output. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(false);
         }
    
       return(true);
      }
    //+------------------------------------------------------------------+
    //| RunCastFloat8Float                                               |
    //+------------------------------------------------------------------+
    bool RunCastFloat8ToFloat(long model_handle,const ENUM_FLOAT8_FORMAT fmt)
      {
       PrintFormat("TEST: %s(%s)",__FUNCTION__,EnumToString(fmt));
    //---
       float test_data[15]   = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15};
       uchar data_float8[15] = {};
    
       if(!ArrayToFP8(data_float8,test_data,fmt))
         {
          Print("error in ArrayToFP8. error code=",GetLastError());
          OnnxRelease(model_handle);
          return(false);
         }
    
       U<uchar> input_float8_values[3*5];
       U<float> output_float_values[3*5];
       float    test_data_float[];
    //--- convert float8 to float
       if(!ArrayFromFP8(test_data_float,data_float8,fmt))
         {
          Print("error in ArrayFromFP8. error code=",GetLastError());
          OnnxRelease(model_handle);
          return(false);
         }
    
       for(uint i=0; i<data_float8.Size(); i++)
         {
          input_float8_values[i].value=data_float8[i];
          PrintFormat("%d input value =%f  Hex float8 = %s  ushort value=%d",i,test_data_float[i],ArrayToHexString(input_float8_values[i].uc),input_float8_values[i].value);
         }
    
       Print("ONNX input array: ",ArrayToString(input_float8_values));
    //--- execute model (convert float8 to float using ONNX)
       if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_float8_values,output_float_values))
         {
          PrintFormat("error in OnnxRun. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(false);
         }
    
       Print("ONNX output array: ",ArrayToString(output_float_values));
    //--- calculate error (compare ONNX and ArrayFromFP8 results)
       double sum_error=0.0;
    
       for(uint i=0; i<test_data.Size(); i++)
         {
          double delta=test_data_float[i]-(double)output_float_values[i].value;
          sum_error+=MathAbs(delta);
          PrintFormat("%d output float %f = %s difference=%f",i,output_float_values[i].value,ArrayToHexString(output_float_values[i].uc),delta);
         }
    //---
       PrintFormat("%s(%s): sum_error=%f\n",__FUNCTION__,EnumToString(fmt),sum_error);
       return(true);
      }
    //+------------------------------------------------------------------+
    //| TestModel                                                        |
    //+------------------------------------------------------------------+
    bool TestModel(const uchar &model[],const ENUM_FLOAT8_FORMAT fmt)
      {
    //--- create patched model
       long model_handle=CreatePatchedModel(model);
    
       if(model_handle==INVALID_HANDLE)
          return(false);
    //--- prepare input and output shapes
       if(!PrepareShapes(model_handle))
          return(false);
    //--- run ONNX model
       if(!RunCastFloat8ToFloat(model_handle,fmt))
          return(false);
    //--- release model handle
       OnnxRelease(model_handle);
    
       return(true);
      }
    //+------------------------------------------------------------------+
    //| Script program start function                                    |
    //+------------------------------------------------------------------+
    int OnStart(void)
      {
    //--- run ONNX model
       if(!TestModel(ExtModel_FLOAT8E4M3FN_to_FLOAT,FLOAT_FP8_E4M3FN))
          return(TEST_FAILED);
    
    //--- run ONNX model
       if(!TestModel(ExtModel_FLOAT8E4M3FNUZ_to_FLOAT,FLOAT_FP8_E4M3FNUZ))
          return(TEST_FAILED);
    
    //--- run ONNX model
       if(!TestModel(ExtModel_FLOAT8E5M2_to_FLOAT,FLOAT_FP8_E5M2FN))
          return(TEST_FAILED);
    
    //--- run ONNX model
       if(!TestModel(ExtModel_FLOAT8E5M2FNUZ_to_FLOAT,FLOAT_FP8_E5M2FNUZ))
          return(TEST_FAILED);
    
       return(TEST_PASSED);
      }
    //+------------------------------------------------------------------+
    

    Результат:

    TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN)
    TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [38]  ushort value=56
    TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
    TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [44]  ushort value=68
    TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
    TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [4a]  ushort value=74
    TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4c]  ushort value=76
    TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4e]  ushort value=78
    TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [50]  ushort value=80
    TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [51]  ushort value=81
    TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [52]  ushort value=82
    TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [53]  ushort value=83
    TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [54]  ushort value=84
    TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [55]  ushort value=85
    TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [56]  ushort value=86
    TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [57]  ushort value=87
    TestCastFloat8 (EURUSD,H1)      ONNX input array: [56,64,68,72,74,76,78,80,81,82,83,84,85,86,87]
    TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
    TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FN): sum_error=0.000000
    TestCastFloat8 (EURUSD,H1)      
    TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ)
    TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
    TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [48]  ushort value=72
    TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [4c]  ushort value=76
    TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [50]  ushort value=80
    TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [52]  ushort value=82
    TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [54]  ushort value=84
    TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [56]  ushort value=86
    TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [58]  ushort value=88
    TestCastFloat8 (EURUSD,H1)      8 input value =9.000000  Hex float8 = [59]  ushort value=89
    TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [5a]  ushort value=90
    TestCastFloat8 (EURUSD,H1)      10 input value =11.000000  Hex float8 = [5b]  ushort value=91
    TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [5c]  ushort value=92
    TestCastFloat8 (EURUSD,H1)      12 input value =13.000000  Hex float8 = [5d]  ushort value=93
    TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [5e]  ushort value=94
    TestCastFloat8 (EURUSD,H1)      14 input value =15.000000  Hex float8 = [5f]  ushort value=95
    TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,72,76,80,82,84,86,88,89,90,91,92,93,94,95]
    TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]
    TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      8 output float 9.000000 = [00,00,10,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      10 output float 11.000000 = [00,00,30,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      12 output float 13.000000 = [00,00,50,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      14 output float 15.000000 = [00,00,70,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E4M3FNUZ): sum_error=0.000000
    TestCastFloat8 (EURUSD,H1)      
    TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN)
    TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [3c]  ushort value=60
    TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [40]  ushort value=64
    TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [42]  ushort value=66
    TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [44]  ushort value=68
    TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [45]  ushort value=69
    TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [46]  ushort value=70
    TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [47]  ushort value=71
    TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [48]  ushort value=72
    TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [48]  ushort value=72
    TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [49]  ushort value=73
    TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4a]  ushort value=74
    TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4a]  ushort value=74
    TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4a]  ushort value=74
    TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4b]  ushort value=75
    TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [4c]  ushort value=76
    TestCastFloat8 (EURUSD,H1)      ONNX input array: [60,64,66,68,69,70,71,72,72,73,74,74,74,75,76]
    TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
    TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FN): sum_error=0.000000
    TestCastFloat8 (EURUSD,H1)      
    TestCastFloat8 (EURUSD,H1)      TEST: RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ)
    TestCastFloat8 (EURUSD,H1)      0 input value =1.000000  Hex float8 = [40]  ushort value=64
    TestCastFloat8 (EURUSD,H1)      1 input value =2.000000  Hex float8 = [44]  ushort value=68
    TestCastFloat8 (EURUSD,H1)      2 input value =3.000000  Hex float8 = [46]  ushort value=70
    TestCastFloat8 (EURUSD,H1)      3 input value =4.000000  Hex float8 = [48]  ushort value=72
    TestCastFloat8 (EURUSD,H1)      4 input value =5.000000  Hex float8 = [49]  ushort value=73
    TestCastFloat8 (EURUSD,H1)      5 input value =6.000000  Hex float8 = [4a]  ushort value=74
    TestCastFloat8 (EURUSD,H1)      6 input value =7.000000  Hex float8 = [4b]  ushort value=75
    TestCastFloat8 (EURUSD,H1)      7 input value =8.000000  Hex float8 = [4c]  ushort value=76
    TestCastFloat8 (EURUSD,H1)      8 input value =8.000000  Hex float8 = [4c]  ushort value=76
    TestCastFloat8 (EURUSD,H1)      9 input value =10.000000  Hex float8 = [4d]  ushort value=77
    TestCastFloat8 (EURUSD,H1)      10 input value =12.000000  Hex float8 = [4e]  ushort value=78
    TestCastFloat8 (EURUSD,H1)      11 input value =12.000000  Hex float8 = [4e]  ushort value=78
    TestCastFloat8 (EURUSD,H1)      12 input value =12.000000  Hex float8 = [4e]  ushort value=78
    TestCastFloat8 (EURUSD,H1)      13 input value =14.000000  Hex float8 = [4f]  ushort value=79
    TestCastFloat8 (EURUSD,H1)      14 input value =16.000000  Hex float8 = [50]  ushort value=80
    TestCastFloat8 (EURUSD,H1)      ONNX input array: [64,68,70,72,73,74,75,76,76,77,78,78,78,79,80]
    TestCastFloat8 (EURUSD,H1)      ONNX output array: [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,8.0,10.0,12.0,12.0,12.0,14.0,16.0]
    TestCastFloat8 (EURUSD,H1)      0 output float 1.000000 = [00,00,80,3f] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      1 output float 2.000000 = [00,00,00,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      2 output float 3.000000 = [00,00,40,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      3 output float 4.000000 = [00,00,80,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      4 output float 5.000000 = [00,00,a0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      5 output float 6.000000 = [00,00,c0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      6 output float 7.000000 = [00,00,e0,40] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      7 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      8 output float 8.000000 = [00,00,00,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      9 output float 10.000000 = [00,00,20,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      10 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      11 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      12 output float 12.000000 = [00,00,40,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      13 output float 14.000000 = [00,00,60,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      14 output float 16.000000 = [00,00,80,41] difference=0.000000
    TestCastFloat8 (EURUSD,H1)      RunCastFloat8ToFloat(FLOAT_FP8_E5M2FNUZ): sum_error=0.000000
    TestCastFloat8 (EURUSD,H1)



    2. Пример использования ONNX для повышения разрешения изображений

    В этом разделе мы рассмотрим пример использования моделей ESRGAN для увеличения разрешения изображений.

    ESRGAN, или Enhanced Super-Resolution Generative Adversarial Networks, представляет собой мощную архитектуру нейронных сетей, разработанную для решения задачи суперразрешения изображений. ESRGAN разработан с целью улучшения качества изображений, повышая их разрешение до более высокого уровня. Это осуществляется путем обучения глубокой нейронной сети на большом наборе данных низкого разрешения и соответствующих им высококачественных изображений. ESRGAN использует архитектуру генеративно-состязательных сетей (GANs), которая состоит из двух основных компонентов: генератора и дискриминатора. Генератор отвечает за создание изображений высокого разрешения, тогда как дискриминатор обучается отличать сгенерированные изображения от реальных.

    В основе архитектуры ESRGAN лежат резидуальные блоки, которые помогают извлекать и сохранять важные признаки изображений на разных уровнях абстракции. Это позволяет сети эффективно восстанавливать детали и текстуры на изображениях с высоким качеством.

    Для достижения высокого качества и общности в решении задачи суперразрешения, ESRGAN требует обширных наборов данных для обучения. Это позволяет сети изучать различные стили и характеристики изображений и делает ее более адаптивной к различным типам входных данных. ESRGAN может быть использован для улучшения качества изображений во многих областях, включая фотографии, медицинскую диагностику, кино и видеопроизводство, графический дизайн и многое другое. Его гибкость и эффективность делают его одним из ведущих методов в области суперразрешения изображений.

    ESRGAN представляет собой важный шаг вперед в области обработки изображений и искусственного интеллекта, открывая новые возможности для создания и улучшения изображений.


    2.1. Пример исполнения ONNX-модели с float32

    Для исполнения примера требуется скачать файл https://github.com/amannm/super-resolution-service/blob/main/models/esrgan.onnx и скопировать его в папку \MQL5\Scripts\models.

    Модель ESRGAN.onnx содержит ~1200 ONNX-операций, начальные из них представлены на рис.12

    Рис. Модель ESRGAN в MetaEditor

    Рис.12. Модель ESRGAN в MetaEditor



    Рис. Модель ESRGAN в Netron

    Рис.13. Модель ESRGAN в Netron


    Приведенный ниже код представляет собой демонстрацию увеличения размера изображения в 4 раза с использованием ESRGAN.onnx.

    Он начинается с загрузки модели esrgan.onnxm, затем выбирается и загружается исходное изображение в формате BMP. После этого изображение конвертируется в отдельные каналы RGB, которые подаются на вход модели. Модель выполняет процесс увеличения размера изображения в 4 раза, после чего полученное увеличенное изображение проходит обратное преобразование и подготавливается к отображению.

    Для отображения используется библиотека Canvas, а для выполнения модели - библиотека ONNX Runtime. После выполнения программы увеличенное изображение сохраняется в файл с добавлением "_upscaled" к имени исходного файла. Основные функции включают предобработку и постобработку изображения, а также выполнение модели для увеличения размера изображения.

    //+------------------------------------------------------------------+
    //|                                                       ESRGAN.mq5 |
    //|                                  Copyright 2024, MetaQuotes Ltd. |
    //|                                             https://www.mql5.com |
    //+------------------------------------------------------------------+
    #property copyright "Copyright 2024, MetaQuotes Ltd."
    #property link      "https://www.mql5.com"
    #property version   "1.00"
    //+------------------------------------------------------------------+
    //| 4x image upscaling demo using ESRGAN                             |
    //| esrgan.onnx model from                                           |
    //| https://github.com/amannm/super-resolution-service/              |
    //+------------------------------------------------------------------+
    //| Xintao Wang et al (2018)                                         |
    //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
    //| https://arxiv.org/abs/1809.00219                                 |
    //+------------------------------------------------------------------+
    #resource "models\\esrgan.onnx" as uchar ExtModel[];
    #include <Canvas\Canvas.mqh>
    //+------------------------------------------------------------------+
    //| clamp                                                            |
    //+------------------------------------------------------------------+
    float clamp(float value, float minValue, float maxValue)
      {
       return MathMin(MathMax(value, minValue), maxValue);
      }
    //+------------------------------------------------------------------+
    //| Preprocessing                                                    |
    //+------------------------------------------------------------------+
    bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
      {
    //--- checkup
       if(image_height==0 || image_width==0)
          return(false);
    //--- prepare destination array with separated RGB channels for ONNX model
       int data_count=3*image_width*image_height;
    
       if(ArrayResize(data,data_count)!=data_count)
         {
          Print("ArrayResize failed");
          return(false);
         }
    //--- converting
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
            {
             //--- load source RGB
             int   offset=y*image_width+x;
             uint  clr   =image_data[offset];
             uchar r     =GETRGBR(clr);
             uchar g     =GETRGBG(clr);
             uchar b     =GETRGBB(clr);
             //--- store RGB components as separated channels
             int offset_ch1=0*image_width*image_height+offset;
             int offset_ch2=1*image_width*image_height+offset;
             int offset_ch3=2*image_width*image_height+offset;
    
             data[offset_ch1]=r/255.0f;
             data[offset_ch2]=g/255.0f;
             data[offset_ch3]=b/255.0f;
            }
    //---
       return(true);
      }
    //+------------------------------------------------------------------+
    //| PostProcessing                                                   |
    //+------------------------------------------------------------------+
    bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
      {
    //--- checks
       if(image_height == 0 || image_width == 0)
          return(false);
    
       int data_count=image_width*image_height;
       
       if(ArraySize(data)!=3*data_count)
          return(false);
       if(ArrayResize(image_data,data_count)!=data_count)
          return(false);
    //---
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
            {
             int offset    =y*image_width+x;
             int offset_ch1=0*image_width*image_height+offset;
             int offset_ch2=1*image_width*image_height+offset;
             int offset_ch3=2*image_width*image_height+offset;
             //--- rescale to [0..255]
             float r=clamp(data[offset_ch1]*255,0,255);
             float g=clamp(data[offset_ch2]*255,0,255);
             float b=clamp(data[offset_ch3]*255,0,255);
             //--- set color image_data
             image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
            }
    //---
       return(true);
      }
    //+------------------------------------------------------------------+
    //| ShowImage                                                        |
    //+------------------------------------------------------------------+
    bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
      {
       if(ArraySize(image_data)==0 || name=="")
          return(false);
    //--- prepare canvas
       canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
    //--- copy image to canvas
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
             canvas.PixelSet(x,y,image_data[y*image_width+x]);
    //--- ready to draw
       canvas.Update(true);
       return(true);
      }
    //+------------------------------------------------------------------+
    //| Script program start function                                    |
    //+------------------------------------------------------------------+
    int OnStart(void)
      {
    //--- select BMP from <data folder>\MQL5\Files
       string image_path[1];
    
       if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna-original4.bmp")!=1)
         {
          Print("file not selected");
          return(-1);
         }
    //--- load BMP into array
       uint image_data[];
       int  image_width;
       int  image_height;
    
       if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
         {
          PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
          return(-1);
         }
    //--- convert RGB image to separated RGB channels
       float input_data[];
       Preprocessing(input_data,image_data,image_width,image_height);
       PrintFormat("input array size=%d",ArraySize(input_data));
    //--- load model
       long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
    
       if(model_handle==INVALID_HANDLE)
         {
          PrintFormat("OnnxCreate error %d",GetLastError());
          return(-1);
         }
    
       PrintFormat("model loaded successfully");
       PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
    //--- set input shape
       ulong input_shape[]={1,3,image_height,image_width};
    
       if(!OnnxSetInputShape(model_handle,0,input_shape))
         {
          PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    //--- upscaled image size
       int   new_image_width =4*image_width;
       int   new_image_height=4*image_height;
       ulong output_shape[]= {1,3,new_image_height,new_image_width};
    
       if(!OnnxSetOutputShape(model_handle,0,output_shape))
         {
          PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    //--- run the model
       float output_data[];
       int new_data_count=3*new_image_width*new_image_height;
       if(ArrayResize(output_data,new_data_count)!=new_data_count)
         {
          OnnxRelease(model_handle);
          return(-1);
         }
    
       if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data))
         {
          PrintFormat("error in OnnxRun. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    
       Print("model successfully executed, output data size ",ArraySize(output_data));
       OnnxRelease(model_handle);
    //--- postprocessing
       uint new_image[];
       PostProcessing(output_data,new_image,new_image_width,new_image_height);
    //--- show images
       CCanvas canvas_original,canvas_scaled;
       ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
       ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
    //--- save upscaled image
       StringReplace(image_path[0],".bmp","_upscaled.bmp");
       Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
    //---
       while(!IsStopped())
          Sleep(100);
    
       return(0);
      }
    //+------------------------------------------------------------------+

    Результат:

    Рис. Увеличение изображения 160x200 в 4 раза при помощи модели ESRGAN.onnx

    Рис.14. Результат работы модели ESRGAN.onnx (160x200 -> 640x800)

    В данном примере изображение 160x200 было увеличено в 4 раза (до 640x800) при помощи модели ESRGAN.onnx.


    2.2. Пример исполнения ONNX-модели с float16

    Для конвертации моделей в float16 воспользуемся методом, описанным в Create Float16 and Mixed Precision Models.

    # Copyright 2024, MetaQuotes Ltd.
    # https://www.mql5.com
    
    import onnx
    from onnxconverter_common import float16
    
    from sys import argv
    
    # Define the path for saving the model
    data_path = argv[0]
    last_index = data_path.rfind("\\") + 1
    data_path = data_path[0:last_index]
    
    # конвертация модели в float16
    model_path = data_path+'\\models\\esrgan.onnx'
    modelfp16_path = data_path+'\\models\\esrgan_float16.onnx'
    
    model = onnx.load(model_path)
    model_fp16 = float16.convert_float_to_float16(model)
    onnx.save(model_fp16, modelfp16_path)
    

    После конвертации размер файла уменьшился вдвое (с 64Mb до 32Mb).

    Изменения в коде минимальные:

    //+------------------------------------------------------------------+
    //|                                               ESRGAN_float16.mq5 |
    //|                                  Copyright 2024, MetaQuotes Ltd. |
    //|                                             https://www.mql5.com |
    //+------------------------------------------------------------------+
    #property copyright "Copyright 2024, MetaQuotes Ltd."
    #property link      "https://www.mql5.com"
    #property version   "1.00"
    //+------------------------------------------------------------------+
    //| 4x image upscaling demo using ESRGAN                             |
    //| esrgan.onnx model from                                           |
    //| https://github.com/amannm/super-resolution-service/              |
    //+------------------------------------------------------------------+
    //| Xintao Wang et al (2018)                                         |
    //| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
    //| https://arxiv.org/abs/1809.00219                                 |
    //+------------------------------------------------------------------+
    #resource "models\\esrgan_float16.onnx" as uchar ExtModel[];
    #include <Canvas\Canvas.mqh>
    //+------------------------------------------------------------------+
    //| clamp                                                            |
    //+------------------------------------------------------------------+
    float clamp(float value, float minValue, float maxValue)
      {
       return MathMin(MathMax(value, minValue), maxValue);
      }
    //+------------------------------------------------------------------+
    //| Preprocessing                                                    |
    //+------------------------------------------------------------------+
    bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
      {
    //--- checkup
       if(image_height==0 || image_width==0)
          return(false);
    //--- prepare destination array with separated RGB channels for ONNX model
       int data_count=3*image_width*image_height;
    
       if(ArrayResize(data,data_count)!=data_count)
         {
          Print("ArrayResize failed");
          return(false);
         }
    //--- converting
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
            {
             //--- load source RGB
             int   offset=y*image_width+x;
             uint  clr   =image_data[offset];
             uchar r     =GETRGBR(clr);
             uchar g     =GETRGBG(clr);
             uchar b     =GETRGBB(clr);
             //--- store RGB components as separated channels
             int offset_ch1=0*image_width*image_height+offset;
             int offset_ch2=1*image_width*image_height+offset;
             int offset_ch3=2*image_width*image_height+offset;
    
             data[offset_ch1]=r/255.0f;
             data[offset_ch2]=g/255.0f;
             data[offset_ch3]=b/255.0f;
            }
    //---
       return(true);
      }
    //+------------------------------------------------------------------+
    //| PostProcessing                                                   |
    //+------------------------------------------------------------------+
    bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
      {
    //--- checks
       if(image_height == 0 || image_width == 0)
          return(false);
    
       int data_count=image_width*image_height;
       
       if(ArraySize(data)!=3*data_count)
          return(false);
       if(ArrayResize(image_data,data_count)!=data_count)
          return(false);
    //---
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
            {
             int offset    =y*image_width+x;
             int offset_ch1=0*image_width*image_height+offset;
             int offset_ch2=1*image_width*image_height+offset;
             int offset_ch3=2*image_width*image_height+offset;
             //--- rescale to [0..255]
             float r=clamp(data[offset_ch1]*255,0,255);
             float g=clamp(data[offset_ch2]*255,0,255);
             float b=clamp(data[offset_ch3]*255,0,255);
             //--- set color image_data
             image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
            }
    //---
       return(true);
      }
    //+------------------------------------------------------------------+
    //| ShowImage                                                        |
    //+------------------------------------------------------------------+
    bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
      {
       if(ArraySize(image_data)==0 || name=="")
          return(false);
    //--- prepare canvas
       canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
    //--- copy image to canvas
       for(int y=0; y<image_height; y++)
          for(int x=0; x<image_width; x++)
             canvas.PixelSet(x,y,image_data[y*image_width+x]);
    //--- ready to draw
       canvas.Update(true);
       return(true);
      }
    //+------------------------------------------------------------------+
    //| Script program start function                                    |
    //+------------------------------------------------------------------+
    int OnStart(void)
      {
    //--- select BMP from <data folder>\MQL5\Files
       string image_path[1];
    
       if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1)
         {
          Print("file not selected");
          return(-1);
         }
    //--- load BMP into array
       uint image_data[];
       int  image_width;
       int  image_height;
    
       if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
         {
          PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
          return(-1);
         }
    //--- convert RGB image to separated RGB channels
       float input_data[];
       Preprocessing(input_data,image_data,image_width,image_height);
       PrintFormat("input array size=%d",ArraySize(input_data));
       
       ushort input_data_float16[];
       if(!ArrayToFP16(input_data_float16,input_data,FLOAT_FP16))
         {
          Print("error in ArrayToFP16. error code=",GetLastError());
          return(false);
         }   
    //--- load model
       long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);
       if(model_handle==INVALID_HANDLE)
         {
          PrintFormat("OnnxCreate error %d",GetLastError());
          return(-1);
         }
    
       PrintFormat("model loaded successfully");
       PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
    //--- set input shape
       ulong input_shape[]={1,3,image_height,image_width};
    
       if(!OnnxSetInputShape(model_handle,0,input_shape))
         {
          PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    //--- upscaled image size
       int   new_image_width =4*image_width;
       int   new_image_height=4*image_height;
       ulong output_shape[]= {1,3,new_image_height,new_image_width};
    
       if(!OnnxSetOutputShape(model_handle,0,output_shape))
         {
          PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    //--- run the model
       float output_data[];
       ushort output_data_float16[];
       int new_data_count=3*new_image_width*new_image_height;
       if(ArrayResize(output_data_float16,new_data_count)!=new_data_count)
         {
          OnnxRelease(model_handle);
          return(-1);
         }
    
       if(!OnnxRun(model_handle,ONNX_NO_CONVERSION,input_data_float16,output_data_float16))
         {
          PrintFormat("error in OnnxRun. error code=%d",GetLastError());
          OnnxRelease(model_handle);
          return(-1);
         }
    
       Print("model successfully executed, output data size ",ArraySize(output_data));
       OnnxRelease(model_handle);
       
       if(!ArrayFromFP16(output_data,output_data_float16,FLOAT_FP16))
         {
          Print("error in ArrayFromFP16. error code=",GetLastError());
          return(false);
         }   
    //--- postprocessing
       uint new_image[];
       PostProcessing(output_data,new_image,new_image_width,new_image_height);
    //--- show images
       CCanvas canvas_original,canvas_scaled;
       ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
       ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
    //--- save upscaled image
       StringReplace(image_path[0],".bmp","_upscaled.bmp");
       Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
    //---
       while(!IsStopped())
          Sleep(100);
    
       return(0);
      }
    //+------------------------------------------------------------------+

    Изменения в коде, которые потребовались для исполнения модели, сконвертированной в формат float16, выделены цветом.

    Результат:

     Рис. Результат работы модели ESRGAN_float16.onnx (160x200 -> 640x800)

     Рис.15. Результат работы модели ESRGAN_float16.onnx (160x200 -> 640x800)


    Таким образом, использование чисел float16 вместо float32 позволяет сократить размер файла ONNX-модели в 2 раза (с 64Mb до 32Mb).

    При исполнении моделей на числах float16 качество изображений осталось прежним, визуально трудно найти отличия:

     Рис.16. Сравнение результатов работы модели ESRGAN для float и float16

    Рис.16. Сравнение результатов работы модели ESRGAN для float и float16


    Изменения в коде минимальные, нужно лишь позаботиться о конвертации входных и выходных данных.

    В данном случае после конвертации в float16 качество работы модели существенно не изменилось, однако при анализе финансовых данных следует стремиться к расчетам с максимально возможной точностью.


    Выводы

    Использование новых типов данных для чисел с плавающей точкой позволяет сократить размер ONNX-моделей без существенной потери качества.

    Препроцессинг и постпроцессинг данных значительно упрощаются благодаря использованию функций конвертации данных ArrayToFP16/ArrayFromFP16 и ArrayToFP8/ArrayFromFP8.

    Для работы с конвертированными ONNX-моделями требуются минимальные изменения в коде.


    Прикрепленные файлы |
    Последние комментарии | Перейти к обсуждению на форуме трейдеров (6)
    Quantum
    Quantum | 27 февр. 2024 в 16:49
    fxsaber #:
    Просьба добавить справа еще одну картинку того же размера - увеличенная в четыре раза (вместо одного пикселя - четыре (2x2) таких же по цвету) оригинальная картинка.

    Lenna-ESRGAN-ESRGAN_float and original 4-x-scaled

    Для ее вывода можно заменить код:

       //ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
       ShowImage4(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
    //+------------------------------------------------------------------+
    //| ShowImage4                                                        |
    //+------------------------------------------------------------------+
    bool ShowImage4(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
      {
       if(ArraySize(image_data)==0 || name=="")
          return(false);
    //--- prepare canvas
       canvas.CreateBitmapLabel(name,x0,y0,4*image_width,4*image_height,COLOR_FORMAT_XRGB_NOALPHA);
    //--- copy image to canvas
       for(int y=0; y<4*image_height-1; y++)
          for(int x=0; x<4*image_width-1; x++)
          {
             uint  clr =image_data[(y/4)*image_width+(x/4)];
             canvas.PixelSet(x,y,clr);
             }
    //--- ready to draw
       canvas.Update(true);
       return(true);
      }
    fxsaber
    fxsaber | 27 февр. 2024 в 17:05
    Quantum #:

    Для ее вывода можно заменить код:

    Спасибо! Уменьшил в два раза по каждой координате, получив правое изображение, как оригинальное.

    Думал, что float16/32 при таком преобразовании станут близкими к оригиналу. Но и они заметно лучше! Т.е. UpScale+DownScale >> Original.


    ЗЫ Удивили. Похоже, все старые снимки/видео целесообразно прогонять через подобную onnx-модель.

    fxsaber
    fxsaber | 27 февр. 2024 в 17:11

    Если на вход onnx-модели подавать одни и те же данные, то на выходе будет всегда одинаковый результат?

    Есть ли элемент случайности внутри onnx-модели?

    Quantum
    Quantum | 27 февр. 2024 в 17:54
    fxsaber #:

    Если на вход onnx-модели подавать одни и те же данные, но на выходе будет всегда одинаковый результат?

    Есть ли элемент случайности внутри onnx-модели?

    В общем случае это зависит от того, какие операторы используются внутри ONNX-модели.

    Для данной модели результат должен быть одинаковый, она содержит детерминированные операции (всего 1195)

    Forester
    Forester | 28 февр. 2024 в 11:10

    Описание float16

    https://ru.wikipedia.org/wiki/%D0%A7%D0%B8%D1%81%D0%BB%D0%BE_%D0%BF%D0%BE%D0%BB%D0%BE%D0%B2%D0%B8%D0%BD%D0%BD%D0%BE%D0%B9_%D1%82%D0%BE%D1%87%D0%BD%D0%BE%D1%81%D1%82%D0%B8


    Примеры чисел половинной точности

    В данных примерах числа с плавающей запятой представлены в двоичном представлении. Они включают в себя бит знака, экспоненту и мантиссу.

    0 01111 0000000000 = +1 * 215-15 = 1
    0 01111 0000000001 = +1.0000000001 2 * 215-15=1 + 2-10 = 1.0009765625 (следующее большее число после 1)

    Т.е. для чисел с 5 знаками после запятой (большинство валют) после 1.00000 можно применить только 1.00098.
    Круто! Но не для трейдинга и работы с котировками.

    Перестановка ценовых баров в MQL5 Перестановка ценовых баров в MQL5
    В этой статье мы представляем алгоритм перестановки ценовых баров и подробно рассказываем, как тесты на перестановку (permutation tests) можно использовать для выявления случаев, когда эффективность стратегии была сфабрикована с целью обмануть потенциальных покупателей советников.
    Нейросети — это просто (Часть 78): Детектор объектов на основе Трансформера (DFFT) Нейросети — это просто (Часть 78): Детектор объектов на основе Трансформера (DFFT)
    В данной статье я предлагаю посмотреть на вопрос построения торговой стратегии с другой стороны. Мы не будем прогнозировать будущее ценовое движение, а попробуем построить торговую систему на основе анализа исторических данных.
    Добавляем пользовательскую LLM в торгового робота (Часть 2): Пример развертывания среды Добавляем пользовательскую LLM в торгового робота (Часть 2): Пример развертывания среды
    Языковые модели (LLM) являются важной частью быстро развивающегося искусственного интеллекта, поэтому нам следует подумать о том, как интегрировать мощные LLM в нашу алгоритмическую торговлю. Большинству людей сложно настроить эти модели в соответствии со своими потребностями, развернуть их локально, а затем применить к алгоритмической торговле. В этой серии статей будет рассмотрен пошаговый подход к достижению этой цели.
    Оптимизация и тестирование торговых стратегий (Часть 1): Взгляд на "Red Dragon H4", "BOLT", "YinYang", и "Statistics SAR" Оптимизация и тестирование торговых стратегий (Часть 1): Взгляд на "Red Dragon H4", "BOLT", "YinYang", и "Statistics SAR"
    Так как я постоянно занимаюсь, разработкой разного рода торговых систем сегодня хочу поделиться с Вами несколькими из них по стратегиям "Red Dragon H4", "BOLT", "YinYang" и "Statistics SAR". Данные стратегии были найдены на просторах интернета.