
基于Python和MQL5的特征工程(第一部分):为长期 AI 模型预测移动平均线
在将 AI 用于执行任何任务时,我们必须尽最大努力为模型提供尽可能多的关于现实世界的有用信息。为了向我们的 AI 模型描述市场的不同属性,我们需要操作和转换输入数据,这一过程被称为特征工程。本系列文章将教您如何转换市场数据,以降低模型的误差水平。今天,我将专注于如何使用移动平均线来增加 AI 模型的预测范围,同时完全控制策略并合理理解其整体有效性。
策略概述
上次我们讨论用 AI 预测移动平均线时,我提供了证据表明,移动平均线的值比未来价格水平更容易被我们的 AI 模型预测,相关文章链接在这里。不过为了让我们对所发现的结果更有信心,我在 200 多种不同的交易品种上训练了两个相同的 AI 模型,并比较了预测价格与预测移动平均线的准确率,结果表明,如果预测价格而不是移动平均线,准确率平均会下降 34%。
平均而言,当预测移动平均线时,可以有望达到 70% 的准确率,而预测价格时的准确率为 52%。我们都知道,根据选择的时间周期不同,移动平均线指标并不总是紧密跟随价格水平,例如,价格可能在 20 根蜡烛图期间下跌,而移动平均线在同一区间内却在上升。这种背离对我们来说是不利的,因为我们有可能正确地预测了移动平均线的未来方向,但价格可能会出现背离。令人惊讶的是,我们观察到在所有市场中,背离率大致固定在 31% 左右,而我们预测背离的平均能力为 68%。
此外,预测背离的能力方差为 0.000041,而背离发生的方差为 0.000386。这表明我们的模型能够以可靠的方式进行自我修正。希望将 AI 应用于长期交易策略的社区成员应该考虑在更高的时间框架上采用这种替代方法。我们的讨论暂时仅限于 M1 时间框架,因为这一时间框架确保我们能够获得所有 297 个市场的足够数据,以便进行公平的比较。
有许多可能的原因可以解释为什么移动平均线比价格本身更容易预测。这可能是因为预测移动平均线更符合线性回归的理念,而预测价格则不然。线性回归假设数据是多个输入的线性组合(总和):移动平均线是之前价格值的总和,这意味着线性假设是成立的。价格本身并不是现实世界变量的简单总和,而是许多变量之间复杂的相互关系。
让我们开始
我们首先需要导入 Python 中用于科学计算的标准库。
#Load the libraries we need import pandas as pd import numpy as np import MetaTrader5 as mt5 from sklearn.model_selection import TimeSeriesSplit,cross_val_score from sklearn.linear_model import LogisticRegression,LinearRegression import matplotlib.pyplot as plt
让我们初始化 MetaTrader 5 终端。
#Initialize the terminal
mt5.initialize()
我们有多少种可用的交易品种?
#The total number of symbols we have print(f"Total Symbols Available: ",mt5.symbols_total())
获取所有交易品种的名称。
#Get the names of all pairs symbols = mt5.symbols_get() idx = [s.name for s in symbols]
创建一个数据结构,用于存储所有交易品种的预测准确率。
global_params = pd.DataFrame(index=idx,columns=["OHLC Error","MAR Error","Noise Levels","Divergence Error"]) global_params
图 1:我们的数据结构将存储终端中所有市场的准确率
定义我们的时间序列拆分对象。
#Define the time series split object tscv = TimeSeriesSplit(n_splits=5,gap=10)
在所有交易品种上检测准确率。
#Iterate over all symbols for i in np.arange(global_params.dropna().shape[0],len(idx)): #Fetch M1 Data data = pd.DataFrame(mt5.copy_rates_from_pos(cols[i],mt5.TIMEFRAME_M1,0,50000)) data.rename(columns={"open":"Open","high":"High","low":"Low","close":"Close"},inplace=True) #Define our period period = 10 #Add the classical target data.loc[data["Close"].shift(-period) > data["Close"],"OHLC Target"] = 1 #Calculate the returns data.loc[:,["Open","High","Low","Close"]] = data.loc[:,["Open","High","Low","Close"]].diff(period) data["RMA"] = data["Close"].rolling(period).mean() #Calculate our new target data.dropna(inplace=True) data.reset_index(inplace=True,drop=True) data.loc[data["RMA"].shift(-period) > data["RMA"],"New Target"] = 1 data = data.iloc[0:-period,:] #Calculate the divergence target data.loc[data["OHLC Target"] != data["New Target"],"Divergence Target"] = 1 #Noise ratio global_params.iloc[i,2] = data.loc[data["New Target"] != data["OHLC Target"]].shape[0] / data.shape[0] #Test our accuracy predicting the future close price score = cross_val_score(LogisticRegression(),data.loc[:,["Open","High","Low","Close"]],data["OHLC Target"],cv=tscv) global_params.iloc[i,0] = score.mean() #Test our accuracy predicting the moving average of future returns score = cross_val_score(LogisticRegression(),data.loc[:,["Open","Close","RMA"]],data["New Target"],cv=tscv) global_params.iloc[i,1] = score.mean() #Test our accuracy predicting the future divergence between price and its moving average score = cross_val_score(LogisticRegression(),data.loc[:,["Open","Close","RMA"]],data["Divergence Target"],cv=tscv) global_params.iloc[i,3] = score.mean() print(f"{((i/len(idx)) * 100)}% complete") #We are done print("Done")
完成!
分析结果
现在我们已经获取了市场数据,并对模型在两个目标上的表现进行了评估,接下来让我们总结一下在所有市场上的测试结果。我们先从总结对未来收盘价变化的预测准确率开始。图 2 展示了对未来收盘价变化的预测总结。红色水平线代表 50% 的准确率阈值。我们的平均准确率用蓝色水平线表示。正如我们所见,我们的平均准确率与 50% 的阈值相差不大,这并不是一个令人鼓舞的结果。
然而,为了公平起见,我们也可以观察到某些特定市场明显高于平均水平,其预测准确率超过 65%。这令人印象深刻,但也需要进一步探究以确定这些结果是否有意义,还是仅仅出于偶然。
global_params.iloc[:,0].plot() plt.title("OHLC Accuracy") plt.xlabel("Market") plt.ylabel("5-fold Accuracy %") plt.axhline(global_params.iloc[:,0].mean(),linestyle='--') plt.axhline(0.5,linestyle='--',color='red')
图 2:我们预测价格变化的平均准确率
接下来,我们将关注对移动平均线变化的预测准确率。图 3 为我们总结了数据。同样,红线代表 50% 的阈值,黄线代表我们预测价格变化的平均准确率,而蓝线是我们预测移动平均线变化的平均准确率。简单地说我们的模型更擅长预测移动平均线,这其实不言而喻。我认为这已经不再是一个有争议的问题,而是一个事实:我们的模型在预测某些指标方面比预测价格本身更胜一筹。
global_params.iloc[:,1].plot() plt.title("Moving Average Returns Accuracy") plt.xlabel("Market") plt.ylabel("5-fold Accuracy %") plt.axhline(global_params.iloc[:,1].mean(),linestyle='--') plt.axhline(global_params.iloc[:,0].mean(),linestyle='--',color='orange') plt.axhline(0.5,linestyle='--',color='red')
图 3:我们预测移动平均线变化的准确率
接下来我们观察价格和移动平均线背离的比率。接近 50% 的背离水平是不好的,因为这意味着我们无法合理确定价格和移动平均线是会朝同一方向移动,还是会朝相反方向移动。幸运的是,在我们评估的所有品种中,噪声水平保持一致。噪声水平在 35% 到 30% 之间波动。
global_params.iloc[:,2].plot() plt.title("Noise Level") plt.xlabel("Market") plt.ylabel("Percentage of Divergence:Price And Moving Average") plt.axhline(global_params.iloc[:,2].mean(),linestyle='--')
图 4:可视化所有市场的噪声水平
如果两个变量的比率几乎保持不变,那么这可能意味着它们之间存在某种我们可以建模的关系。接下来观察我们对价格和移动平均线背离的预测能力。逻辑很简单:如果我们的模型预测移动平均线将会下跌,能否合理预测价格是否会朝同一方向移动,还是会与移动平均线背离?事实证明,我们可以以相当可靠的准确率(平均接近 70%)来预测背离。
global_params.iloc[:,3].plot() plt.title("Divergence Accuracy") plt.xlabel("Market") plt.ylabel("5-fold Accuracy %") plt.axhline(global_params.iloc[:,3].mean(),linestyle='--')
图 5:我们预测价格和移动平均线背离的准确率
我们还可以将我们的发现总结成表格形式。这样就可以轻松比较预测市场价格与预测移动平均线之间的准确率水平。请注意,尽管我们的移动平均线可能“滞后”于价格,但预测反转的准确率仍然显著高于预测价格本身的准确率。
指标 | 准确率 |
---|---|
误差 | 0.525353 |
移动平均线误差 | 0.705468 |
噪声水平 | 0.317187 |
背离误差 | 0.682069 |
让我们看看我们的模型在哪些市场上表现最好。
global_params.sort_values("MAR Error",ascending=False)
图 6:我们表现最好的市场
针对表现最好的市场进行优化
接下来我们将针对表现最好的市场之一,对移动平均线指标进行定制优化。我们还将直观地比较新的特征与经典特征。首先从指定我们选择的市场开始。
symbol = "AUDJPY"
确保我们可以连接到终端。
#Reach the terminal
mt5.initialize()
现在,获取市场数据。
data = pd.DataFrame(mt5.copy_rates_from_pos(symbol,mt5.TIMEFRAME_D1,365*2,5000))
导入所需的库。
#Standard libraries import seaborn as sns from mpl_toolkits.mplot3d import Axes3D from sklearn.linear_model import LinearRegression from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error from sklearn.model_selection import cross_val_score,TimeSeriesSplit
定义我们计算周期的起始点和预测范围的终点。确保两个输入的维度相同,否则代码会出错。
#Define the input range x_min , x_max = 2,100 #Look ahead y_min , y_max = 2,100 #Period
以 5 为步长对输入域进行采样,以便我们的计算既详细又经济。
#Sample input range uniformly x_axis = np.arange(x_min,x_max,2) #Look ahead y_axis = np.arange(y_min,y_max,2) #Period
使用 x_axis 和 y_axis 创建一个网格。网格由两个二维数组组成,定义了我们希望评估的所有可能的预测范围和周期的组合。
#Create a meshgrid
x , y = np.meshgrid(x_axis,y_axis)
接下来,需要一个函数来获取市场数据并对其进行标记,以便我们评估。
def clean_data(look_ahead,period): #Fetch the data from our terminal and clean it up data = pd.DataFrame(mt5.copy_rates_from_pos('AUDJPY',mt5.TIMEFRAME_D1,365*2,5000)) data['time'] = pd.to_datetime(data['time'],unit='s') data['MA'] = data['close'].rolling(period).mean() #Transform the data #Target data['Target'] = data['MA'].shift(-look_ahead) - data['MA'] #Change in price data['close'] = data['close'] - data['close'].shift(period) #Change in MA data['MA'] = data['MA'] - data['MA'].shift(period) data.dropna(inplace=True) data.reset_index(drop=True,inplace=True) return(data)
接下来的函数将对我们的模型执行一个5折交叉验证。
#Evaluate the objective function def evaluate(look_ahead,period): #Define the model model = LinearRegression() #Define our time series split tscv = TimeSeriesSplit(n_splits=5,gap=look_ahead) temp = clean_data(look_ahead,period) score = np.mean(cross_val_score(model,temp.loc[:,["Open","High","Low","Close"]],temp["Target"],cv=tscv)) return(score)
最后是我们的目标函数。我们的目标函数仅仅是我们在新设置下对模型进行评估时的五折验证误差。回想一下,我们正在试图找到模型应该预测到未来的最佳距离,此外,我们还在试图确定用于计算价格变化的周期。
#Define the objective def objective(x,y): #Define the output matrix results = np.zeros([x.shape[0],y.shape[0]]) #Fill in the output matrix for i in np.arange(0,x.shape[0]): #Select the rows look_ahead = x[i] period = y[i] for j in np.arange(0,y.shape[0]): results[i,j] = evaluate(look_ahead[j],period[j]) return(results)
我们将评估模型与市场之间的关系,尝试直接预测价格水平的变化。图 7 展示了我们的模型与价格变化之间的关系,而图 8 展示了我们的模型与移动平均线变化之间的关系。在这两个图中,白点代表误差最低的输入组合。
res = objective(x,y) res = np.abs(res)
绘制我们的模型在预测 AUDJPY 日线时的最佳表现。数据显示,当我们预测未来价格变化时,最多只能向前预测一步。人类交易者在下单时并不会仅仅向前看一步。因此,通过直接预测市场价格所获得的结果,限制了我们的方法,并使我们的模型局限于下一个蜡烛图。
plt.contourf(x,y,res,100,cmap="jet") plt.plot(x_axis[res.min(axis=0).argmin()],y_axis[res.min(axis=1).argmin()],'.',color='white') plt.ylabel("Differencing Period") plt.xlabel("Forecast Horizon") plt.title("Linear Regression Accuracy Forecasting AUDJPY Daily Return")
图 7:可视化我们的模型预测未来价格水平的能力。
当我们开始预测移动平均线的变化,而不是价格的变化时,我们可以观察到我们的最优预测范围向右移动。图 8 显示,通过预测移动平均线的变化,我们现在可以可靠地预测到未来 22 步,而直接预测价格变化时只能预测到未来一步。
plt.contourf(x,y,res,100,cmap="jet") plt.plot(x_axis[res.min(axis=0).argmin()],y_axis[res.min(axis=1).argmin()],'.',color='white') plt.ylabel("Differencing Period") plt.xlabel("Forecast Horizon") plt.title("Linear Regression Accuracy Forecasting AUDJPY Daily Moving Average Return")
图 8:可视化我们的模型预测未来移动平均线水平的能力。
更令人印象深刻的是,在最优点上,我们对两个目标的误差水平是相同的。换句话说,我们的模型预测未来 40 步的移动平均线变化的难度,与预测未来一步的价格变化的难度相同。因此,移动平均线预测为我们提供了更大的预测范围,而不会增加我们预测的误差。
当我们以三维方式可视化我们的测试结果时,两个目标之间的差异变得清晰。图 9 展示了价格水平变化与我们模型的预测参数之间的关系。从数据中我们可以看到一个明显的趋势:随着我们向更远的未来进行预测,我们的结果变得更差。因此,以这种方式设计我们的 AI 模型时,它们在一定程度上是“短视的”,无法合理预测超过 20 步。
图 10 是根据我们的模型与预测移动平均线变化时的误差之间的关系生成的。这个图展示了理想的特性,我们可以清楚地看到,随着向更远的未来进行预测并增加计算移动平均线变化的周期,我们的误差率会平稳下降到一个最低点,然后再次上升。图像表明,与预测价格相比,我们的模型预测移动平均线要容易得多。
#Create a surface plot fig , ax = plt.subplots(subplot_kw={"projection":"3d"}) fig.set_size_inches(8,8) ax.plot_surface(x,y,optimal_nn_res,cmap="jet")
图 9:可视化我们的模型与 AUDJPY 日价格变化之间的关系。
图 10:我们的模型与 AUDJPY 对的移动平均线变化之间的关系。
非线性变换:小波去噪
到目前为止,我们只对数据应用了线性变换。我们可以更深入地探索,并对模型的输入数据应用非线性变换。特征工程有时只是一个试错的过程。我们并不总是能保证会得到更好的结果。因此,我们以一种尝试的方式应用这些变换,没有一个精确的公式来告诉我们,在任何给定时刻应该应用哪种“最佳”变换。
小波变换是一种用于创建数据的频率和时间表示的数学工具。它通常用于信号和图像处理任务,以分离我们试图处理的信号中的噪声。在应用变换后,我们的数据将进入频率域。其想法是,我们数据中的噪声来自于变换识别出的小频率值。所有低于某个阈值的值,将以两种可能的方式之一被压缩为 0。结果是对原始数据的稀疏表示。
小波去噪比其他流行技术(如快速傅里叶变换 [FFT])具有更多优势。对于可能不熟悉的读者来说,傅里叶变换将任何信号表示为正弦波和余弦波的总和。不幸的是,傅里叶变换会过滤掉高频值。这并不总是我们希望的,特别是对于信号在高频域中的数据。鉴于我们不确定信号是在高频域还是低频域中,我们需要一种足够灵活的变换,能够以无监督的方式完成这项任务。小波变换将在过滤掉尽可能多的噪声的同时,保留数据中的信号。
如果您想跟随我们,请确保您已安装 scikit-learn-image 及其依赖项 PyWavelets。对于希望在 MQL5 中创建完整交易应用程序的读者来说,从头开始实现和调试变换可能会过于复杂。对于我们来说,不使用它更容易取得进展。而对于那些可能希望使用 Python 库与终端交互的读者来说,变换是一个值得纳入您工具库的工具。
我们可以比较验证准确率的变化,看看变换是否有助于我们的模型,答案是肯定的。请注意,我们只对模型的输入应用变换,而不是目标输出,我们不对目标输出做变换。观察到验证准确率确实下降了。我们使用小波变换的硬阈值,因此它将所有噪声系数设置为 0。或者,可以使用软阈值,这将引导我们的噪声系数趋近于 0,但可能不会将它们恰好设置为 0。
#Benchmark Score np.mean(cross_val_score(LinearRegression(),data.loc[:,["MA"]],data["Target"]))
#Wavelet denoising data["Denoised"] = denoise_wavelet( data["MA"], method='BayesShrink', mode='hard', rescale_sigma=True, wavelet_levels = 3, wavelet='sym5' ) np.mean(cross_val_score(LinearRegression(),np.sqrt(np.log(data.loc[:,["Denoised"]])),data["Target"]))
构建定制化 AI 模型
现在我们已经知道了应该预测到未来的理想距离,以及我们理想的移动平均线周期。让我们直接从 MetaTrader 5 终端获取市场数据,以确保我们的 AI 模型在真实交易中观察到的指标值与训练时使用的指标值相同。我们希望尽可能模拟真实交易环境。
在脚本中,移动平均线周期将与我们上面计算的理想移动平均线周期相匹配。此外,我们还将从终端获取 RSI 读数,以稳定 AI 交易机器人的行为。通过依赖两个独立指标的预测,而不是一个指标的预测,我们的 AI 模型可能会随着时间的推移更加稳定。
//+------------------------------------------------------------------+ //| ProjectName | //| Copyright 2020, CompanyName | //| http://www.companyname.net | //+------------------------------------------------------------------+ #property copyright "Gamuchirai Zororo Ndawana" #property link "https://www.mql5.com/en/users/gamuchiraindawa" #property version "1.00" #property script_show_inputs //+------------------------------------------------------------------+ //| Script Inputs | //+------------------------------------------------------------------+ input int size = 100000; //How much data should we fetch? //+------------------------------------------------------------------+ //| Global variables | //+------------------------------------------------------------------+ int ma_handler,rsi_handler; double ma_reading[],rsi_reading[]; //+------------------------------------------------------------------+ //| On start function | //+------------------------------------------------------------------+ void OnStart() { //--- Load indicator ma_handler = iMA(Symbol(),PERIOD_CURRENT,40,0,MODE_SMA,PRICE_CLOSE); rsi_handler = iRSI(Symbol(),PERIOD_CURRENT,30,PRICE_CLOSE); //--- Load the indicator values CopyBuffer(ma_handler,0,0,size,ma_reading); CopyBuffer(rsi_handler,0,0,size,rsi_reading); ArraySetAsSeries(ma_reading,true); ArraySetAsSeries(rsi_reading,true); //--- File name string file_name = "Market Data " + Symbol() +" MA RSI " + " As Series.csv"; //--- Write to file int file_handle=FileOpen(file_name,FILE_WRITE|FILE_ANSI|FILE_CSV,","); for(int i= size;i>=0;i--) { if(i == size) { FileWrite(file_handle,"Time","Open","High","Low","Close","MA","RSI"); } else { FileWrite(file_handle,iTime(Symbol(),PERIOD_CURRENT,i), iOpen(Symbol(),PERIOD_CURRENT,i), iHigh(Symbol(),PERIOD_CURRENT,i), iLow(Symbol(),PERIOD_CURRENT,i), iClose(Symbol(),PERIOD_CURRENT,i), ma_reading[i], rsi_reading[i] ); } } //--- Close the file FileClose(file_handle); } //+------------------------------------------------------------------+
现在我们已经创建了脚本,可以直接将其拖动并放置到目标市场上,然后就可以开始处理市场数据了。为了让我们的回测结果具有实际意义,我们从生成的 CSV 文件中删除了最近两年的市场数据。这样一来,当我们从 2023 年到 2024 年回测我们的策略时,观察到的结果将真实反映我们的模型在未见过的数据上的表现。
#Read in the data data = pd.read_csv("Market Data AUDJPY MA RSI As Series.csv") #Let's drop the last two years of data. We'll use that to validate our model in the back test data = data.iloc[365:-(365 * 2),:] data
图 11:在 22 年的市场数据上训练我们的模型,排除 2023-2024 年期间。
现在让我们为机器学习标记数据。我们希望帮助我们的模型学习在技术指标变化的情况下价格的变化。为了帮助我们的模型学习这种关系,我们将转换输入以表示指标的当前状态。例如,我们的 RSI 指标将有 3 种可能的状态:
- 高于 70。
- 低于 30。
- 介于 70 和 30 之间。
#MA States data["MA 1"] = 0 data["MA 2"] = 0 data.loc[data["MA"] > data["MA"].shift(40),"MA 1"] = 1 data.loc[data["MA"] <= data["MA"].shift(40),"MA 2"] = 1 #RSI States data["RSI 1"] = 0 data["RSI 2"] = 0 data["RSI 3"] = 0 data.loc[data["RSI"] < 30,"RSI 1"] = 1 data.loc[data["RSI"] > 70,"RSI 2"] = 1 data.loc[(data["RSI"] >= 30) & (data["RSI"] <= 70),"RSI 3"] = 1 #Target data["Target"] = data["Close"].shift(-22) - data["Close"] data["MA Target"] = data["MA"].shift(-22) - data["MA"] #Clean up the data data = data.dropna() data = data.iloc[40:,:] data = data.reset_index(drop=True)
现在我们可以开始计算准确率了。
from sklearn.linear_model import Ridge from sklearn.model_selection import TimeSeriesSplit,cross_val_score
应用这些转换后,我们可以观察到 RSI 穿过我们指定的 3 个区域时价格的平均变化。我们的线性模型的系数可以被解释为与每个 RSI 区域相关的平均价格变化。这些发现有时可能与关于如何使用该指标的经典教义相悖。例如,我们的 Ridge 模型已经学会,当 RSI 读数超过 70 时,价格水平倾向于下跌,否则当 RSI 读数小于 70 时,未来价格水平倾向于上涨。
#Our model can suggest optimal ways of using the RSI indicator #Our model has learned that on average price tends to fall the RSI reading is less than 30 and increases otherwises model = Ridge() model.fit(data.loc[:,["RSI 1","RSI 2","RSI 3"]] , data["Target"]) model.coef_
我们的 Ridge 模型可以根据 RSI 的当前状态很好地预测未来价格。
#RSI state np.mean(cross_val_score(Ridge(),data.loc[:,["RSI 1","RSI 2","RSI 3"]] , data["Target"],cv=tscv))
同样,我们的模型也从移动平均线指标的变化中学到了自己的交易规则。模型的第一个系数为负,这意味着当移动平均线在 40 根蜡烛图上上升时,移动平均线倾向于下跌。第二个系数为正。因此,从我们在终端中获取的历史数据来看,当 40 期 AUDJPY 日移动平均线在 40 根蜡烛图上下跌时,它们倾向于随后上涨。我们的模型从数据中学到了一种均值回归策略。
#Our model can suggest optimal ways of using the RSI indicator #Our model has learned that on average price tends to fall the RSI reading is less than 30 and increases otherwises model = Ridge() model.fit(data.loc[:,["MA 1","MA 2"]] , data["Target"]) model.coef_
当我们将移动平均线指标的当前状态提供给模型时,模型的表现甚至更好。
#MA state np.mean(cross_val_score(Ridge(),data.loc[:,["MA 1","MA 2"]] , data["Target"],cv=tscv))
转换为 ONNX
现在我们已经找到了移动平均线预测的理想输入参数,让我们准备将模型转换为 ONNX 格式。开放神经网络交换 (ONNX) 允许我们在一个与语言无关的框架中构建机器学习模型。ONNX 协议是一个开源倡议,旨在为机器学习模型创建一个通用的标准表示,只要完全采用 ONNX API,我们就可以在任何语言中构建和部署机器学习模型。
首先,让我们根据找到的最佳输入获取所需的数据。
#Fetch clean data new_data = clean_data(140,130)
导入所需的库。
import onnx from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType
在我们所有的数据上拟合 RSI 模型。
#First we will export the RSI model rsi_model = Ridge() rsi_model.fit(data.loc[:,['RSI 1','RSI 2','RSI 3']],data.loc[:,'Target'])
在我们所有的数据上拟合移动平均线模型。
#Finally we will export the MA model ma_model = Ridge() ma_model.fit(data.loc[:,['MA 1','MA 2']],data.loc[:,'MA Target'])
定义我们模型的输入形状,并将其保存到磁盘上。
initial_types = [('float_input', FloatTensorType([1, 3]))] onnx.save(convert_sklearn(rsi_model,initial_types=initial_types,target_opset=12),"AUDJPY D1 RSI AI F22 P40.onnx") initial_types = [('float_input', FloatTensorType([1, 2]))] onnx.save(convert_sklearn(ma_model,initial_types=initial_types,target_opset=12),"AUDJPY D1 MA AI F22 P40.onnx")
用MQL5来实现
现在,我们已经准备好开始在 MQL5 中构建交易应用程序了。我们希望构建一个能够利用我们对移动平均线的新策略来开仓和平仓的交易应用程序。不仅如此,我们通常会使用较慢的指标来引导我们的模型,这些指标生成信号的方式不会过于激进。让我们尝试模仿人类交易者并不总是强行在市场上建立头寸的方式。
此外,我们还将实施跟踪止损以确保良好的风险管理。将使用平均真实范围(ATR)指标来动态设置止损和获利水平。我们的策略主要基于移动平均线通道。
策略会提前预测未来 40 步的移动平均线,以在我们开仓之前为我们提供确认信号。我们在训练期间未向模型展示的 1 年历史数据上对这一策略进行了回测。
首先,我们将从导入我们刚刚创建的 ONNX 文件开始。
//+------------------------------------------------------------------+ //| GBPUSD AI.mq5 | //| Gamuchirai Zororo Ndawana | //| https://www.mql5.com/en/gamuchiraindawa | //+------------------------------------------------------------------+ #property copyright "Gamuchirai Zororo Ndawana" #property link "https://www.mql5.com/en/gamuchiraindawa" #property version "1.00" //+------------------------------------------------------------------+ //| Load our resources | //+------------------------------------------------------------------+ #resource "\\Files\\AUDJPY D1 MA AI F22 P40.onnx" as const uchar onnx_buffer[]; #resource "\\Files\\AUDJPY D1 RSI AI F22 P40.onnx" as const uchar rsi_onnx_buffer[];
我们将导入所需的库。
//+------------------------------------------------------------------+ //| Libraries | //+------------------------------------------------------------------+ #include <Trade\Trade.mqh> CTrade Trade; #include <Trade\OrderInfo.mqh> class COrderInfo;
现在定义我们需要的全局变量。
//+------------------------------------------------------------------+ //| Global variables | //+------------------------------------------------------------------+ long onnx_model; int ma_handler,state; double bid,ask,vol; vectorf model_forecast = vectorf::Zeros(1); vectorf rsi_model_output = vectorf::Zeros(1); double min_volume,max_volume_increase, volume_step, buy_stop_loss, sell_stop_loss,atr_stop,risk_equity; double take_profit = 0; double close_price[3],atr_reading[],ma_buffer[]; long min_distance,login; int atr,close_average,ticket_1,ticket_2; bool authorized = false; double margin,lot_step; string currency,server; bool all_closed =true; int rsi_handler; long rsi_onnx_model; double indicator_reading[]; ENUM_ACCOUNT_TRADE_MODE account_type; const double stop_percent = 1;
定义所需的输入参数。
//+------------------------------------------------------------------+ //| Technical indicators | //+------------------------------------------------------------------+ input group "Money Management" input int lot_multiple = 10; // How big should the lot size be? input double profit_target = 0; // Profit Target input double loss_target = 0; // Max Loss Allowed input group "Money Management" const int atr_period = 200; //ATR Period input double atr_multiple =2.5; //ATR Multiple
现在必须明确定义我们的交易应用程序如何初始化。我们首先会检查用户是否已授予应用程序交易权限。如果获得继续的权限,那么我们将加载技术指标和 ONNX 模型。
int OnInit() { //Authorization if(!TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { Comment("Press Ctrl + E To Give The Robot Permission To Trade And Reload The Program"); return(INIT_FAILED); } else if(!MQLInfoInteger(MQL_TRADE_ALLOWED)) { Comment("Reload The Program And Make Sure You Clicked Allow Algo Trading"); return(INIT_FAILED); } else { Comment("This License is Genuine"); setup(); } //Everything was okay return(INIT_SUCCEEDED); }
每当我们的交易应用程序不再使用时,必须释放不再使用的资源,以确保用户能够获得良好的使用体验。
//+------------------------------------------------------------------+ //| Expert deinitialization function | //+------------------------------------------------------------------+ void OnDeinit(const int reason) { OnnxRelease(onnx_model); OnnxRelease(rsi_onnx_model); IndicatorRelease(atr) }
每当收到更新的价格报价时,我们将更新变量并检查新的交易机会。否则,如果我们已经有持仓,那么更新跟踪止损位。
//+------------------------------------------------------------------+ //| Expert tick function | //+------------------------------------------------------------------+ void OnTick() { //Update technical data update(); if(PositionsTotal() == 0) { check_setup(); } if(PositionsTotal() > 0) { check_atr_stop(); } }
要从模型中获得预测值,我们必须定义当前RSI和移动平均线的状态
//+------------------------------------------------------------------+ //| Get a prediction from our model | //+------------------------------------------------------------------+ int model_predict(void) { //MA Forecast vectorf model_inputs = vectorf::Zeros(2); vectorf rsi_model_inputs = vectorf::Zeros(3); CopyBuffer(ma_handler,0,0,40,ma_buffer); if(ma_buffer[0] > ma_buffer[39]) { model_inputs[0] = 1; model_inputs[1] = 0; } else if(ma_buffer[0] < ma_buffer[39]) { model_inputs[1] = 1; model_inputs[0] = 0; } //RSI Forecast CopyBuffer(rsi_handler,0,0,1,indicator_reading); if(indicator_reading[0] < 30) { rsi_model_inputs[0] = 1; rsi_model_inputs[1] = 0; rsi_model_inputs[2] = 0; } else if(indicator_reading[0] >70) { rsi_model_inputs[0] = 0; rsi_model_inputs[1] = 1; rsi_model_inputs[2] = 0; } else { rsi_model_inputs[0] = 0; rsi_model_inputs[1] = 0; rsi_model_inputs[2] = 1; } //Model predictions OnnxRun(onnx_model,ONNX_DEFAULT,model_inputs,model_forecast); OnnxRun(rsi_onnx_model,ONNX_DEFAULT,rsi_model_inputs,rsi_model_output); //Evaluate model output for buy setup if(((rsi_model_output[0] > 0) && (model_forecast[0] > 0))) { //AI Models forecast Comment("AI Forecast: UP"); return(1); } //Evaluate model output for a sell setup if((rsi_model_output[0] < 0) && (model_forecast[0] < 0)) { Comment("AI Forecast: DOWN"); return(-1); } //Otherwise no position was found return(0); }
更新全局变量。将这些更新操作集中在一个函数调用中,比直接在 OnTick() 处理程序中执行所有代码要更加清晰和高效。
//+------------------------------------------------------------------+ //| Update our market data | //+------------------------------------------------------------------+ void update(void) { ask = SymbolInfoDouble(_Symbol,SYMBOL_ASK); bid = SymbolInfoDouble(_Symbol,SYMBOL_BID); buy_stop_loss = 0; sell_stop_loss = 0; static datetime time_stamp; datetime time = iTime(_Symbol,PERIOD_CURRENT,0); check_price(3); CopyBuffer(atr,0,0,1,atr_reading); CopyBuffer(ma_handler,0,0,1,ma_buffer); ArraySetAsSeries(atr_reading,true); atr_stop = ((min_volume + atr_reading[0]) * atr_multiple); //On Every Candle if(time_stamp != time) { //Mark the candle time_stamp = time; OrderCalcMargin(ORDER_TYPE_BUY,_Symbol,min_volume,ask,margin); } }
加载所需的资源,如技术指标,账户和市场信息以及其他类似数据。
//+------------------------------------------------------------------+ //| Load resources | //+------------------------------------------------------------------+ bool setup(void) { //Account Info currency = AccountInfoString(ACCOUNT_CURRENCY); server = AccountInfoString(ACCOUNT_SERVER); login = AccountInfoInteger(ACCOUNT_LOGIN); //Indicators atr = iATR(_Symbol,PERIOD_CURRENT,atr_period); //Setup technical indicators ma_handler =iMA(Symbol(),PERIOD_CURRENT,40,0,MODE_SMA,PRICE_LOW); vol = SymbolInfoDouble(Symbol(),SYMBOL_VOLUME_MIN) * lot_multiple; rsi_handler = iRSI(Symbol(),PERIOD_CURRENT,30,PRICE_CLOSE); //Market Information min_volume = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN); max_volume_increase = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MAX) / SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN); min_distance = SymbolInfoInteger(_Symbol,SYMBOL_TRADE_STOPS_LEVEL); lot_step = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_STEP); //Define our ONNX model ulong ma_input_shape [] = {1,2}; ulong rsi_input_shape [] = {1,3}; ulong output_shape [] = {1,1}; //Create the model onnx_model = OnnxCreateFromBuffer(onnx_buffer,ONNX_DEFAULT); rsi_onnx_model = OnnxCreateFromBuffer(rsi_onnx_buffer,ONNX_DEFAULT); if((onnx_model == INVALID_HANDLE) || (rsi_onnx_model == INVALID_HANDLE)) { Comment("[ERROR] Failed to load AI module correctly"); return(false); } //Validate I/O if((!OnnxSetInputShape(onnx_model,0,ma_input_shape)) || (!OnnxSetInputShape(rsi_onnx_model,0,rsi_input_shape))) { Comment("[ERROR] Failed to set input shape correctly: ",GetLastError()); return(false); } if((!OnnxSetOutputShape(onnx_model,0,output_shape)) || (!OnnxSetOutputShape(rsi_onnx_model,0,output_shape))) { Comment("[ERROR] Failed to load AI module correctly: ",GetLastError()); return(false); } //Everything went fine return(true); }
把它们都整合起来,这就是我们的交易程序。
//+------------------------------------------------------------------+ //| GBPUSD AI.mq5 | //| Gamuchirai Zororo Ndawana | //| https://www.mql5.com/en/gamuchiraindawa | //+------------------------------------------------------------------+ #property copyright "Gamuchirai Zororo Ndawana" #property link "https://www.mql5.com/en/gamuchiraindawa" #property version "1.00" //+------------------------------------------------------------------+ //| Load our resources | //+------------------------------------------------------------------+ #resource "\\Files\\AUDJPY D1 MA AI F22 P40.onnx" as const uchar onnx_buffer[]; #resource "\\Files\\AUDJPY D1 RSI AI F22 P40.onnx" as const uchar rsi_onnx_buffer[]; //+------------------------------------------------------------------+ //| Libraries | //+------------------------------------------------------------------+ #include <Trade\Trade.mqh> CTrade Trade; #include <Trade\OrderInfo.mqh> class COrderInfo; //+------------------------------------------------------------------+ //| Global variables | //+------------------------------------------------------------------+ long onnx_model; int ma_handler,state; double bid,ask,vol; vectorf model_forecast = vectorf::Zeros(1); vectorf rsi_model_output = vectorf::Zeros(1); double min_volume,max_volume_increase, volume_step, buy_stop_loss, sell_stop_loss,atr_stop,risk_equity; double take_profit = 0; double close_price[3],atr_reading[],ma_buffer[]; long min_distance,login; int atr,close_average,ticket_1,ticket_2; bool authorized = false; double margin,lot_step; string currency,server; bool all_closed =true; int rsi_handler; long rsi_onnx_model; double indicator_reading[]; ENUM_ACCOUNT_TRADE_MODE account_type; const double stop_percent = 1; //+------------------------------------------------------------------+ //| Technical indicators | //+------------------------------------------------------------------+ input group "Money Management" input int lot_multiple = 10; // How big should the lot size be? input double profit_target = 0; // Profit Target input double loss_target = 0; // Max Loss Allowed input group "Money Management" input int bb_period = 36; //Bollinger band period input int ma_period = 4; //Moving average period const int atr_period = 200; //ATR Period input double atr_multiple =2.5; //ATR Multiple //+------------------------------------------------------------------+ //| Expert initialization function | //+------------------------------------------------------------------+ int OnInit() { //Authorization if(!TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { Comment("Press Ctrl + E To Give The Robot Permission To Trade And Reload The Program"); return(INIT_FAILED); } else if(!MQLInfoInteger(MQL_TRADE_ALLOWED)) { Comment("Reload The Program And Make Sure You Clicked Allow Algo Trading"); return(INIT_FAILED); } else { Comment("This License is Genuine"); setup(); } //--- Everything was okay return(INIT_SUCCEEDED); } //+------------------------------------------------------------------+ //| Expert deinitialization function | //+------------------------------------------------------------------+ void OnDeinit(const int reason) { //--- OnnxRelease(onnx_model); OnnxRelease(rsi_onnx_model); } //+------------------------------------------------------------------+ //| Expert tick function | //+------------------------------------------------------------------+ void OnTick() { //--- Update technical data update(); if(PositionsTotal() == 0) { check_setup(); } if(PositionsTotal() > 0) { check_atr_stop(); } } //+------------------------------------------------------------------+ //| Get a prediction from our model | //+------------------------------------------------------------------+ int model_predict(void) { //MA Forecast vectorf model_inputs = vectorf::Zeros(2); vectorf rsi_model_inputs = vectorf::Zeros(3); CopyBuffer(ma_handler,0,0,40,ma_buffer); if(ma_buffer[0] > ma_buffer[39]) { model_inputs[0] = 1; model_inputs[1] = 0; } else if(ma_buffer[0] < ma_buffer[39]) { model_inputs[1] = 1; model_inputs[0] = 0; } //RSI Forecast CopyBuffer(rsi_handler,0,0,1,indicator_reading); if(indicator_reading[0] < 30) { rsi_model_inputs[0] = 1; rsi_model_inputs[1] = 0; rsi_model_inputs[2] = 0; } else if(indicator_reading[0] >70) { rsi_model_inputs[0] = 0; rsi_model_inputs[1] = 1; rsi_model_inputs[2] = 0; } else { rsi_model_inputs[0] = 0; rsi_model_inputs[1] = 0; rsi_model_inputs[2] = 1; } //Model predictions OnnxRun(onnx_model,ONNX_DEFAULT,model_inputs,model_forecast); OnnxRun(rsi_onnx_model,ONNX_DEFAULT,rsi_model_inputs,rsi_model_output); //Evaluate model output for buy setup if(((rsi_model_output[0] > 0) && (model_forecast[0] > 0))) { //AI Models forecast Comment("AI Forecast: UP"); return(1); } //Evaluate model output for a sell setup if((rsi_model_output[0] < 0) && (model_forecast[0] < 0)) { Comment("AI Forecast: DOWN"); return(-1); } //Otherwise no position was found return(0); } //+------------------------------------------------------------------+ //| Check for valid trade setups | //+------------------------------------------------------------------+ void check_setup(void) { int res = model_predict(); if(res == -1) { Trade.Sell(vol,Symbol(),bid,0,0,"VD V75 AI"); state = -1; } else if(res == 1) { Trade.Buy(vol,Symbol(),ask,0,0,"VD V75 AI"); state = 1; } } //+------------------------------------------------------------------+ //| Update our market data | //+------------------------------------------------------------------+ void update(void) { ask = SymbolInfoDouble(_Symbol,SYMBOL_ASK); bid = SymbolInfoDouble(_Symbol,SYMBOL_BID); buy_stop_loss = 0; sell_stop_loss = 0; static datetime time_stamp; datetime time = iTime(_Symbol,PERIOD_CURRENT,0); check_price(3); CopyBuffer(atr,0,0,1,atr_reading); CopyBuffer(ma_handler,0,0,1,ma_buffer); ArraySetAsSeries(atr_reading,true); atr_stop = ((min_volume + atr_reading[0]) * atr_multiple); //On Every Candle if(time_stamp != time) { //Mark the candle time_stamp = time; OrderCalcMargin(ORDER_TYPE_BUY,_Symbol,min_volume,ask,margin); } } //+------------------------------------------------------------------+ //+------------------------------------------------------------------+ //| Load resources | //+------------------------------------------------------------------+ bool setup(void) { //Account Info currency = AccountInfoString(ACCOUNT_CURRENCY); server = AccountInfoString(ACCOUNT_SERVER); login = AccountInfoInteger(ACCOUNT_LOGIN); //Indicators atr = iATR(_Symbol,PERIOD_CURRENT,atr_period); //Setup technical indicators ma_handler =iMA(Symbol(),PERIOD_CURRENT,40,0,MODE_SMA,PRICE_LOW); vol = SymbolInfoDouble(Symbol(),SYMBOL_VOLUME_MIN) * lot_multiple; rsi_handler = iRSI(Symbol(),PERIOD_CURRENT,30,PRICE_CLOSE); //Market Information min_volume = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN); max_volume_increase = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MAX) / SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_MIN); min_distance = SymbolInfoInteger(_Symbol,SYMBOL_TRADE_STOPS_LEVEL); lot_step = SymbolInfoDouble(_Symbol,SYMBOL_VOLUME_STEP); //Define our ONNX model ulong ma_input_shape [] = {1,2}; ulong rsi_input_shape [] = {1,3}; ulong output_shape [] = {1,1}; //Create the model onnx_model = OnnxCreateFromBuffer(onnx_buffer,ONNX_DEFAULT); rsi_onnx_model = OnnxCreateFromBuffer(rsi_onnx_buffer,ONNX_DEFAULT); if((onnx_model == INVALID_HANDLE) || (rsi_onnx_model == INVALID_HANDLE)) { Comment("[ERROR] Failed to load AI module correctly"); return(false); } //--- Validate I/O if((!OnnxSetInputShape(onnx_model,0,ma_input_shape)) || (!OnnxSetInputShape(rsi_onnx_model,0,rsi_input_shape))) { Comment("[ERROR] Failed to set input shape correctly: ",GetLastError()); return(false); } if((!OnnxSetOutputShape(onnx_model,0,output_shape)) || (!OnnxSetOutputShape(rsi_onnx_model,0,output_shape))) { Comment("[ERROR] Failed to load AI module correctly: ",GetLastError()); return(false); } //--- Everything went fine return(true); } //+------------------------------------------------------------------+ //| Close all our open positions | //+------------------------------------------------------------------+ void close_all() { if(PositionsTotal() > 0) { ulong ticket; for(int i =0;i < PositionsTotal();i++) { ticket = PositionGetTicket(i); Trade.PositionClose(ticket); } } } //+------------------------------------------------------------------+ //| Update our trailing ATR stop | //+------------------------------------------------------------------+ void check_atr_stop() { for(int i = PositionsTotal() -1; i >= 0; i--) { string symbol = PositionGetSymbol(i); if(_Symbol == symbol) { ulong ticket = PositionGetInteger(POSITION_TICKET); double position_price = PositionGetDouble(POSITION_PRICE_OPEN); double type = PositionGetInteger(POSITION_TYPE); double current_stop_loss = PositionGetDouble(POSITION_SL); if(type == POSITION_TYPE_BUY) { double atr_stop_loss = (ask - (atr_stop)); double atr_take_profit = (ask + (atr_stop)); if((current_stop_loss < atr_stop_loss) || (current_stop_loss == 0)) { Trade.PositionModify(ticket,atr_stop_loss,atr_take_profit); } } else if(type == POSITION_TYPE_SELL) { double atr_stop_loss = (bid + (atr_stop)); double atr_take_profit = (bid - (atr_stop)); if((current_stop_loss > atr_stop_loss) || (current_stop_loss == 0)) { Trade.PositionModify(ticket,atr_stop_loss,atr_take_profit); } } } } } //+------------------------------------------------------------------+ //| Close our open buy positions | //+------------------------------------------------------------------+ void close_buy() { ulong ticket; int type; if(PositionsTotal() > 0) { for(int i = 0; i < PositionsTotal();i++) { if(PositionGetSymbol(i) == _Symbol) { ticket = PositionGetTicket(i); type = (int)PositionGetInteger(POSITION_TYPE); if(type == POSITION_TYPE_BUY) { Trade.PositionClose(ticket); } } } } } //+------------------------------------------------------------------+ //| Close our open sell positions | //+------------------------------------------------------------------+ void close_sell() { ulong ticket; int type; if(PositionsTotal() > 0) { for(int i = 0; i < PositionsTotal();i++) { if(PositionGetSymbol(i) == _Symbol) { ticket = PositionGetTicket(i); type = (int)PositionGetInteger(POSITION_TYPE); if(type == POSITION_TYPE_SELL) { Trade.PositionClose(ticket); } } } } } //+------------------------------------------------------------------+ //| Get the most recent price values | //+------------------------------------------------------------------+ void check_price(int candles) { for(int i = 0; i < candles;i++) { close_price[i] = iClose(_Symbol,PERIOD_CURRENT,i); } } //+------------------------------------------------------------------+
现在让我们使用在训练算法时未展示给算法的数据,来对我们的交易算法进行回测。我们选择的回测时间段是从 2023 年 1 月初到 2024 年 6 月 28 日,使用 AUDJPY 货币对的日线市场报价。将“向前测试”参数设置为“否”,因为我们已经确保在训练模型时未观察到选择的日期的数据。
图 12:我们将用来评估交易策略的货币对和时间框架。
此外,将通过首先将“延迟”参数设置为“随机延迟”来模拟真实的交易条件。此参数控制我们的订单下达和执行之间的时间延迟。将其设置为随机,类似于真实交易中的情况,我们的延迟并非始终是固定的。此外,我们将指示终端使用真实价格变动来模拟市场。此设置会稍微放慢我们的回测速度,因为终端需要先从互联网上的经纪商获取详细的市场数据。
最后的参数,控制账户存款和杠杆,应根据您的交易设置进行调整。假设成功获取了所有请求的数据,我们的回测将开始。
图 13:我们将用于回测的参数。
图 14:我们的策略在模型未训练过的数据上的表现。
图15:对未知市场数据回测的更多细节。
结论
今天分析的大量市场数据清楚地告诉我们,如果你希望预测未来不足40步的时间间隔,直接预测价格可能会更好。然而,如果你希望预测未来超过40步的时间间隔,预测移动平均线的变化可能比预测价格的变化更有优势。总是有更多的改进等待我们去观察以及它们所带来的差异。我们可以清楚地看到,花时间转换数据的输入是值得的,因为它使我们能够以更有意义的方式向模型展示底层关系。
本文由MetaQuotes Ltd译自英文
原文地址: https://www.mql5.com/en/articles/16230
注意: MetaQuotes Ltd.将保留所有关于这些材料的权利。全部或部分复制或者转载这些材料将被禁止。
本文由网站的一位用户撰写,反映了他们的个人观点。MetaQuotes Ltd 不对所提供信息的准确性负责,也不对因使用所述解决方案、策略或建议而产生的任何后果负责。




显示的结果看起来很有希望,我们会试一试。
请多多指教。
谢谢。
使用 Python 和 MQL5 的特征工程(第一部分)》一文已经发表:利用移动平均线进行长期预测的人工智能模型:
作者:Gamuchirai Zororo Ndawana
部分图片无法显示...
结果看起来很有希望。
请多多指教。
谢谢。
不客气 Too Che Ng。
有了这样一个强有力的开端,肯定还有很多可以说的。
某些图像无法显示...
我很遗憾听到这个消息,我相信版主会解决这个问题的,因为他们已经有很多事情要做了。