import MetaTrader5 as mt5
import pandas as pd
import time
from datetime import datetime, timedelta
import pytz
import pandas as pd

MAX_OPEN_TRADES = 10  # Maximum number of open trades


def remove_duplicate_indices(df):
    """Delete duplicating indices, while saving only the first string with a unique index."""
    df = df[~df.index.duplicated(keep="first")]
    return df


# Path to MetaTrader 5 terminal
terminal_path = "C:/Program Files/ForexBroker - MetaTrader 5/Arima/terminal64.exe"


def get_mt5_data(symbol, count, terminal_path):
    print(f"Fetching data for {symbol}")
    if not mt5.initialize(path=terminal_path):
        print(f"Failed to connect to MetaTrader 5 terminal at {terminal_path}")
        return None

    # set time zone to UTC
    timezone = pytz.timezone("Etc/UTC")
    # create 'datetime' object in UTC time zone to avoid the implementation of a local time zone offset
    utc_from = datetime.now(timezone) - timedelta(
        days=1
    )  # get data from the last 24 hours

    ticks = mt5.copy_ticks_from(symbol, utc_from, count, mt5.COPY_TICKS_ALL)
    if ticks is None:
        print(f"Failed to fetch data for {symbol}")
        return None

    ticks_frame = pd.DataFrame(ticks)
    ticks_frame["time"] = pd.to_datetime(ticks_frame["time"], unit="s")

    return ticks_frame


def get_currency_data():
    # Define currency pairs and their amount of data
    symbols = [
        "AUDUSD",
        "AUDJPY",
        "CADJPY",
        "AUDCHF",
        "AUDNZD",
        "USDCAD",
        "USDCHF",
        "USDJPY",
        "NZDUSD",
        "GBPUSD",
        "EURUSD",
        "CADCHF",
        "CHFJPY",
        "NZDCAD",
        "NZDCHF",
        "NZDJPY",
        "GBPCAD",
        "GBPCHF",
        "GBPJPY",
        "GBPNZD",
        "EURCAD",
        "EURCHF",
        "EURGBP",
        "EURJPY",
        "EURNZD",
    ]
    count = 1000  # number of data points for each currency pair

    data = {}
    for symbol in symbols:
        df = get_mt5_data(symbol, count, terminal_path)
        if df is not None:
            data[symbol] = df[["time", "bid", "ask"]].set_index("time")

    return data


def calculate_synthetic_prices(data):
    synthetic_prices = {}

    # Apply the remove_duplicate_indices function to all DataFrame in the data dictionary
    for key in data:
        data[key] = remove_duplicate_indices(data[key])

    # Calculate synthetic prices for all pairs using multiple methods
    pairs = [
        ("AUDUSD", "USDCHF"),
        ("AUDUSD", "NZDUSD"),
        ("AUDUSD", "USDJPY"),
        ("USDCHF", "USDCAD"),
        ("USDCHF", "NZDCHF"),
        ("USDCHF", "CHFJPY"),
        ("USDJPY", "USDCAD"),
        ("USDJPY", "NZDJPY"),
        ("USDJPY", "GBPJPY"),
        ("NZDUSD", "NZDCAD"),
        ("NZDUSD", "NZDCHF"),
        ("NZDUSD", "NZDJPY"),
        ("GBPUSD", "GBPCAD"),
        ("GBPUSD", "GBPCHF"),
        ("GBPUSD", "GBPJPY"),
        ("EURUSD", "EURCAD"),
        ("EURUSD", "EURCHF"),
        ("EURUSD", "EURJPY"),
        ("CADCHF", "CADJPY"),
        ("CADCHF", "GBPCAD"),
        ("CADCHF", "EURCAD"),
        ("CHFJPY", "GBPCHF"),
        ("CHFJPY", "EURCHF"),
        ("CHFJPY", "NZDCHF"),
        ("NZDCAD", "NZDJPY"),
        ("NZDCAD", "GBPNZD"),
        ("NZDCAD", "EURNZD"),
        ("NZDCHF", "NZDJPY"),
        ("NZDCHF", "GBPNZD"),
        ("NZDCHF", "EURNZD"),
        ("NZDJPY", "GBPNZD"),
        ("NZDJPY", "EURNZD"),
    ]

    method_count = 1
    for pair1, pair2 in pairs:
        print(
            f"Calculating synthetic price for {pair1} and {pair2} using method {method_count}"
        )
        synthetic_prices[f"{pair1}_{method_count}"] = (
            data[pair1]["bid"] / data[pair2]["ask"]
        )
        method_count += 1
        print(
            f"Calculating synthetic price for {pair1} and {pair2} using method {method_count}"
        )
        synthetic_prices[f"{pair1}_{method_count}"] = (
            data[pair1]["bid"] / data[pair2]["bid"]
        )
        method_count += 1

    return pd.DataFrame(synthetic_prices)


def analyze_arbitrage(data, synthetic_prices, method_count):
    # Calculate spreads for each pair
    spreads = {}

    for pair in data.keys():
        for i in range(1, method_count + 1):
            synthetic_pair = f"{pair}_{i}"
            if synthetic_pair in synthetic_prices.columns:
                print(f"Analyzing arbitrage opportunity for {synthetic_pair}")
                spreads[synthetic_pair] = (
                    data[pair]["bid"] - synthetic_prices[synthetic_pair]
                )

    # Identify arbitrage opportunities
    arbitrage_opportunities = pd.DataFrame(spreads) > 0.00008

    print("Arbitrage opportunities:")
    print(arbitrage_opportunities)

    # Save the full table of arbitrage opportunities to a CSV file
    arbitrage_opportunities.to_csv("arbitrage_opportunities.csv")

    return arbitrage_opportunities


def open_test_limit_order(
    symbol, order_type, price, volume, take_profit, stop_loss, terminal_path
):
    if not mt5.initialize(path=terminal_path):
        print(f"Failed to connect to MetaTrader 5 terminal at {terminal_path}")
        return None

    symbol_info = mt5.symbol_info(symbol)
    positions_total = mt5.positions_total()

    if symbol_info is None:
        print(f"Instrument not found: {symbol}")
        return None
    if positions_total >= MAX_OPEN_TRADES:
        print("MAX POSITIONS TOTAL!")
        return None

    # Check if symbol_info is None before accessing its attributes
    if symbol_info is not None:
        request = {
            "action": mt5.TRADE_ACTION_DEAL,
            "symbol": symbol,
            "volume": volume,
            "type": order_type,
            "price": price,
            "deviation": 30,
            "magic": 123456,
            "comment": "ShtencoArbitrage",
            "type_time": mt5.ORDER_TIME_GTC,
            "type_filling": mt5.ORDER_FILLING_IOC,
            "tp": (
                price + take_profit * symbol_info.point
                if order_type == mt5.ORDER_TYPE_BUY
                else price - take_profit * symbol_info.point
            ),
            "sl": (
                price - stop_loss * symbol_info.point
                if order_type == mt5.ORDER_TYPE_BUY
                else price + stop_loss * symbol_info.point
            ),
        }
        result = mt5.order_send(request)
        if result is not None and result.retcode == mt5.TRADE_RETCODE_DONE:
            print(f"Test limit order placed for {symbol}")
            return result.order
        else:
            print(
                f"Error: Test limit order not placed for {symbol}, retcode={result.retcode if result is not None else 'None'}"
            )
            return None
    else:
        print(f"Error: Symbol info not found for {symbol}")
        return None


def main():
    data = get_currency_data()
    synthetic_prices = calculate_synthetic_prices(data)
    method_count = 2000  # Define the method_count variable here
    arbitrage_opportunities = analyze_arbitrage(data, synthetic_prices, method_count)

    # Trade based on arbitrage opportunities
    for symbol in arbitrage_opportunities.columns:
        if arbitrage_opportunities[symbol].any():
            direction = "BUY" if arbitrage_opportunities[symbol].iloc[0] else "SELL"
            symbol = symbol.split("_")[0]  # Remove the index from the symbol
            symbol_info = mt5.symbol_info_tick(symbol)
            if symbol_info is not None:
                price = symbol_info.bid if direction == "BUY" else symbol_info.ask
                take_profit = 450
                stop_loss = 200
                order = open_test_limit_order(
                    symbol,
                    mt5.ORDER_TYPE_BUY if direction == "BUY" else mt5.ORDER_TYPE_SELL,
                    price,
                    0.50,
                    take_profit,
                    stop_loss,
                    terminal_path,
                )
            else:
                print(f"Error: Symbol info tick not found for {symbol}")


if __name__ == "__main__":
    while True:
        current_time = datetime.now(pytz.timezone("Etc/UTC")).time()
        if (
            current_time >= datetime.strptime("23:30", "%H:%M").time()
            or current_time <= datetime.strptime("05:00", "%H:%M").time()
        ):
            print("Current time is between 23:30 and 05:00. Skipping execution.")
            time.sleep(300)  # Wait for 5 minutes before checking again
            continue

        main()
        time.sleep(300)  # Wait for 5 minutes before the next execution
