import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import matplotlib

matplotlib.use("Agg")  # Agg backend for saving charts
import matplotlib.pyplot as plt
import scipy.stats as stats
import seaborn as sns
import logging
import time
import os
from threading import Thread
import json
from datetime import datetime, timedelta

# Logging setting
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Output directories
output_dir = "charts"
data_dir = "data"
for directory in [output_dir, data_dir]:
    if not os.path.exists(directory):
        os.makedirs(directory)


# --- MT5 initialization ---
def initialize_mt5(max_retries=3, retry_delay=5):
    """Connecting to open MT5 terminal"""
    try:
        for attempt in range(max_retries):
            if not mt5.initialize(timeout=10000):
                logger.error(
                    f"Attempt {attempt + 1}/{max_retries} - MT5 initialization error: {mt5.last_error()}"
                )
                if attempt < max_retries - 1:
                    logger.info(
                        f"Waiting {retry_delay} seconds before the next attempt..."
                    )
                    time.sleep(retry_delay)
                continue
            if mt5.account_info() is None:
                logger.error("Terminal not authorized, enter MT5 manually")
                mt5.shutdown()
                return False
            logger.info(
                f"MT5 connected successfully, account number: {mt5.account_info().login}"
            )
            return True
        logger.error(f"Failed to initialize MT5 after {max_retries} attempts")
        return False
    except Exception as e:
        logger.error(f"Error in initialize_mt5: {str(e)}")
        return False


# --- Get data ---
def fetch_historical_data(
    symbol, timeframe=mt5.TIMEFRAME_M5, start_date=None, end_date=None
):
    """Download historical data for the specified period"""
    try:
        if not mt5.symbol_select(symbol, True):
            logger.error(f"{symbol} not found in Market Watch")
            return None

        # Set dates (the default is the last 2 months from the current date)
        if start_date is None:
            start_date = datetime.now() - timedelta(days=60)  # 2 months ago
        if end_date is None:
            end_date = datetime.now()

        rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
        if rates is None or len(rates) == 0:
            logger.error(f"Failed to download data for {symbol}: {mt5.last_error()}")
            return None

        df = pd.DataFrame(rates)
        df["time"] = pd.to_datetime(df["time"], unit="s")
        df.set_index("time", inplace=True)
        df.to_csv(os.path.join(data_dir, f"{symbol}_M5_2months.csv"))
        logger.info(
            f"Data for {symbol} downloaded and saved: {len(df)} entries (from {start_date} to {end_date})"
        )
        return df[["open", "high", "low", "close", "spread"]]
    except Exception as e:
        logger.error(f"Error in fetch_historical_data for {symbol}: {str(e)}")
        return None


def fetch_tick_data(symbol, hours=24):
    """Download tick data for the last N hours"""
    try:
        from_time = datetime.now() - timedelta(hours=hours)
        ticks = mt5.copy_ticks_from(symbol, from_time, 100000, mt5.COPY_TICKS_ALL)
        if ticks is None or len(ticks) == 0:
            logger.error(f"Failed to download ticks for {symbol}: {mt5.last_error()}")
            return None
        df = pd.DataFrame(ticks)
        df["time"] = pd.to_datetime(df["time"], unit="s")
        df.set_index("time", inplace=True)
        logger.info(f"Downloaded tick data for {symbol}: {len(df)} entries")
        return df[["bid", "ask"]]
    except Exception as e:
        logger.error(f"Error in fetch_tick_data for {symbol}: {str(e)}")
        return None


def sync_dataframes(dfs):
    """Synchronize data by time"""
    try:
        common_index = dfs[0].index
        for df in dfs[1:]:
            common_index = common_index.intersection(df.index)
        synced_dfs = [df.loc[common_index] for df in dfs]
        logger.info(f"Data synchronized, total size: {len(common_index)}")
        return synced_dfs
    except Exception as e:
        logger.error(f"Error in sync_dataframes: {str(e)}")
        return None


# --- Calculate synthetic rates and imbalances ---
def calculate_synthetic_rate(eurusd, gbpusd, eurgbp, normalize=False):
    """Calculate EURGBP synthetic pair"""
    try:
        synthetic = eurusd["close"] / gbpusd["close"]
        if normalize:
            synthetic = (synthetic - synthetic.mean()) / synthetic.std()
        logger.info("Synthetic EURGBP calculated")
        return synthetic
    except Exception as e:
        logger.error(f"Error in calculate_synthetic_rate: {str(e)}")
        return None


def compute_imbalance(real_eurgbp, synthetic_eurgbp):
    """Calculate imbalance"""
    try:
        imbalance = real_eurgbp["close"] - synthetic_eurgbp
        logger.info("Imbalance calculated")
        return imbalance
    except Exception as e:
        logger.error(f"Error in compute_imbalance: {str(e)}")
        return None


def monitor_imbalance_real_time(symbols, interval=5):
    """Monitoring imbalance in real time"""
    try:
        while True:
            bids = {symbol: mt5.symbol_info_tick(symbol).bid for symbol in symbols}
            if any(bid is None for bid in bids.values()):
                logger.error("Failed to get quotes in real time")
                continue
            synthetic = bids["EURUSD"] / bids["GBPUSD"]
            imbalance = bids["EURGBP"] - synthetic
            logger.info(f"Current imbalance: {imbalance:.6f}")
            time.sleep(interval)
    except Exception as e:
        logger.error(f"Error on monitor_imbalance_real_time: {str(e)}")


# --- Visualization ---
# Changes in the plot_pairs_and_imbalance function
def plot_pairs_and_imbalance(eurusd, gbpusd, eurgbp, synthetic_eurgbp, imbalance):
    """Pair and imbalance chart"""
    try:
        # Fixed width 750px
        width_px = 750
        dpi = 100  # Standard value DPI
        width_inches = width_px / dpi
        height_inches = width_inches * 0.7  # Correlation of sides ~1.4

        fig, (ax1, ax2) = plt.subplots(
            2, 1, figsize=(width_inches, height_inches), sharex=True, dpi=dpi
        )
        ax1.plot(eurgbp.index, eurgbp["close"], label="Real EURGBP", color="blue")
        ax1.plot(
            synthetic_eurgbp.index,
            synthetic_eurgbp,
            label="Synthetic EURGBP",
            color="orange",
            linestyle="--",
        )
        ax1.set_title("Real vs Synthetic EURGBP")
        ax1.legend()

        ax2.plot(imbalance.index, imbalance, label="Imbalance", color="red")
        ax2.axhline(0, color="black", linestyle="--")
        ax2.set_title("Imbalance (Real - Synthetic)")
        ax2.legend()

        plt.tight_layout()
        output_file = os.path.join(output_dir, "pairs_and_imbalance_2months.png")
        plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
        plt.close(fig)
        logger.info(f"Chart saved: {output_file}")
    except Exception as e:
        logger.error(f"Error in plot_pairs_and_imbalance: {str(e)}")


# Changes in the plot_imbalance_distribution function
def plot_imbalance_distribution(imbalance):
    """Histogram and box-and-whisker"""
    try:
        width_px = 750
        dpi = 100
        width_inches = width_px / dpi
        height_inches = width_inches * 0.5  # Correlation of sides 2:1

        fig, (ax1, ax2) = plt.subplots(
            1, 2, figsize=(width_inches, height_inches), dpi=dpi
        )
        sns.histplot(imbalance.dropna(), bins=50, kde=True, color="purple", ax=ax1)
        ax1.set_title("Imbalance Distribution (2 Months)")
        sns.boxplot(x=imbalance.dropna(), color="purple", ax=ax2)
        ax2.set_title("Imbalance Boxplot (2 Months)")

        plt.tight_layout()
        output_file = os.path.join(output_dir, "imbalance_distribution_2months.png")
        plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
        plt.close(fig)
        logger.info(f"Histogram and box saved: {output_file}")
    except Exception as e:
        logger.error(f"Error in plot_imbalance_distribution: {str(e)}")


# Changes in the plot_heatmap function
def plot_heatmap(imbalance, timeframe="H"):
    """Imbalance heat map"""
    try:
        width_px = 750
        dpi = 100
        width_inches = width_px / dpi
        height_inches = width_inches * 0.5  # Correlation of sides 2:1

        df = pd.DataFrame({"imbalance": imbalance})
        df["hour"] = df.index.hour
        df["day"] = df.index.dayofweek
        heatmap_data = df.pivot_table(
            values="imbalance", index="day", columns="hour", aggfunc="mean"
        )

        plt.figure(figsize=(width_inches, height_inches), dpi=dpi)
        sns.heatmap(heatmap_data, cmap="RdBu", center=0)
        plt.title("Imbalance Heatmap by Hour and Day (2 Months)")

        plt.tight_layout()
        output_file = os.path.join(output_dir, "imbalance_heatmap_2months.png")
        plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
        plt.close()
        logger.info(f"Heat map saved: {output_file}")
    except Exception as e:
        logger.error(f"Error in plot_heatmap: {str(e)}")


# Changes in the backtest_strategy function
def backtest_strategy(
    imbalance, eurgbp, threshold=0.000126, ema_period=20, session="all"
):
    """Optimized scalping strategy with real spread"""
    try:
        # Filter by session
        df = pd.DataFrame({"imbalance": imbalance})
        df["hour"] = df.index.hour
        if session == "asian":
            df = df[(df["hour"] >= 0) & (df["hour"] < 8)]
        elif session == "european":
            df = df[(df["hour"] >= 8) & (df["hour"] < 16)]
        elif session == "american":
            df = df[(df["hour"] >= 16) & (df["hour"] <= 23)]
        imbalance = df["imbalance"]

        ema = imbalance.ewm(span=ema_period, adjust=False).mean()
        signals = pd.Series(0, index=imbalance.index)
        signals[imbalance > ema + threshold] = -1  # Sell
        signals[imbalance < ema - threshold] = 1  # Buy

        # Exit when crossing EMA
        exits = ((signals.shift(1) == 1) & (imbalance > ema)) | (
            (signals.shift(1) == -1) & (imbalance < ema)
        )
        signals[exits] = 0

        # Real spread from data
        spread = (
            eurgbp["spread"].reindex(signals.index, method="ffill") / 10000
        )  # Convert to price
        returns = signals.shift(1) * imbalance.diff() - spread * signals.abs()
        cumulative_returns = returns.cumsum()

        # Fixed width of 750px
        width_px = 750
        dpi = 100
        width_inches = width_px / dpi
        height_inches = width_inches * 0.6  # Correlation of sides ~1.7

        plt.figure(figsize=(width_inches, height_inches), dpi=dpi)
        plt.plot(cumulative_returns, label="Cumulative Returns")
        plt.title(
            f"Optimized Backtest Results ({session.capitalize()} Session, 2 Months)"
        )
        plt.legend()

        plt.tight_layout()
        output_file = os.path.join(
            output_dir, f"optimized_backtest_{session}_2months.png"
        )
        plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
        plt.close()

        logger.info(
            f"[{session.capitalize()}] Total profit: {cumulative_returns.iloc[-1]:.6f}"
        )
        logger.info(
            f"[{session.capitalize()}] Maximum drawdown: {-(cumulative_returns - cumulative_returns.cummax()).min():.6f}"
        )
        logger.info(f"[{session.capitalize()}] Number of trades: {signals.abs().sum()}")
        logger.info(f"[{session.capitalize()}] Average spread: {spread.mean():.6f}")
    except Exception as e:
        logger.error(f"Error on backtest_strategy: {str(e)}")


# Changes in the correlation_analysis function
def correlation_analysis(imbalances):
    """Imbalance correlation matrix"""
    try:
        df = pd.DataFrame(imbalances)
        corr_matrix = df.corr()

        # Fixed width of 750px
        width_px = 750
        dpi = 100
        width_inches = width_px / dpi
        height_inches = width_inches * 0.8  # Correlation of sides 1.25:1

        plt.figure(figsize=(width_inches, height_inches), dpi=dpi)
        sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", center=0)
        plt.title("Correlation of Imbalances (2 Months)")

        plt.tight_layout()
        output_file = os.path.join(output_dir, "correlation_matrix_2months.png")
        plt.savefig(output_file, dpi=dpi, bbox_inches="tight")
        plt.close()
        logger.info(f"Correlation matrix saved: {output_file}")
    except Exception as e:
        logger.error(f"Error in correlation_analysis: {str(e)}")


# --- Statistical analysis ---
def analyze_imbalance_stats(imbalance):
    """Full statistical analysis"""
    try:
        stats_dict = {
            "Mean": imbalance.mean(),
            "Std Dev": imbalance.std(),
            "Min": imbalance.min(),
            "Max": imbalance.max(),
            "Skewness": stats.skew(imbalance.dropna()),
            "Kurtosis": stats.kurtosis(imbalance.dropna()),
        }
        normality_p = (
            stats.shapiro(imbalance.dropna())[1]
            if len(imbalance.dropna()) > 3
            else float("nan")
        )
        autocorr = pd.Series(imbalance).autocorr()

        logger.info("Imbalance statistics (2 months):")
        for key, value in stats_dict.items():
            logger.info(f"{key}: {value:.6f}")
        logger.info(f"Normality test (p-value): {normality_p:.6f}")
        logger.info(f"Auto correlation (lag 1): {autocorr:.6f}")

        with open(os.path.join(data_dir, "imbalance_stats_2months.json"), "w") as f:
            json.dump(stats_dict, f)
    except Exception as e:
        logger.error(f"Error in analyze_imbalance_stats: {str(e)}")


def segment_by_time(imbalance):
    """Segmentation by time"""
    try:
        df = pd.DataFrame({"imbalance": imbalance})
        df["hour"] = df.index.hour
        df["day"] = df.index.dayofweek

        sessions = {
            "Asian": df[(df["hour"] >= 0) & (df["hour"] < 8)]["imbalance"],
            "European": df[(df["hour"] >= 8) & (df["hour"] < 16)]["imbalance"],
            "American": df[(df["hour"] >= 16) & (df["hour"] <= 23)]["imbalance"],
        }

        logger.info("Average imbalance by sessions (2 months):")
        for session, data in sessions.items():
            logger.info(f"{session}: {data.mean():.6f} (Std: {data.std():.6f})")
    except Exception as e:
        logger.error(f"Error in segment_by_time: {str(e)}")


# --- Filtering noise ---
def filter_imbalance(imbalance, method="ema", window=20, threshold=2):
    """Filtering imbalances"""
    try:
        if method == "ma":
            smoothed = imbalance.rolling(window=window, center=True).mean()
        elif method == "ema":
            smoothed = imbalance.ewm(span=window, adjust=False).mean()
        std = imbalance.rolling(window=window, center=True).std()
        filtered = imbalance[
            (imbalance > smoothed + threshold * std)
            | (imbalance < smoothed - threshold * std)
        ]
        logger.info(
            f"Filtering ({method}): {len(filtered.dropna())} significant imbalances"
        )
        return filtered
    except Exception as e:
        logger.error(f"Error in filter_imbalance: {str(e)}")
        return None


# --- Analysis of market conditions ---
def analyze_volatility_impact(eurusd, imbalance):
    """Connection between violatility and imbalance"""
    try:
        vol = eurusd["close"].pct_change().rolling(20).std() * np.sqrt(
            252
        )  # Yearly volatility
        df = pd.DataFrame({"volatility": vol, "imbalance": imbalance.abs()})
        correlation = df.corr().iloc[0, 1]
        logger.info(
            f"Volatility and imbalance correlation (2 months): {correlation:.6f}"
        )
    except Exception as e:
        logger.error(f"Error in analyze_volatility_impact: {str(e)}")


# --- Backtest ---
def backtest_strategy(
    imbalance, eurgbp, threshold=0.000126, ema_period=20, session="all"
):
    """Optimized scalping strategy with real spread"""
    try:
        # Sort by session
        df = pd.DataFrame({"imbalance": imbalance})
        df["hour"] = df.index.hour
        if session == "asian":
            df = df[(df["hour"] >= 0) & (df["hour"] < 8)]
        elif session == "european":
            df = df[(df["hour"] >= 8) & (df["hour"] < 16)]
        elif session == "american":
            df = df[(df["hour"] >= 16) & (df["hour"] <= 23)]
        imbalance = df["imbalance"]

        ema = imbalance.ewm(span=ema_period, adjust=False).mean()
        signals = pd.Series(0, index=imbalance.index)
        signals[imbalance > ema + threshold] = -1  # Sell
        signals[imbalance < ema - threshold] = 1  # Buy

        # Exit when crossing EMA
        exits = ((signals.shift(1) == 1) & (imbalance > ema)) | (
            (signals.shift(1) == -1) & (imbalance < ema)
        )
        signals[exits] = 0

        # Real spread from data
        spread = (
            eurgbp["spread"].reindex(signals.index, method="ffill") / 10000
        )  # Convert to price
        returns = signals.shift(1) * imbalance.diff() - spread * signals.abs()
        cumulative_returns = returns.cumsum()

        plt.figure(figsize=(10, 6))
        plt.plot(cumulative_returns, label="Cumulative Returns")
        plt.title(
            f"Optimized Backtest Results ({session.capitalize()} Session, 2 Months)"
        )
        plt.legend()
        output_file = os.path.join(
            output_dir, f"optimized_backtest_{session}_2months.png"
        )
        plt.savefig(output_file)
        plt.close()

        logger.info(
            f"[{session.capitalize()}] Total profit: {cumulative_returns.iloc[-1]:.6f}"
        )
        logger.info(
            f"[{session.capitalize()}] Maximum drawdown: {-(cumulative_returns - cumulative_returns.cummax()).min():.6f}"
        )
        logger.info(f"[{session.capitalize()}] Number of trades: {signals.abs().sum()}")
        logger.info(f"[{session.capitalize()}] Average spread: {spread.mean():.6f}")
    except Exception as e:
        logger.error(f"Error in backtest_strategy: {str(e)}")


# --- Extended triangles ---
def analyze_extended_triangles(symbol_sets, start_date=None, end_date=None):
    """Analyzing additional triangles"""
    try:
        for symbols in symbol_sets:
            dfs = [
                fetch_historical_data(symbol, start_date=start_date, end_date=end_date)
                for symbol in symbols
            ]
            if any(df is None for df in dfs):
                continue
            synced_dfs = sync_dataframes(dfs)
            synthetic = synced_dfs[0]["close"] / synced_dfs[1]["close"]
            imbalance = compute_imbalance(synced_dfs[2], synthetic)
            logger.info(f"Imbalance for {symbols} (2 months): {imbalance.mean():.6f}")
    except Exception as e:
        logger.error(f"Error in analyze_extended_triangles: {str(e)}")


# --- Console interface ---
def console_interface(imbalance, eurgbp):
    """Simple console interface"""
    try:
        while True:
            print("\nCommands: stats, plot, backtest, exit")
            cmd = input("Enter command: ").strip().lower()
            if cmd == "stats":
                analyze_imbalance_stats(imbalance)
            elif cmd == "plot":
                plot_imbalance_distribution(imbalance)
            elif cmd == "backtest":
                session = (
                    input("Enter session (all, asian, european, american): ")
                    .strip()
                    .lower()
                )
                backtest_strategy(imbalance, eurgbp, session=session)
            elif cmd == "exit":
                break
            else:
                print("Unknown command")
    except Exception as e:
        logger.error(f"Error in console_interface: {str(e)}")


# --- Main loop ---
def main():
    logger.info("Launch the analysis of synthetic pairs for 2 months")

    if not initialize_mt5():
        logger.error("Failed to connect to MT5")
        return

    # Set a period (2 months before the current date)
    end_date = datetime(2025, 3, 15)  # Current date from your instructions
    start_date = end_date - timedelta(days=60)  # January 15, 2025

    # Download data
    symbols = ["EURUSD", "GBPUSD", "EURGBP"]
    dfs = [
        fetch_historical_data(symbol, start_date=start_date, end_date=end_date)
        for symbol in symbols
    ]
    if any(df is None for df in dfs) or len(dfs) != 3:
        logger.error("Failed to download data")
        mt5.shutdown()
        return
    eurusd, gbpusd, eurgbp = sync_dataframes(dfs)

    # Tick data (the last 24 hours for comparison)
    tick_data = fetch_tick_data("EURUSD", hours=24)

    # Calculate synthetic pair and imbalance
    synthetic_eurgbp = calculate_synthetic_rate(eurusd, gbpusd, eurgbp, normalize=False)
    imbalance = compute_imbalance(eurgbp, synthetic_eurgbp)

    # Visualization
    plot_pairs_and_imbalance(eurusd, gbpusd, eurgbp, synthetic_eurgbp, imbalance)
    plot_imbalance_distribution(imbalance)
    plot_heatmap(imbalance)

    # Statistical analysis
    analyze_imbalance_stats(imbalance)
    segment_by_time(imbalance)

    # Filtration
    filtered_imbalance = filter_imbalance(imbalance, method="ema")

    # Analyzing market conditions
    analyze_volatility_impact(eurusd, imbalance)

    # Backtest by sessions
    for session in ["all", "asian", "european", "american"]:
        backtest_strategy(imbalance, eurgbp, threshold=0.000126, session=session)

    # Real time (in a separate thread)
    Thread(target=monitor_imbalance_real_time, args=(symbols,), daemon=True).start()

    # Extended triangles
    extended_sets = [["USDJPY", "EURJPY", "EURUSD"], ["AUDUSD", "NZDUSD", "AUDNZD"]]
    analyze_extended_triangles(extended_sets, start_date=start_date, end_date=end_date)

    # Correlations
    imbalances = {"EURGBP": imbalance}
    correlation_analysis(imbalances)

    # Console interface
    console_interface(imbalance, eurgbp)

    mt5.shutdown()
    logger.info("2-month analysis complete")


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.error(f"Unprocessed error: {str(e)}")
        mt5.shutdown()
