import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import MetaTrader5 as mt5
import matplotlib.pyplot as plt
import torch.onnx
import torch.nn.functional as F

# Connect to MetaTrader 5
if not mt5.initialize():
    print("Initialize failed")
    mt5.shutdown()

# Load market data
symbol = "EURUSD"
timeframe = mt5.TIMEFRAME_M15
rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, 80000)
mt5.shutdown()

# Convert to DataFrame
data = pd.DataFrame(rates)
data['time'] = pd.to_datetime(data['time'], unit='s')
data.set_index('time', inplace=True)

# Tokenize time
data['time_token'] = (data.index.hour * 3600 + data.index.minute * 60 + data.index.second) / 86400

# Normalize prices on a rolling basis resetting at the start of each day
def normalize_daily_rolling(data):
    data['date'] = data.index.date
    data['rolling_high'] = data.groupby('date')['high'].transform(lambda x: x.expanding(min_periods=1).max())
    data['rolling_low'] = data.groupby('date')['low'].transform(lambda x: x.expanding(min_periods=1).min())

    data['norm_open'] = (data['open'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_high'] = (data['high'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_low'] = (data['low'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])
    data['norm_close'] = (data['close'] - data['rolling_low']) / (data['rolling_high'] - data['rolling_low'])

    # Replace NaNs with zeros
    data.fillna(0, inplace=True)
    return data

data = normalize_daily_rolling(data)

# Check for NaNs in the data
if data.isnull().values.any():
    print("Data contains NaNs")
    print(data.isnull().sum())

# Drop unnecessary columns
data = data[['time_token', 'norm_open', 'norm_high', 'norm_low', 'norm_close']]

# Create sequences
def create_sequences(data, seq_length):
    xs, ys = [], []
    for i in range(len(data) - seq_length):
        x = data.iloc[i:(i + seq_length)].values
        y = data.iloc[i + seq_length]['norm_close']
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

seq_length = 60
X, y = create_sequences(data, seq_length)

# Split data
split = int(len(X) * 0.8)
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]

# Convert to tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Set the seed for reproducibility
seed_value = 42
torch.manual_seed(seed_value)

# Define LSTM model class
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        h0 = torch.zeros(1, input_seq.size(1), self.hidden_layer_size).to(input_seq.device)
        c0 = torch.zeros(1, input_seq.size(1), self.hidden_layer_size).to(input_seq.device)
        lstm_out, _ = self.lstm(input_seq, (h0, c0))
        predictions = self.linear(lstm_out.view(input_seq.size(0), -1))
        return predictions[-1]

print(f"Seed value used: {seed_value}")

input_size = 5  # time_token, norm_open, norm_high, norm_low, norm_close
hidden_layer_size = 100
output_size = 1

model = LSTMModel(input_size, hidden_layer_size, output_size)
#model = torch.compile(model)
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training
epochs = 100
for epoch in range(epochs + 1):
    for seq, labels in zip(X_train, y_train):
        optimizer.zero_grad()
        y_pred = model(seq.unsqueeze(1))

        # Ensure both are tensors of shape [1]
        y_pred = y_pred.view(-1)
        labels = labels.view(-1)

        single_loss = loss_function(y_pred, labels)
        
        # Print intermediate values to debug NaN loss
        if torch.isnan(single_loss):
            print(f'Epoch {epoch} NaN loss detected')
            print('Sequence:', seq)
            print('Prediction:', y_pred)
            print('Label:', labels)

        single_loss.backward()
        optimizer.step()

    if epoch % 10 == 0 or epoch == epochs:  # Include the final epoch
        print(f'Epoch {epoch} loss: {single_loss.item()}')

# Save the model's state dictionary
torch.save(model.state_dict(), 'lstm_model.pth')

# Convert the model to ONNX format
model.eval()
dummy_input = torch.randn(seq_length, 1, input_size, dtype=torch.float32)
onnx_model_path = "lstm_model.onnx"
torch.onnx.export(model, 
                  dummy_input, 
                  onnx_model_path, 
                  input_names=['input'], 
                  output_names=['output'],
                  dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}},
                  opset_version=11)

print(f"Model has been converted to ONNX format and saved to {onnx_model_path}")

# Predictions
model.eval()
predictions = []
for seq in X_test:
    with torch.no_grad():
        predictions.append(model(seq.unsqueeze(1)).item())

# Evaluate the model
plt.plot(y_test.numpy(), label='True Prices (Normalized)')
plt.plot(predictions, label='Predicted Prices (Normalized)')
plt.legend()
plt.show()

# Calculate percent changes with a small value added to the denominator to prevent divide by zero error
true_prices = y_test.numpy()
predicted_prices = np.array(predictions)

true_pct_change = np.diff(true_prices) / (true_prices[:-1] + 1e-10)
predicted_pct_change = np.diff(predicted_prices) / (predicted_prices[:-1] + 1e-10)

# Plot the true and predicted prices
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(true_prices, label='True Prices (Normalized)')
plt.plot(predicted_prices, label='Predicted Prices (Normalized)')
plt.legend()
plt.title('True vs Predicted Prices (Normalized)')

# Plot the percent change
plt.subplot(2, 1, 2)
plt.plot(true_pct_change, label='True Percent Change')
plt.plot(predicted_pct_change, label='Predicted Percent Change')
plt.legend()
plt.title('True vs Predicted Percent Change')

plt.tight_layout()
plt.show()
