//+------------------------------------------------------------------+
//|                                                       ESRGAN.mq5 |
//|                                  Copyright 2024, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2024, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
//+------------------------------------------------------------------+
//| 4x image upscaling demo using ESRGAN                             |
//| esrgan.onnx model from                                           |
//| https://github.com/amannm/super-resolution-service/              |
//+------------------------------------------------------------------+
//| Xintao Wang et al (2018)                                         |
//| ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks|
//| https://arxiv.org/abs/1809.00219                                 |
//+------------------------------------------------------------------+
#resource "models\\esrgan.onnx" as uchar ExtModel[];
#include <Canvas\Canvas.mqh>
//+------------------------------------------------------------------+
//| clamp                                                            |
//+------------------------------------------------------------------+
float clamp(float value, float minValue, float maxValue)
  {
   return MathMin(MathMax(value, minValue), maxValue);
  }
//+------------------------------------------------------------------+
//| Preprocessing                                                    |
//+------------------------------------------------------------------+
bool Preprocessing(float &data[],uint &image_data[],int &image_width,int &image_height)
  {
//--- checkup
   if(image_height==0 || image_width==0)
      return(false);
//--- prepare destination array with separated RGB channels for ONNX model
   int data_count=3*image_width*image_height;

   if(ArrayResize(data,data_count)!=data_count)
     {
      Print("ArrayResize failed");
      return(false);
     }
//--- converting
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         //--- load source RGB
         int   offset=y*image_width+x;
         uint  clr   =image_data[offset];
         uchar r     =GETRGBR(clr);
         uchar g     =GETRGBG(clr);
         uchar b     =GETRGBB(clr);
         //--- store RGB components as separated channels
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;

         data[offset_ch1]=r/255.0f;
         data[offset_ch2]=g/255.0f;
         data[offset_ch3]=b/255.0f;
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| PostProcessing                                                   |
//+------------------------------------------------------------------+
bool PostProcessing(const float &data[], uint &image_data[], const int &image_width, const int &image_height)
  {
//--- checks
   if(image_height == 0 || image_width == 0)
      return(false);

   int data_count=image_width*image_height;
   
   if(ArraySize(data)!=3*data_count)
      return(false);
   if(ArrayResize(image_data,data_count)!=data_count)
      return(false);
//---
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
        {
         int offset    =y*image_width+x;
         int offset_ch1=0*image_width*image_height+offset;
         int offset_ch2=1*image_width*image_height+offset;
         int offset_ch3=2*image_width*image_height+offset;
         //--- rescale to [0..255]
         float r=clamp(data[offset_ch1]*255,0,255);
         float g=clamp(data[offset_ch2]*255,0,255);
         float b=clamp(data[offset_ch3]*255,0,255);
         //--- set color image_data
         image_data[offset]=XRGB(uchar(r),uchar(g),uchar(b));
        }
//---
   return(true);
  }
//+------------------------------------------------------------------+
//| ShowImage                                                        |
//+------------------------------------------------------------------+
bool ShowImage(CCanvas &canvas,const string name,const int x0,const int y0,const int image_width,const int image_height, const uint &image_data[])
  {
   if(ArraySize(image_data)==0 || name=="")
      return(false);
//--- prepare canvas
   canvas.CreateBitmapLabel(name,x0,y0,image_width,image_height,COLOR_FORMAT_XRGB_NOALPHA);
//--- copy image to canvas
   for(int y=0; y<image_height; y++)
      for(int x=0; x<image_width; x++)
         canvas.PixelSet(x,y,image_data[y*image_width+x]);
//--- ready to draw
   canvas.Update(true);
   return(true);
  }
//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
int OnStart(void)
  {
//--- select BMP from <data folder>\MQL5\Files
   string image_path[1];

   if(FileSelectDialog("Select BMP image",NULL,"Bitmap files (*.bmp)|*.bmp",FSD_FILE_MUST_EXIST,image_path,"lenna.bmp")!=1)
     {
      Print("file not selected");
      return(-1);
     }
//--- load BMP into array
   uint image_data[];
   int  image_width;
   int  image_height;

   if(!CCanvas::LoadBitmap(image_path[0],image_data,image_width,image_height))
     {
      PrintFormat("CCanvas::LoadBitmap failed with error %d",GetLastError());
      return(-1);
     }
//--- convert RGB image to separated RGB channels
   float input_data[];
   Preprocessing(input_data,image_data,image_width,image_height);
   PrintFormat("input array size=%d",ArraySize(input_data));
//--- load model
   long model_handle=OnnxCreateFromBuffer(ExtModel,ONNX_DEFAULT);

   if(model_handle==INVALID_HANDLE)
     {
      PrintFormat("OnnxCreate error %d",GetLastError());
      return(-1);
     }

   PrintFormat("model loaded successfully");
   PrintFormat("original:  width=%d, height=%d  Size=%d",image_width,image_height,ArraySize(image_data));
//--- set input shape
   ulong input_shape[]={1,3,image_height,image_width};

   if(!OnnxSetInputShape(model_handle,0,input_shape))
     {
      PrintFormat("error in OnnxSetInputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- upscaled image size
   int   new_image_width =4*image_width;
   int   new_image_height=4*image_height;
   ulong output_shape[]= {1,3,new_image_height,new_image_width};

   if(!OnnxSetOutputShape(model_handle,0,output_shape))
     {
      PrintFormat("error in OnnxSetOutputShape. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }
//--- run the model
   float output_data[];
   int new_data_count=3*new_image_width*new_image_height;
   if(ArrayResize(output_data,new_data_count)!=new_data_count)
     {
      OnnxRelease(model_handle);
      return(-1);
     }

   if(!OnnxRun(model_handle,ONNX_DEBUG_LOGS,input_data,output_data))
     {
      PrintFormat("error in OnnxRun. error code=%d",GetLastError());
      OnnxRelease(model_handle);
      return(-1);
     }

   Print("model successfully executed, output data size ",ArraySize(output_data));
   OnnxRelease(model_handle);
//--- postprocessing
   uint new_image[];
   PostProcessing(output_data,new_image,new_image_width,new_image_height);
//--- show images
   CCanvas canvas_original,canvas_scaled;
   ShowImage(canvas_original,"original_image",new_image_width,0,image_width,image_height,image_data);
   ShowImage(canvas_scaled,"upscaled_image",0,0,new_image_width,new_image_height,new_image);
//--- save upscaled image
   StringReplace(image_path[0],".bmp","_upscaled.bmp");
   Print(ResourceSave(canvas_scaled.ResourceName(),image_path[0]));
//---
   while(!IsStopped())
      Sleep(100);

   return(0);
  }
//+------------------------------------------------------------------+