import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import pytz

# Path to MetaTrader 5 terminal
terminal_path = "C:/Program Files/ForexBroker - MetaTrader 5/Arima/terminal64.exe"


def remove_duplicate_indices(df):
    """Remove duplicate indices, while saving only the first string with a unique index."""
    return df[~df.index.duplicated(keep="first")]


def get_historical_data(start_date, end_date, terminal_path):
    if not mt5.initialize(path=terminal_path):
        print(f"Failed to connect to MetaTrader 5 terminal at {terminal_path}")
        return None

    symbols = [
        "AUDUSD",
        "AUDJPY",
        "CADJPY",
        "AUDCHF",
        "AUDNZD",
        "USDCAD",
        "USDCHF",
        "USDJPY",
        "NZDUSD",
        "GBPUSD",
        "EURUSD",
        "CADCHF",
        "CHFJPY",
        "NZDCAD",
        "NZDCHF",
        "NZDJPY",
        "GBPCAD",
        "GBPCHF",
        "GBPJPY",
        "GBPNZD",
        "EURCAD",
        "EURCHF",
        "EURGBP",
        "EURJPY",
        "EURNZD",
    ]

    historical_data = {}
    for symbol in symbols:
        timeframe = mt5.TIMEFRAME_M1
        rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
        if rates is not None and len(rates) > 0:
            df = pd.DataFrame(rates)
            df["time"] = pd.to_datetime(df["time"], unit="s")
            df.set_index("time", inplace=True)
            df = df[["open", "high", "low", "close"]]
            df["bid"] = df["close"]  # Simplification: use 'close' as 'bid'
            df["ask"] = df["close"] + 0.000001  # Simplification: add spread
            historical_data[symbol] = df

    mt5.shutdown()
    return historical_data


def calculate_synthetic_prices(data):
    synthetic_prices = {}
    pairs = [
        ("AUDUSD", "USDCHF"),
        ("AUDUSD", "NZDUSD"),
        ("AUDUSD", "USDJPY"),
        ("USDCHF", "USDCAD"),
        ("USDCHF", "NZDCHF"),
        ("USDCHF", "CHFJPY"),
        ("USDJPY", "USDCAD"),
        ("USDJPY", "NZDJPY"),
        ("USDJPY", "GBPJPY"),
        ("NZDUSD", "NZDCAD"),
        ("NZDUSD", "NZDCHF"),
        ("NZDUSD", "NZDJPY"),
        ("GBPUSD", "GBPCAD"),
        ("GBPUSD", "GBPCHF"),
        ("GBPUSD", "GBPJPY"),
        ("EURUSD", "EURCAD"),
        ("EURUSD", "EURCHF"),
        ("EURUSD", "EURJPY"),
        ("CADCHF", "CADJPY"),
        ("CADCHF", "GBPCAD"),
        ("CADCHF", "EURCAD"),
        ("CHFJPY", "GBPCHF"),
        ("CHFJPY", "EURCHF"),
        ("CHFJPY", "NZDCHF"),
        ("NZDCAD", "NZDJPY"),
        ("NZDCAD", "GBPNZD"),
        ("NZDCAD", "EURNZD"),
        ("NZDCHF", "NZDJPY"),
        ("NZDCHF", "GBPNZD"),
        ("NZDCHF", "EURNZD"),
        ("NZDJPY", "GBPNZD"),
        ("NZDJPY", "EURNZD"),
    ]

    for pair1, pair2 in pairs:
        if pair1 in data and pair2 in data:
            synthetic_prices[f"{pair1}_{pair2}_1"] = (
                data[pair1]["bid"] / data[pair2]["ask"]
            )
            synthetic_prices[f"{pair1}_{pair2}_2"] = (
                data[pair1]["bid"] / data[pair2]["bid"]
            )

    return pd.DataFrame(synthetic_prices)


def analyze_arbitrage(data, synthetic_prices):
    spreads = {}
    for pair in data.keys():
        for synth_pair in synthetic_prices.columns:
            if pair in synth_pair:
                spreads[synth_pair] = data[pair]["bid"] - synthetic_prices[synth_pair]

    arbitrage_opportunities = pd.DataFrame(spreads) > 0.00008
    return arbitrage_opportunities


def simulate_trade(data, direction, entry_price, take_profit, stop_loss):
    for i, row in data.iterrows():
        current_price = row["bid"] if direction == "BUY" else row["ask"]

        if direction == "BUY":
            if current_price >= entry_price + take_profit:
                return {"profit": take_profit * 800, "duration": i}
            elif current_price <= entry_price - stop_loss:
                return {"profit": -stop_loss * 400, "duration": i}
        else:  # SELL
            if current_price <= entry_price - take_profit:
                return {"profit": take_profit * 800, "duration": i}
            elif current_price >= entry_price + stop_loss:
                return {"profit": -stop_loss * 400, "duration": i}

    # If the loop completes without hitting TP or SL, close at the last price
    last_price = data["bid"].iloc[-1] if direction == "BUY" else data["ask"].iloc[-1]
    profit = (
        (last_price - entry_price) * 100000
        if direction == "BUY"
        else (entry_price - last_price) * 100000
    )
    return {"profit": profit, "duration": len(data)}


def backtest_arbitrage_system(historical_data, start_date, end_date):
    equity_curve = [10000]  # Starting with $10,000
    trades = []
    dates = pd.date_range(start=start_date, end=end_date, freq="D")

    for current_date in dates:
        print(f"Backtesting for date: {current_date.date()}")

        # Get data for the current day
        data = {
            symbol: df[df.index.date == current_date.date()]
            for symbol, df in historical_data.items()
        }

        # Skip if no data for the current day
        if all(df.empty for df in data.values()):
            continue

        synthetic_prices = calculate_synthetic_prices(data)
        arbitrage_opportunities = analyze_arbitrage(data, synthetic_prices)

        # Simulate trades based on arbitrage opportunities
        for symbol in arbitrage_opportunities.columns:
            if arbitrage_opportunities[symbol].any():
                direction = "BUY" if arbitrage_opportunities[symbol].iloc[0] else "SELL"
                base_symbol = symbol.split("_")[0]
                if base_symbol in data and not data[base_symbol].empty:
                    price = (
                        data[base_symbol]["bid"].iloc[-1]
                        if direction == "BUY"
                        else data[base_symbol]["ask"].iloc[-1]
                    )
                    take_profit = 800 * 0.00001  # Convert to price
                    stop_loss = 400 * 0.00001  # Convert to price

                    # Simulate trade
                    trade_result = simulate_trade(
                        data[base_symbol], direction, price, take_profit, stop_loss
                    )
                    trades.append(trade_result)

                    # Update equity curve
                    equity_curve.append(equity_curve[-1] + trade_result["profit"])

    return equity_curve, trades


def main():
    start_date = datetime(2024, 1, 1, tzinfo=pytz.UTC)
    end_date = datetime(2024, 8, 31, tzinfo=pytz.UTC)  # Backtest for January 2023

    print("Fetching historical data...")
    historical_data = get_historical_data(start_date, end_date, terminal_path)

    if historical_data is None:
        print("Failed to fetch historical data. Exiting.")
        return

    print("Starting backtest...")
    equity_curve, trades = backtest_arbitrage_system(
        historical_data, start_date, end_date
    )

    total_profit = sum(trade["profit"] for trade in trades)
    win_rate = (
        sum(1 for trade in trades if trade["profit"] > 0) / len(trades) if trades else 0
    )

    print(f"Backtest completed. Results:")
    print(f"Total Profit: ${total_profit:.2f}")
    print(f"Win Rate: {win_rate:.2%}")
    print(f"Final Equity: ${equity_curve[-1]:.2f}")

    # Plot equity curve
    plt.figure(figsize=(15, 10))
    plt.plot(equity_curve)
    plt.title("Equity Curve: Backtest Results")
    plt.xlabel("Trade Number")
    plt.ylabel("Account Balance ($)")
    plt.savefig("equity_curve.png")
    plt.close()

    print("Equity curve saved as 'equity_curve.png'.")


if __name__ == "__main__":
    main()
