//+------------------------------------------------------------------+
//|                                  ONNX_DQN_Trading_Script.mq5     |
//|                        Copyright 2023, MetaQuotes Ltd.           |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, MetaQuotes Ltd."
#property link      "https://www.mql5.com"
#property version   "1.00"
#property script_show_inputs

//--- input parameters
input string   ModelPath = "dueling_dqn_xauusd.onnx";  // File in MQL5\Files\
input int      WindowSize = 30;                        // Observation window size
input int      FeatureCount = 5;                       // Number of features

//--- ONNX model handle
long onnxHandle;

//--- Normalization parameters (REPLACE WITH YOUR ACTUAL VALUES)
const double   RSI_MEAN = 55.0,       RSI_STD = 15.0;
const double   MACD_MEAN = 0.05,      MACD_STD = 0.5;
const double   SMA20_MEAN = 1800.0,   SMA20_STD = 100.0;
const double   SMA50_MEAN = 1800.0,   SMA50_STD = 100.0;
const double   RETURN_MEAN = 0.0002,  RETURN_STD = 0.01;

//+------------------------------------------------------------------+
//| Script program start function                                    |
//+------------------------------------------------------------------+
void OnStart()
{
   //--- Load ONNX model
   onnxHandle = OnnxCreate(ModelPath, ONNX_DEFAULT);
   if(onnxHandle == INVALID_HANDLE)
   {
      Print("Error loading model: ", GetLastError());
      return;
   }
   
   //--- Prepare input data buffer
   double inputData[];
   ArrayResize(inputData, WindowSize * FeatureCount);
   
   //--- Collect and prepare data
   if(!PrepareInputData(inputData))
   {
      Print("Data preparation failed");
      OnnxRelease(onnxHandle);
      return;
   }
   
   //--- Set input shape (FIXED: no need to set shape for dynamic axes)
   //--- Run inference
   double outputData[3];
   if(!RunInference(inputData, outputData))
   {
      Print("Inference failed");
      OnnxRelease(onnxHandle);
      return;
   }
   
   //--- Interpret results
   InterpretResults(outputData);
   OnnxRelease(onnxHandle);
}
//+------------------------------------------------------------------+
//| Prepare input data for the model                                 |
//+------------------------------------------------------------------+
bool PrepareInputData(double &inputData[])
{
   //--- Get closing prices
   double closes[];
   int closeCount = WindowSize + 1;
   if(CopyClose(_Symbol, _Period, 0, closeCount, closes) != closeCount)
   {
      Print("Not enough historical data. Requested: ", closeCount, ", Received: ", ArraySize(closes));
      return false;
   }
   
   //--- Calculate returns (percentage changes)
   double returns[];
   ArrayResize(returns, WindowSize);
   for(int i = 0; i < WindowSize; i++)
      returns[i] = (closes[i] - closes[i+1]) / closes[i+1];
   
   //--- Calculate technical indicators
   double rsi[], macd[], sma20[], sma50[];
   if(!CalculateIndicators(rsi, macd, sma20, sma50))
      return false;
   
   //--- Verify indicator array sizes
   if(ArraySize(rsi) < WindowSize || ArraySize(macd) < WindowSize || 
      ArraySize(sma20) < WindowSize || ArraySize(sma50) < WindowSize)
   {
      Print("Indicator data mismatch");
      return false;
   }
   
   //--- Normalize features and fill input data
   int dataIndex = 0;
   for(int i = WindowSize - 1; i >= 0; i--)
   {
      inputData[dataIndex++] = (rsi[i] - RSI_MEAN) / RSI_STD;
      inputData[dataIndex++] = (macd[i] - MACD_MEAN) / MACD_STD;
      inputData[dataIndex++] = (sma20[i] - SMA20_MEAN) / SMA20_STD;
      inputData[dataIndex++] = (sma50[i] - SMA50_MEAN) / SMA50_STD;
      inputData[dataIndex++] = (returns[i] - RETURN_MEAN) / RETURN_STD;
   }
   
   return true;
}
//+------------------------------------------------------------------+
//| Calculate technical indicators                                   |
//+------------------------------------------------------------------+
bool CalculateIndicators(double &rsi[], double &macd[], double &sma20[], double &sma50[])
{
   //--- RSI (14 period)
   int rsiHandle = iRSI(_Symbol, _Period, 14, PRICE_CLOSE);
   if(rsiHandle == INVALID_HANDLE) return false;
   if(CopyBuffer(rsiHandle, 0, 0, WindowSize, rsi) != WindowSize) return false;
   IndicatorRelease(rsiHandle);
   
   //--- MACD (12,26,9)
   int macdHandle = iMACD(_Symbol, _Period, 12, 26, 9, PRICE_CLOSE);
   if(macdHandle == INVALID_HANDLE) return false;
   double macdSignal[];
   if(CopyBuffer(macdHandle, 0, 0, WindowSize, macd) != WindowSize) return false;
   if(CopyBuffer(macdHandle, 1, 0, WindowSize, macdSignal) != WindowSize) return false;
   
   // Calculate MACD difference (histogram)
   for(int i = 0; i < WindowSize; i++)
      macd[i] = macd[i] - macdSignal[i];
   
   IndicatorRelease(macdHandle);
   
   //--- SMA20
   int sma20Handle = iMA(_Symbol, _Period, 20, 0, MODE_SMA, PRICE_CLOSE);
   if(sma20Handle == INVALID_HANDLE) return false;
   if(CopyBuffer(sma20Handle, 0, 0, WindowSize, sma20) != WindowSize) return false;
   IndicatorRelease(sma20Handle);
   
   //--- SMA50
   int sma50Handle = iMA(_Symbol, _Period, 50, 0, MODE_SMA, PRICE_CLOSE);
   if(sma50Handle == INVALID_HANDLE) return false;
   if(CopyBuffer(sma50Handle, 0, 0, WindowSize, sma50) != WindowSize) return false;
   IndicatorRelease(sma50Handle);
   
   return true;
}
//+------------------------------------------------------------------+
//| Run model inference                                              |
//+------------------------------------------------------------------+
bool RunInference(const double &inputData[], double &outputData[])
{
   //--- Run model directly without setting shape (for dynamic axes)
   if(!OnnxRun(onnxHandle, ONNX_DEBUG_LOGS, inputData, outputData))
   {
      Print("Model inference failed: ", GetLastError());
      return false;
   }
   return true;
}
//+------------------------------------------------------------------+
//| Interpret model results                                          |
//+------------------------------------------------------------------+
void InterpretResults(const double &outputData[])
{
   //--- Find best action
   int bestAction = ArrayMaximum(outputData);
   string actionText = "";
   
   switch(bestAction)
   {
      case 0: actionText = "HOLD"; break;
      case 1: actionText = "BUY"; break;
      case 2: actionText = "SELL"; break;
   }
   
   //--- Print results
   Print("Model Output: [HOLD: ", outputData[0], ", BUY: ", outputData[1], ", SELL: ", outputData[2], "]");
   Print("Recommended Action: ", actionText);
}
//+------------------------------------------------------------------+