import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
from scipy import stats
from scipy.stats import gaussian_kde

# Path Configuration
# MT5 terminal sandbox default path (Windows):
# C:\Users\<user>\AppData\Roaming\MetaQuotes\Terminal\<InstanceID>\MQL5\Files\
# Adjust DATA_PATH to point to your local MT5 Files directory.

BASE_DIR   = os.path.dirname(os.path.abspath(__file__))
DATA_PATH  = os.path.join(BASE_DIR, "exports", "analytics_export.csv")
OUTPUT_DIR = os.path.join(BASE_DIR, "charts")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load the normalized analytics export
df = pd.read_csv(DATA_PATH)

# Global Style Configuration
plt.rcParams.update({
    "figure.facecolor"  : "#0d1117",
    "axes.facecolor"    : "#161b22",
    "axes.edgecolor"    : "#30363d",
    "axes.labelcolor"   : "#c9d1d9",
    "xtick.color"       : "#8b949e",
    "ytick.color"       : "#8b949e",
    "text.color"        : "#c9d1d9",
    "grid.color"        : "#21262d",
    "grid.linestyle"    : "--",
    "grid.linewidth"    : 0.6,
    "font.family"       : "monospace",
    "font.size"         : 9,
})


def plot_lag_vs_whipsaw(df: pd.DataFrame, save_path: str = None):
    """
    Signal Lag vs. Whipsaw Frequency (Joint Scatter with Trend Lines)

    Plots Avg_Lag_On_Turn_Bars (X) against False_Flips_Whipsaws (Y) for all
    records in the export. Each unique Indicator_Name receives a distinct color
    and a fitted OLS trend line to expose per-category efficiency trajectories.

    Parameters
    ----------
    df        : Full analytics DataFrame from CSV export.
    save_path : File path for saving the output figure (optional).
    """

    # --- Data Preparation
    required = ["Avg_Lag_On_Turn_Bars", "False_Flips_Whipsaws", "Indicator_Name"]
    plot_df  = df[required].dropna()

    indicators = plot_df["Indicator_Name"].unique()
    palette    = sns.color_palette("husl", len(indicators))
    color_map  = dict(zip(indicators, palette))

    # --- Figure Setup
    fig, ax = plt.subplots(figsize=(8.167, 5.4))
    fig.suptitle(
        "Signal Lag vs. Whipsaw Frequency  |  Alpha-Efficiency Boundary Analysis",
        fontsize=11, fontweight="bold", color="#e6edf3"
    )

    for indicator in indicators:
        mask   = plot_df["Indicator_Name"] == indicator
        x_vals = plot_df.loc[mask, "Avg_Lag_On_Turn_Bars"].values
        y_vals = plot_df.loc[mask, "False_Flips_Whipsaws"].values
        color  = color_map[indicator]

        ax.scatter(x_vals, y_vals, color=color, label=indicator,
                   alpha=0.75, s=52, linewidths=0.4,
                   edgecolors="white", zorder=3)

        # OLS trend line per indicator category (requires >= 2 points)
        if len(x_vals) >= 2:
            slope, intercept, r_value, p_value, _ = stats.linregress(x_vals, y_vals)
            x_line = np.linspace(x_vals.min(), x_vals.max(), 120)
            y_line = slope * x_line + intercept
            ax.plot(x_line, y_line, color=color, linewidth=1.5,
                    linestyle="--", alpha=0.7, zorder=2)
            ax.text(x_line[-1] + 0.15, y_line[-1],
                    f"R²={r_value**2:.2f}",
                    fontsize=7, color=color, va="center", alpha=0.9)

    # --- Quadrant Reference Lines
    x_med = plot_df["Avg_Lag_On_Turn_Bars"].median()
    y_med = plot_df["False_Flips_Whipsaws"].median()
    ax.axvline(x_med, color="#484f58", linewidth=0.8, linestyle=":")
    ax.axhline(y_med, color="#484f58", linewidth=0.8, linestyle=":")

    ax.text(plot_df["Avg_Lag_On_Turn_Bars"].min() + 0.1,
            plot_df["False_Flips_Whipsaws"].min() + 0.5,
            "Optimal Region (Low Lag | Low Noise)",
            fontsize=7.5, color="#3fb950", style="italic", alpha=0.8)

    ax.set_xlabel("Average Lag on Turn (Bars)", fontsize=9, labelpad=8)
    ax.set_ylabel("Whipsaw Count (False Flips)", fontsize=9, labelpad=8)
    ax.grid(True, alpha=0.35)
    ax.spines[["top", "right"]].set_visible(False)

    ax.legend(title="Indicator Class", fontsize=8,
              title_fontsize=8, loc="upper left",
              framealpha=0.25, edgecolor="#30363d")

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


# Execute
plot_lag_vs_whipsaw(
    df,
    save_path=os.path.join(OUTPUT_DIR, "Lag_vs_Whipsaw_Scatter.png")
)
