
import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
import pytz
from typing import Dict, Optional, List, Tuple
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

class ArbitrageModule:
    def __init__(self, terminal_path: str = "C:/Program Files/RannForex MetaTrader 5/terminal64.exe", max_trades: int = 10):
        self.terminal_path = terminal_path
        self.MAX_OPEN_TRADES = max_trades
        self.z_scores = {}
        self.symbols = [
            "AUDUSD.ecn", "AUDJPY.ecn", "CADJPY.ecn", "AUDCHF.ecn", "AUDNZD.ecn", 
            "USDCAD.ecn", "USDCHF.ecn", "USDJPY.ecn", "NZDUSD.ecn", "GBPUSD.ecn", 
            "EURUSD.ecn", "CADCHF.ecn", "CHFJPY.ecn", "NZDCAD.ecn", "NZDCHF.ecn", 
            "NZDJPY.ecn", "GBPCAD.ecn", "GBPCHF.ecn", "GBPJPY.ecn", "GBPNZD.ecn", 
            "EURCAD.ecn", "EURCHF.ecn", "EURGBP.ecn", "EURJPY.ecn", "EURNZD.ecn"
        ]
        
        # Base pairs for calculating cross rates
        self.usd_pairs = {
            "EUR": "EURUSD.ecn",
            "GBP": "GBPUSD.ecn", 
            "AUD": "AUDUSD.ecn",
            "NZD": "NZDUSD.ecn",
            "USD": None,
            "CAD": ("USDCAD.ecn", True),
            "CHF": ("USDCHF.ecn", True),
            "JPY": ("USDJPY.ecn", True)
        }

    def get_mt5_data(self, symbol: str, count: int = 1000) -> Optional[pd.DataFrame]:
        """Get data from MT5"""
        try:
            if not mt5.initialize(path=self.terminal_path):
                logger.error(f"Failed to connect to MetaTrader 5 terminal at {self.terminal_path}")
                return None

            timezone = pytz.timezone("Etc/UTC")
            utc_from = datetime.now(timezone) - timedelta(days=1)

            ticks = mt5.copy_ticks_from(symbol, utc_from, count, mt5.COPY_TICKS_ALL)
            if ticks is None:
                logger.error(f"Failed to fetch data for {symbol}")
                return None

            ticks_frame = pd.DataFrame(ticks)
            ticks_frame['time'] = pd.to_datetime(ticks_frame['time'], unit='s')
            return ticks_frame

        except Exception as e:
            logger.error(f"Error getting MT5 data for {symbol}: {str(e)}")
            return None

    def get_currency_data(self, count: int = 1000) -> Dict[str, pd.DataFrame]:
        """Get data on all currency pairs"""
        data = {}
        for symbol in self.symbols:
            try:
                df = self.get_mt5_data(symbol, count)
                if df is not None:
                    data[symbol] = df[['time', 'bid', 'ask']].set_index('time')
                    data[symbol]['close'] = (data[symbol]['bid'] + data[symbol]['ask']) / 2
                    data[symbol] = data[symbol][~data[symbol].index.duplicated(keep='first')]
            except Exception as e:
                logger.error(f"Data error {symbol}: {str(e)}")
        return data

    def get_usd_rate(self, currency: str, data: dict) -> float:
        """Get currency rate relative to USD"""
        if currency == "USD":
            return 1.0
            
        pair_info = self.usd_pairs[currency]
        if isinstance(pair_info, tuple):
            pair, inverse = pair_info
            rate = data[pair]['close'].iloc[-1]
            return 1 / rate if inverse else rate
        else:
            pair = pair_info
            return data[pair]['close'].iloc[-1]

    def calculate_cross_rate(self, base: str, quote: str, data: dict) -> float:
        """Calculate cross rate via USD"""
        base_usd = self.get_usd_rate(base, data)
        quote_usd = self.get_usd_rate(quote, data)
        return base_usd / quote_usd

    def calculate_synthetic_prices(self, data: Dict[str, pd.DataFrame]) -> pd.DataFrame:
        """Calculate synthetic prices via cross rates"""
        synthetic_prices = {}
        
        try:
            for symbol in self.symbols:
                base = symbol[:3]
                quote = symbol[3:6]
                
                # Calculate synthetic price via cross rates
                fair_price = self.calculate_cross_rate(base, quote, data)
                synthetic_prices[f'{symbol}_fair'] = pd.Series([fair_price])
                
                # Calculate the current price
                current_price = data[symbol]['close'].iloc[-1]
                deviation = current_price - fair_price
                
                # Display data in the log
                logger.info(f"Symbol: {symbol}, Current: {current_price:.5f}, Fair: {fair_price:.5f}, "
                          f"Deviation: {abs(deviation)*10000:.1f} pips")
                
            return pd.DataFrame(synthetic_prices)

        except Exception as e:
            logger.error(f"Error calculating synthetic prices: {str(e)}")
            return pd.DataFrame()

    def check_trading_time(self) -> bool:
        """Check trading time"""
        current_time = datetime.now(pytz.timezone("Etc/UTC")).time()
        start_time = datetime.strptime("05:00", "%H:%M").time()
        end_time = datetime.strptime("23:30", "%H:%M").time()
        return start_time <= current_time <= end_time

if __name__ == "__main__":
    arbitrage = ArbitrageModule()
    try:
        arbitrage.run()
    except KeyboardInterrupt:
        logger.info("\nArbitrage module stopped.")
    finally:
        mt5.shutdown()
