# MT5 Analytics Suite  |  Unified Execution Module
# Reads:  exports/analytics_export.csv
# Writes: charts/  (PNG files, 120 DPI, max width 8.167 inches)

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
from scipy import stats
from scipy.stats import gaussian_kde
from matplotlib.lines import Line2D

BASE_DIR   = os.path.dirname(os.path.abspath(__file__))
DATA_PATH  = os.path.join(BASE_DIR, "exports", "analytics_export.csv")
OUTPUT_DIR = os.path.join(BASE_DIR, "charts")
os.makedirs(OUTPUT_DIR, exist_ok=True)

plt.rcParams.update({
    "figure.facecolor" : "#0d1117",  "axes.facecolor"   : "#161b22",
    "axes.edgecolor"   : "#30363d",  "axes.labelcolor"  : "#c9d1d9",
    "xtick.color"      : "#8b949e",  "ytick.color"      : "#8b949e",
    "text.color"       : "#c9d1d9",  "grid.color"       : "#21262d",
    "grid.linestyle"   : "--",       "grid.linewidth"   : 0.6,
    "font.family"      : "monospace","font.size"        : 9,
})


def load_analytics_export(path: str) -> pd.DataFrame:
    if not os.path.isfile(path):
        raise FileNotFoundError(
            f"Export file not found at:\n  {path}\n\n"
            "Ensure the MQL5 EA has written to this location or adjust DATA_PATH."
        )
    df = pd.read_csv(path)
    print(f"[Loaded] {len(df):,} records from {os.path.basename(path)}")
    return df


def plot_multi_asset_parameter_matrix(df: pd.DataFrame,
                                      target_period: int = 10,
                                      save_path: str = None):
    """
    Multi-Asset Parameter Match Matrix (Diverging Bar Chart)

    Filters the analytics export to a single Filter_Period and constructs
    a diverging horizontal bar plot comparing normalized performance metrics
    across all unique Symbol values present in that subset.

    Parameters
    ----------
    df            : Full analytics DataFrame from CSV export.
    target_period : The Filter_Period integer to isolate for comparison.
    save_path     : File path for saving the output figure (optional).
    """

    # --- Data Preparation
    subset = df[df["Filter_Period"] == target_period].copy()

    if subset.empty:
        raise ValueError(f"No records found for Filter_Period = {target_period}.")

    # Metrics to include in the comparison matrix
    metric_cols = ["Net_Profit_USD", "Sortino_Ratio", "Max_Drawdown_PCT",
                   "False_Flips_Whipsaws", "Trade_Count"]

    # Invert Max_Drawdown and False_Flips so higher = better across all fields
    subset["Max_Drawdown_PCT"]     = -subset["Max_Drawdown_PCT"]
    subset["False_Flips_Whipsaws"] = -subset["False_Flips_Whipsaws"]

    # Z-score normalization per metric column so all axes share a common scale
    for col in metric_cols:
        col_mean = subset[col].mean()
        col_std  = subset[col].std()
        subset[col + "_z"] = (subset[col] - col_mean) / (col_std + 1e-9)

    z_cols    = [c + "_z" for c in metric_cols]
    symbols   = subset["Symbol"].unique()
    n_metrics = len(metric_cols)

    # Aggregate to symbol-level mean z-scores where multiple timeframes exist
    agg = subset.groupby("Symbol")[z_cols].mean().reset_index()

    # --- Figure Layout
    fig, axes = plt.subplots(1, n_metrics,
                             figsize=(8.167, max(3.5, len(symbols) * 0.55 + 1.5)),
                             sharey=True)
    fig.suptitle(
        f"Multi-Asset Parameter Match Matrix  |  Filter Period: {target_period}",
        fontsize=11, fontweight="bold", color="#e6edf3", y=1.01
    )

    display_labels = {
        "Net_Profit_USD_z"       : "Net Profit (Normalized)",
        "Sortino_Ratio_z"        : "Sortino Ratio",
        "Max_Drawdown_PCT_z"     : "Max DD (Inverted)",
        "False_Flips_Whipsaws_z" : "Whipsaws (Inverted)",
        "Trade_Count_z"          : "Trade Count",
    }

    bar_colors_pos = "#3fb950"   # Green for positive z-score
    bar_colors_neg = "#f85149"   # Red for negative z-score

    for ax_idx, (ax, z_col) in enumerate(zip(axes, z_cols)):
        values     = agg[z_col].values
        sym_labels = agg["Symbol"].values
        y_pos      = np.arange(len(sym_labels))

        colors = [bar_colors_pos if v >= 0 else bar_colors_neg for v in values]
        bars   = ax.barh(y_pos, values, color=colors, height=0.62, alpha=0.88)

        ax.axvline(0, color="#484f58", linewidth=0.9, linestyle="-")
        ax.set_title(display_labels.get(z_col, z_col),
                     fontsize=8, color="#8b949e", pad=6)
        ax.tick_params(axis="x", labelsize=7)
        ax.grid(axis="x", alpha=0.4)
        ax.set_xlim(-3.5, 3.5)
        ax.spines[["top", "right"]].set_visible(False)

        # Annotate each bar with its raw z-score
        for bar_rect, val in zip(bars, values):
            label_x = bar_rect.get_width() + (0.08 if val >= 0 else -0.08)
            ha       = "left" if val >= 0 else "right"
            ax.text(label_x, bar_rect.get_y() + bar_rect.get_height() / 2,
                    f"{val:+.2f}", va="center", ha=ha, fontsize=6.5,
                    color="#8b949e")

        if ax_idx == 0:
            ax.set_yticks(y_pos)
            ax.set_yticklabels(sym_labels, fontsize=8)

    fig.text(0.5, -0.03,
             "Z-score deviation from group mean  |  Higher = Structurally Stronger",
             ha="center", fontsize=7.5, color="#8b949e", style="italic")

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


def plot_lag_vs_whipsaw(df: pd.DataFrame, save_path: str = None):
    """
    Signal Lag vs. Whipsaw Frequency (Joint Scatter with Trend Lines)

    Plots Avg_Lag_On_Turn_Bars (X) against False_Flips_Whipsaws (Y) for all
    records in the export. Each unique Indicator_Name receives a distinct color
    and a fitted OLS trend line to expose per-category efficiency trajectories.

    Parameters
    ----------
    df        : Full analytics DataFrame from CSV export.
    save_path : File path for saving the output figure (optional).
    """

    # --- Data Preparation
    required = ["Avg_Lag_On_Turn_Bars", "False_Flips_Whipsaws", "Indicator_Name"]
    plot_df  = df[required].dropna()

    indicators = plot_df["Indicator_Name"].unique()
    palette    = sns.color_palette("husl", len(indicators))
    color_map  = dict(zip(indicators, palette))

    # --- Figure Setup
    fig, ax = plt.subplots(figsize=(8.167, 5.4))
    fig.suptitle(
        "Signal Lag vs. Whipsaw Frequency  |  Alpha-Efficiency Boundary Analysis",
        fontsize=11, fontweight="bold", color="#e6edf3"
    )

    for indicator in indicators:
        mask   = plot_df["Indicator_Name"] == indicator
        x_vals = plot_df.loc[mask, "Avg_Lag_On_Turn_Bars"].values
        y_vals = plot_df.loc[mask, "False_Flips_Whipsaws"].values
        color  = color_map[indicator]

        ax.scatter(x_vals, y_vals, color=color, label=indicator,
                   alpha=0.75, s=52, linewidths=0.4,
                   edgecolors="white", zorder=3)

        # OLS trend line per indicator category (requires >= 2 points)
        if len(x_vals) >= 2:
            slope, intercept, r_value, p_value, _ = stats.linregress(x_vals, y_vals)
            x_line = np.linspace(x_vals.min(), x_vals.max(), 120)
            y_line = slope * x_line + intercept
            ax.plot(x_line, y_line, color=color, linewidth=1.5,
                    linestyle="--", alpha=0.7, zorder=2)
            ax.text(x_line[-1] + 0.15, y_line[-1],
                    f"R²={r_value**2:.2f}",
                    fontsize=7, color=color, va="center", alpha=0.9)

    # --- Quadrant Reference Lines
    x_med = plot_df["Avg_Lag_On_Turn_Bars"].median()
    y_med = plot_df["False_Flips_Whipsaws"].median()
    ax.axvline(x_med, color="#484f58", linewidth=0.8, linestyle=":")
    ax.axhline(y_med, color="#484f58", linewidth=0.8, linestyle=":")

    ax.text(plot_df["Avg_Lag_On_Turn_Bars"].min() + 0.1,
            plot_df["False_Flips_Whipsaws"].min() + 0.5,
            "Optimal Region (Low Lag | Low Noise)",
            fontsize=7.5, color="#3fb950", style="italic", alpha=0.8)

    ax.set_xlabel("Average Lag on Turn (Bars)", fontsize=9, labelpad=8)
    ax.set_ylabel("Whipsaw Count (False Flips)", fontsize=9, labelpad=8)
    ax.grid(True, alpha=0.35)
    ax.spines[["top", "right"]].set_visible(False)

    ax.legend(title="Indicator Class", fontsize=8,
              title_fontsize=8, loc="upper left",
              framealpha=0.25, edgecolor="#30363d")

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


def plot_walkforward_degradation(df: pd.DataFrame,
                                  top_n: int = 20,
                                  save_path: str = None):
    """
    Walk-Forward Performance Degradation Profile (Paired Slope Chart)

    Draws a slope chart pairing IS_Score (In-Sample) with OOS_Score
    (Out-of-Sample) for the top_n configurations ranked by IS_Score.
    Line color encodes degradation magnitude: green = moderate,
    orange = significant, red = collapse.

    Parameters
    ----------
    df        : Full analytics DataFrame from CSV export.
    top_n     : Number of top in-sample configurations to display.
    save_path : File path for saving the output figure (optional).
    """

    required = ["IS_Score", "OOS_Score", "Symbol", "Filter_Period", "Indicator_Name"]
    plot_df  = df[required].dropna().copy()
    plot_df  = plot_df.nlargest(top_n, "IS_Score").reset_index(drop=True)

    plot_df["Degradation"] = plot_df["OOS_Score"] / (plot_df["IS_Score"].abs() + 1e-9)

    def assign_color(ratio):
        if ratio >= 0.70:
            return "#3fb950"    # Retained >= 70% of IS performance
        elif ratio >= 0.35:
            return "#d29922"    # Moderate degradation (35–70%)
        else:
            return "#f85149"    # Collapse below 35% of IS performance

    plot_df["LineColor"] = plot_df["Degradation"].apply(assign_color)

    fig, ax = plt.subplots(figsize=(8.167, 6.2))
    fig.suptitle(
        "Walk-Forward Performance Degradation Profile  |  IS vs OOS Structural Decay",
        fontsize=11, fontweight="bold", color="#e6edf3"
    )

    x_left  = 0.0
    x_right = 1.0
    y_min = min(plot_df["IS_Score"].min(), plot_df["OOS_Score"].min()) - 0.05
    y_max = max(plot_df["IS_Score"].max(), plot_df["OOS_Score"].max()) + 0.05

    for _, row in plot_df.iterrows():
        y_is  = row["IS_Score"]
        y_oos = row["OOS_Score"]
        color = row["LineColor"]
        label = f"{row['Symbol']} | {row['Indicator_Name']} | P={int(row['Filter_Period'])}"

        ax.plot([x_left, x_right], [y_is, y_oos],
                color=color, linewidth=1.4, alpha=0.80, zorder=2)
        ax.scatter([x_left],  [y_is],  color=color, s=32, zorder=3, alpha=0.9)
        ax.scatter([x_right], [y_oos], color=color, s=32, zorder=3,
                   alpha=0.9, marker="D")
        ax.text(x_left - 0.03, y_is, label,
                ha="right", va="center", fontsize=6.2, color="#8b949e")

    ax.axvline(x_left,  color="#484f58", linewidth=0.7, linestyle="--", zorder=1)
    ax.axvline(x_right, color="#484f58", linewidth=0.7, linestyle="--", zorder=1)
    ax.text(x_left,  y_max + 0.01, "In-Sample Score",
            ha="center", fontsize=9, fontweight="bold", color="#c9d1d9")
    ax.text(x_right, y_max + 0.01, "Out-of-Sample Score",
            ha="center", fontsize=9, fontweight="bold", color="#c9d1d9")

    legend_handles = [
        Line2D([0],[0], color="#3fb950", lw=2, label="Moderate Decay (>=70% retained)"),
        Line2D([0],[0], color="#d29922", lw=2, label="Significant Decay (35–70%)"),
        Line2D([0],[0], color="#f85149", lw=2, label="Collapse (<35% retained)"),
    ]
    ax.legend(handles=legend_handles, fontsize=8,
              loc="upper center", bbox_to_anchor=(0.5, -0.04),
              ncol=3, framealpha=0.2, edgecolor="#30363d")

    ax.set_xlim(-0.65, 1.55)
    ax.set_ylim(y_min, y_max + 0.12)
    ax.set_xticks([])
    ax.set_ylabel("Performance Score (Normalized)", fontsize=9, labelpad=8)
    ax.grid(axis="y", alpha=0.25)
    ax.spines[["top", "right", "bottom"]].set_visible(False)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


def plot_drawdown_distribution(df: pd.DataFrame, save_path: str = None):
    """
    Maximum Drawdown Duration and Depth Distribution (Seaborn KDE)

    Generates two KDE distribution overlays on a dual-panel figure:
      Left panel:  KDE of Max_Drawdown_PCT across all records.
      Right panel: KDE of Drawdown_Duration_Bars.

    Parameters
    ----------
    df        : Full analytics DataFrame from CSV export.
    save_path : File path for saving the output figure (optional).
    """

    dd_depth    = df["Max_Drawdown_PCT"].dropna().abs().values
    dd_duration = df["Drawdown_Duration_Bars"].dropna().values

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8.167, 5.0))
    fig.suptitle(
        "Drawdown Distribution Profile  |  Depth and Duration Density Estimation",
        fontsize=11, fontweight="bold", color="#e6edf3"
    )

    # --- Left Panel: Drawdown Depth KDE
    sns.kdeplot(dd_depth, ax=ax1,
                color="#f85149", linewidth=2.2, fill=True, alpha=0.18)
    ax1.set_xlabel("Drawdown Depth (%)", fontsize=9, labelpad=7)
    ax1.set_ylabel("Probability Density", fontsize=9, labelpad=7)
    ax1.set_title("Depth Distribution", fontsize=9, color="#8b949e", pad=6)

    for pct, pct_val in [(50, np.percentile(dd_depth, 50)),
                         (90, np.percentile(dd_depth, 90)),
                         (95, np.percentile(dd_depth, 95))]:
        ax1.axvline(pct_val, color="#d29922", linewidth=0.9, linestyle=":", alpha=0.8)
        ax1.text(pct_val + 0.1, ax1.get_ylim()[1] * 0.92,
                 f"P{pct}: {pct_val:.1f}%", fontsize=6.5, color="#d29922", ha="left")

    ax1.hist(dd_depth, bins=30, density=True,
             color="#f85149", alpha=0.10, edgecolor="none", zorder=0)
    ax1.grid(True, alpha=0.3)
    ax1.spines[["top", "right"]].set_visible(False)

    # --- Right Panel: Drawdown Duration KDE
    sns.kdeplot(dd_duration, ax=ax2,
                color="#388bfd", linewidth=2.2, fill=True, alpha=0.18)
    ax2.set_xlabel("Drawdown Duration (Bars)", fontsize=9, labelpad=7)
    ax2.set_ylabel("Probability Density", fontsize=9, labelpad=7)
    ax2.set_title("Duration Distribution", fontsize=9, color="#8b949e", pad=6)

    for pct, pct_val in [(50, np.percentile(dd_duration, 50)),
                         (90, np.percentile(dd_duration, 90)),
                         (95, np.percentile(dd_duration, 95))]:
        ax2.axvline(pct_val, color="#d29922", linewidth=0.9, linestyle=":", alpha=0.8)
        ax2.text(pct_val + 0.5, ax2.get_ylim()[1] * 0.92,
                 f"P{pct}: {pct_val:.0f}b", fontsize=6.5, color="#d29922", ha="left")

    ax2.hist(dd_duration, bins=30, density=True,
             color="#388bfd", alpha=0.10, edgecolor="none", zorder=0)
    ax2.grid(True, alpha=0.3)
    ax2.spines[["top", "right"]].set_visible(False)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


def plot_intraday_heatmap(df: pd.DataFrame,
                           metric: str = "Net_Profit_USD",
                           save_path: str = None):
    """
    Intraday and Intraweek Alpha Clusters (Hour-by-Day Performance Heatmap)

    Aggregates the chosen performance metric by Entry_Hour (0–23) and
    Entry_DayOfWeek (0=Monday, 4=Friday), then renders a seaborn heatmap
    with annotated cell values and session zone overlays.

    Parameters
    ----------
    df        : Full analytics DataFrame from CSV export.
    metric    : Column to aggregate into the heatmap cells.
    save_path : File path for saving the output figure (optional).
    """

    required = ["Entry_Hour", "Entry_DayOfWeek", metric]
    plot_df  = df[required].dropna().copy()
    plot_df["Entry_Hour"]      = plot_df["Entry_Hour"].astype(int)
    plot_df["Entry_DayOfWeek"] = plot_df["Entry_DayOfWeek"].astype(int)

    pivot = plot_df.pivot_table(index="Entry_Hour", columns="Entry_DayOfWeek",
                                values=metric, aggfunc="mean")
    pivot = pivot.reindex(index=range(24), columns=range(5), fill_value=0.0)

    day_labels = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri"}
    pivot.rename(columns=day_labels, inplace=True)

    fig, ax = plt.subplots(figsize=(8.167, 7.6))
    fig.suptitle(
        f"Intraday Alpha Cluster Heatmap  |  {metric} by Session Hour and Weekday",
        fontsize=11, fontweight="bold", color="#e6edf3"
    )

    cmap = sns.diverging_palette(10, 130, as_cmap=True)

    sns.heatmap(pivot, ax=ax, cmap=cmap, center=0,
                linewidths=0.4, linecolor="#21262d",
                annot=True, fmt=".0f",
                annot_kws={"size": 6.5, "color": "#e6edf3"},
                cbar_kws={"label": metric, "shrink": 0.75, "pad": 0.02})

    # --- Session Zone Overlays (approximate UTC ranges)
    session_zones = [
        (0,  9,  "Asian",    "#388bfd", 0.06),
        (7,  16, "London",   "#3fb950", 0.06),
        (13, 22, "New York", "#d29922", 0.06),
    ]

    for row_start, row_end, label, color, alpha in session_zones:
        ax.add_patch(
            plt.Rectangle((0, row_start), len(pivot.columns), row_end - row_start,
                           fill=True, color=color, alpha=alpha, zorder=0)
        )
        ax.text(-0.35, (row_start + row_end) / 2, label,
                va="center", ha="right", fontsize=7,
                color=color, rotation=90, style="italic")

    ax.set_xlabel("Weekday", fontsize=9, labelpad=7)
    ax.set_ylabel("Entry Hour (UTC)", fontsize=9, labelpad=7)
    ax.tick_params(axis="both", labelsize=8)
    ax.set_yticklabels([f"{h:02d}:00" for h in range(24)], fontsize=7, rotation=0)

    cbar = ax.collections[0].colorbar
    cbar.ax.yaxis.label.set_color("#c9d1d9")
    cbar.ax.tick_params(labelcolor="#8b949e")

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=120, bbox_inches="tight",
                    facecolor=fig.get_facecolor())
    plt.show()


def run_full_suite(data_path=DATA_PATH, output_dir=OUTPUT_DIR,
                   wf_top_n=20, target_period=10):
    df = load_analytics_export(data_path)

    chart_map = {
        "Multi_Asset_Parameter_Matrix.png": lambda: plot_multi_asset_parameter_matrix(
            df, target_period=target_period,
            save_path=os.path.join(output_dir, "Multi_Asset_Parameter_Matrix.png")),
        "Lag_vs_Whipsaw_Scatter.png": lambda: plot_lag_vs_whipsaw(
            df, save_path=os.path.join(output_dir, "Lag_vs_Whipsaw_Scatter.png")),
        "WalkForward_Degradation_Slope.png": lambda: plot_walkforward_degradation(
            df, top_n=wf_top_n,
            save_path=os.path.join(output_dir, "WalkForward_Degradation_Slope.png")),
        "Drawdown_Distribution_KDE.png": lambda: plot_drawdown_distribution(
            df, save_path=os.path.join(output_dir, "Drawdown_Distribution_KDE.png")),
        "Intraday_Performance_Heatmap.png": lambda: plot_intraday_heatmap(
            df, metric="Net_Profit_USD",
            save_path=os.path.join(output_dir, "Intraday_Performance_Heatmap.png")),
    }

    for chart_name, chart_fn in chart_map.items():
        print(f"\n[Rendering] {chart_name}")
        try:
            chart_fn()
        except Exception as exc:
            print(f"  [Warning] {chart_name} skipped: {exc}")

    print(f"\n[Complete] All charts written to: {output_dir}")


if __name__ == "__main__":
    run_full_suite()