
import MetaTrader5 as mt5
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Using Agg backend for working without GUI
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv1D, MaxPooling1D, Flatten, Dropout, BatchNormalization, Input
from tensorflow.keras.callbacks import EarlyStopping
import time
from datetime import datetime, timedelta

# Set default figure size to 700px width
plt.rcParams['figure.figsize'] = [7, 4.2]  # 700px width (7 inches at 100 dpi)

# Connect to MetaTrader5 terminal
def connect_to_mt5():
    if not mt5.initialize():
        print("Error initializing MetaTrader5")
        mt5.shutdown()
        return False
    return True

# Get historical quotes
def get_historical_data(symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1, num_bars=1000):
    now = datetime.now()
    from_date = now - timedelta(days=num_bars/24)  # Approximate for hourly bars
    
    rates = mt5.copy_rates_range(symbol, timeframe, from_date, now)
    if rates is None or len(rates) == 0:
        print("Error loading historical quotes")
        return None
    
    # Convert to pandas DataFrame
    rates_frame = pd.DataFrame(rates)
    rates_frame['time'] = pd.to_datetime(rates_frame['time'], unit='s')
    rates_frame.set_index('time', inplace=True)
    
    return rates_frame

# Convert data to images for computer vision
def create_images(data, window_size=48, prediction_window=24):
    images = []
    targets = []
    
    # Using OHLC data to create images
    for i in range(len(data) - window_size - prediction_window):
        window_data = data.iloc[i:i+window_size]
        target_data = data.iloc[i+window_size:i+window_size+prediction_window]
        
        # Normalize data in the window
        scaler = MinMaxScaler(feature_range=(0, 1))
        window_scaled = scaler.fit_transform(window_data[['open', 'high', 'low', 'close']])
        
        # Predict price direction (up/down) for the forecast period
        price_direction = 1 if target_data['close'].iloc[-1] > window_data['close'].iloc[-1] else 0
        
        images.append(window_scaled)
        targets.append(price_direction)
    
    return np.array(images), np.array(targets)

# Create and train an improved computer vision model
def train_cv_model(images, targets):
    # Split data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(images, targets, test_size=0.2, shuffle=True, random_state=42)
    
    # Reshaping data for Conv1D (samples, time_steps, features)
    input_shape = X_train.shape[1:]
    
    # Create model with improved architecture
    inputs = Input(shape=input_shape)
    
    # First convolutional block
    x = Conv1D(filters=64, kernel_size=3, padding='same', activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = Conv1D(filters=64, kernel_size=3, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = Dropout(0.2)(x)
    
    # Second convolutional block
    x = Conv1D(filters=128, kernel_size=3, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = Conv1D(filters=128, kernel_size=3, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(pool_size=2)(x)
    feature_maps = x  # Store feature maps for visualization
    x = Dropout(0.2)(x)
    
    # Output block
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    outputs = Dense(1, activation='sigmoid')(x)
    
    # Create the full model
    model = Model(inputs=inputs, outputs=outputs)
    
    # Create a feature extraction model for visualization
    feature_model = Model(inputs=model.input, outputs=feature_maps)
    
    # Compile model
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Early stopping to prevent overfitting
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
    
    # Train model
    history = model.fit(
        X_train, y_train,
        epochs=50,
        batch_size=32,
        validation_data=(X_val, y_val),
        callbacks=[early_stopping],
        verbose=1
    )
    
    # Evaluate model
    _, accuracy = model.evaluate(X_val, y_val)
    print(f'Model validation accuracy: {accuracy * 100:.2f}%')
    
    return model, history, feature_model

# Visualize the training process
def plot_learning_history(history):
    plt.figure(figsize=(7, 4.2))
    
    # Accuracy plot
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.grid(True, alpha=0.3)
    
    # Loss plot
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('learning_history.png', dpi=100)  # Save plot to file
    plt.close()  # Close figure to save memory
    print("Learning history plot saved to 'learning_history.png'")

# Predict on the last available window
def make_prediction(model, data, window_size=48):
    # Get the last window of data
    last_window = data.iloc[-window_size:][['open', 'high', 'low', 'close']]
    
    # Normalize data
    scaler = MinMaxScaler(feature_range=(0, 1))
    last_window_scaled = scaler.fit_transform(last_window)
    
    # Prepare data for the model
    last_window_reshaped = np.array([last_window_scaled])
    
    # Get prediction
    prediction = model.predict(last_window_reshaped)[0][0]
    
    # Interpret result
    direction = "UP ▲" if prediction > 0.5 else "DOWN ▼"
    confidence = prediction if prediction > 0.5 else 1 - prediction
    
    return direction, confidence * 100, last_window_scaled

# Visualize the price prediction
def plot_prediction(data, window_size=48, prediction_window=24, direction="UP ▲"):
    plt.figure(figsize=(7, 4.2))
    
    # Get the last window of data for visualization
    last_window = data.iloc[-window_size:]
    
    # Create time index for prediction
    last_date = last_window.index[-1]
    future_dates = pd.date_range(start=last_date, periods=prediction_window+1, freq=data.index.to_series().diff().mode()[0])[1:]
    
    # Plot closing prices
    plt.plot(last_window.index, last_window['close'], label='Historical Data')
    
    # Add marker for current price
    current_price = last_window['close'].iloc[-1]
    plt.scatter(last_window.index[-1], current_price, color='blue', s=100, zorder=5)
    plt.annotate(f'Current price: {current_price:.5f}', 
                 xy=(last_window.index[-1], current_price),
                 xytext=(10, -30),
                 textcoords='offset points',
                 fontsize=10,
                 arrowprops=dict(arrowstyle='->', color='black'))
    
    # Visualize the predicted direction
    arrow_start = (last_window.index[-1], current_price)
    
    # Calculate range for the arrow (approximately 10% of price range)
    price_range = last_window['high'].max() - last_window['low'].min()
    arrow_length = price_range * 0.1
    
    # Up or down prediction
    if direction == "UP ▲":
        arrow_end = (future_dates[-1], current_price + arrow_length)
        arrow_color = 'green'
    else:
        arrow_end = (future_dates[-1], current_price - arrow_length)
        arrow_color = 'red'
    
    # Direction arrow
    plt.annotate('', 
                 xy=arrow_end,
                 xytext=arrow_start,
                 arrowprops=dict(arrowstyle='->', lw=2, color=arrow_color))
    
    plt.title(f'EURUSD - Forecast for {prediction_window} periods: {direction}')
    plt.xlabel('Date')
    plt.ylabel('Closing Price')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig('prediction.png', dpi=100)  # Save plot to file
    plt.close()  # Close figure to save memory
    print("Prediction plot saved to 'prediction.png'")

# Visualize how the model "sees" the market
def visualize_model_perception(feature_model, last_window_scaled, window_size=48):
    # Get feature maps for the last window
    feature_maps = feature_model.predict(np.array([last_window_scaled]))[0]
    
    # Create time indices for the x-axis
    time_indices = np.arange(feature_maps.shape[0])
    
    # Plot feature maps
    plt.figure(figsize=(7, 10))
    
    # Plot original data
    plt.subplot(5, 1, 1)
    plt.title("Original Price Data (Normalized)")
    plt.plot(np.arange(window_size), last_window_scaled[:, 0], label='Open', alpha=0.7)
    plt.plot(np.arange(window_size), last_window_scaled[:, 1], label='High', alpha=0.7)
    plt.plot(np.arange(window_size), last_window_scaled[:, 2], label='Low', alpha=0.7)
    plt.plot(np.arange(window_size), last_window_scaled[:, 3], label='Close', color='black', linewidth=2)
    plt.xlabel('Time Steps')
    plt.ylabel('Normalized Price')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot candlestick representation
    plt.subplot(5, 1, 2)
    plt.title("Candlestick Representation")
    
    # Width of the candles
    width = 0.6
    
    for i in range(len(last_window_scaled)):
        # Candle color
        if last_window_scaled[i, 3] >= last_window_scaled[i, 0]:  # close >= open
            color = 'green'
            body_bottom = last_window_scaled[i, 0]  # open
            body_height = last_window_scaled[i, 3] - last_window_scaled[i, 0]  # close - open
        else:
            color = 'red'
            body_bottom = last_window_scaled[i, 3]  # close
            body_height = last_window_scaled[i, 0] - last_window_scaled[i, 3]  # open - close
        
        # Candle body
        plt.bar(i, body_height, bottom=body_bottom, color=color, width=width, alpha=0.5)
        
        # Candle wicks
        plt.plot([i, i], [last_window_scaled[i, 2], last_window_scaled[i, 1]], color='black', linewidth=1)
    
    plt.xlabel('Time Steps')
    plt.ylabel('Normalized Price')
    plt.grid(True, alpha=0.3)
    
    # Plot feature maps visualization (first 3 channels)
    num_channels = min(3, feature_maps.shape[-1])
    
    plt.subplot(5, 1, 3)
    plt.title("Feature Map Channel 1 - Pattern Recognition")
    plt.plot(time_indices, feature_maps[:, 0], color='blue')
    plt.xlabel('Time Steps')
    plt.ylabel('Activation')
    plt.grid(True, alpha=0.3)
    
    if num_channels > 1:
        plt.subplot(5, 1, 4)
        plt.title("Feature Map Channel 2 - Trend Detection")
        plt.plot(time_indices, feature_maps[:, 1], color='orange')
        plt.xlabel('Time Steps')
        plt.ylabel('Activation')
        plt.grid(True, alpha=0.3)
    
    if num_channels > 2:
        plt.subplot(5, 1, 5)
        plt.title("Feature Map Channel 3 - Volatility Detection")
        plt.plot(time_indices, feature_maps[:, 2], color='green')
        plt.xlabel('Time Steps')
        plt.ylabel('Activation')
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('model_perception.png', dpi=100)  # Save plot to file
    plt.close()  # Close figure to save memory
    print("Model perception visualization saved to 'model_perception.png'")

# Create a heatmap to show which time steps the model focuses on
def visualize_attention_heatmap(feature_model, last_window_scaled, window_size=48):
    # Get feature maps for the last window
    feature_maps = feature_model.predict(np.array([last_window_scaled]))[0]
    
    # Average activation across all channels to get a measure of "attention"
    avg_activation = np.mean(np.abs(feature_maps), axis=1)
    
    # Normalize to [0, 1] for visualization
    attention = (avg_activation - np.min(avg_activation)) / (np.max(avg_activation) - np.min(avg_activation))
    
    # Upsample attention to match original window size
    upsampled_attention = np.zeros(window_size)
    ratio = window_size / len(attention)
    
    for i in range(len(attention)):
        start_idx = int(i * ratio)
        end_idx = int((i+1) * ratio)
        upsampled_attention[start_idx:end_idx] = attention[i]
    
    # Plot the heatmap
    plt.figure(figsize=(7, 6))
    
    # Price plot with attention heatmap
    plt.subplot(2, 1, 1)
    plt.title("Model Attention Heatmap")
    
    # Plot close prices
    time_indices = np.arange(window_size)
    plt.plot(time_indices, last_window_scaled[:, 3], color='black', linewidth=2, label='Close Price')
    
    # Add shading based on attention
    plt.fill_between(time_indices, last_window_scaled[:, 3].min(), last_window_scaled[:, 3].max(), 
                     alpha=upsampled_attention * 0.5, color='red')
                     
    # Highlight points with high attention
    high_attention_threshold = 0.7
    high_attention_indices = np.where(upsampled_attention > high_attention_threshold)[0]
    plt.scatter(high_attention_indices, last_window_scaled[high_attention_indices, 3], 
               color='red', s=50, zorder=5, label='High Attention Points')
    
    plt.xlabel('Time Steps')
    plt.ylabel('Normalized Price')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Attention bar chart
    plt.subplot(2, 1, 2)
    plt.title("Model Attention Distribution")
    plt.bar(time_indices, upsampled_attention, color='blue', alpha=0.6)
    plt.xlabel('Time Steps')
    plt.ylabel('Attention Level')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('attention_heatmap.png', dpi=100)  # Save plot to file
    plt.close()  # Close figure to save memory
    print("Attention heatmap saved to 'attention_heatmap.png'")

# Main function
def main():
    print("Starting EURUSD prediction system with computer vision")
    
    # Connect to MT5
    if not connect_to_mt5():
        return
    
    print("Successfully connected to MetaTrader5")
    
    # Load historical data
    bars_to_load = 2000  # Load more than needed for training
    data = get_historical_data(num_bars=bars_to_load)
    if data is None:
        mt5.shutdown()
        return
    
    print(f"Loaded {len(data)} bars of EURUSD history")
    
    # Convert data to image format
    print("Converting data for computer vision processing...")
    images, targets = create_images(data)
    print(f"Created {len(images)} images for training")
    
    # Train model
    print("Training computer vision model...")
    model, history, feature_model = train_cv_model(images, targets)
    
    # Visualize training process
    plot_learning_history(history)
    
    # Prediction
    direction, confidence, last_window_scaled = make_prediction(model, data)
    print(f"Forecast for the next 24 periods: {direction} (confidence: {confidence:.2f}%)")
    
    # Visualize prediction
    plot_prediction(data, direction=direction)
    
    # Visualize how the model "sees" the market
    visualize_model_perception(feature_model, last_window_scaled)
    
    # Visualize attention heatmap
    visualize_attention_heatmap(feature_model, last_window_scaled)
    
    # Disconnect from MT5
    mt5.shutdown()
    print("Work completed")

if __name__ == "__main__":
    main()
