"""
ffd_cross_validate.py
=====================
Cross-validate FFDEngine.mqh against the Python afml implementation.

Workflow
--------
1. Run FFDValidation.mq5 in MetaTrader 5. It writes a CSV to the terminal's
   MQL5/Files/ folder (default path shown below).
2. Run this script, pointing --csv at that file.
3. The script recomputes FFD values in Python and compares them bar-by-bar.

Usage
-----
    python ffd_cross_validate.py --csv path/to/ffd_validation_EURUSD_H1_d0.40.csv

Requirements: numpy, pandas, scipy (for ADF test)
"""

import argparse
from pathlib import Path

import numpy as np
import pandas as pd


# ---------------------------------------------------------------------------
# FFD weight and computation functions (afml parity)
# ---------------------------------------------------------------------------

def get_weights_ffd(d: float, thres: float, lim: int = 100_000) -> np.ndarray:
    """
    Compute the fixed-width FFD weight vector.

    Matches afml.get_weights_ffd() exactly. Returns a 1-D array, oldest-lag
    first (i.e., after reversal), so index 0 = smallest weight.

    Parameters
    ----------
    d     : Fractional-differencing order.
    thres : Weight cutoff. Iteration stops when |w_k| < thres.
    lim   : Hard upper bound on the number of weights (safety guard).
    """
    weights = [1.0]
    k = 1
    while True:
        w_ = -weights[-1] * (d - k + 1) / k
        if abs(w_) < thres:
            break
        weights.append(w_)
        k += 1
        if k == lim:
            break
    return np.array(weights[::-1])   # 1-D, oldest-lag first


def frac_diff_ffd_from_csv(
    close: np.ndarray,
    d: float,
    thres: float,
    use_log: bool = True,
) -> np.ndarray:
    """
    Compute FFD values for a chronological price array.

    Replicates the MQL5 ComputeBuffer() logic so the two can be compared
    element-by-element. NaN fills the first `width` positions.

    Parameters
    ----------
    close   : 1-D array, chronological (oldest first).
    d       : Fractional-differencing order.
    thres   : Weight cutoff threshold.
    use_log : Apply ln(max(p, 1e-8)) before differencing.
    """
    weights = get_weights_ffd(d, thres)   # shape (n,), oldest-lag first
    width   = len(weights) - 1
    n       = len(close)

    if use_log:
        values = np.log(np.maximum(close, 1e-8))
    else:
        values = close.copy().astype(float)

    out = np.full(n, np.nan)
    for i in range(width, n):
        window = values[i - width : i + 1]    # length = width + 1, oldest first
        out[i] = np.dot(weights, window)       # weights[0]*oldest + ... + weights[width]*newest

    return out


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    parser = argparse.ArgumentParser(description="Cross-validate MQL5 FFD against Python")
    parser.add_argument("--csv",  required=True, help="Path to ffd_validation_*.csv from FFDValidation.mq5")
    parser.add_argument("--d",    type=float, default=None, help="Override d (default: inferred from filename)")
    parser.add_argument("--thres",type=float, default=1e-5, help="Weight threshold (default: 1e-5)")
    parser.add_argument("--no-log", action="store_true", help="Disable log transform")
    args = parser.parse_args()

    csv_path = Path(args.csv)
    if not csv_path.exists():
        raise FileNotFoundError(csv_path)

    # Infer d from filename if not supplied
    if args.d is None:
        import re
        m = re.search(r"_d([\d.]+)\.csv$", csv_path.name)
        if not m:
            raise ValueError("Could not infer d from filename. Pass --d explicitly.")
        d = float(m.group(1))
    else:
        d = args.d

    use_log = not args.no_log
    thres   = args.thres

    print(f"\nParameters:  d={d}  threshold={thres}  use_log={use_log}")
    print(f"CSV:         {csv_path}\n")

    # --- Load CSV produced by FFDValidation.mq5 ---
    # The CSV contains only rows where ffd != EMPTY_VALUE, starting at bar=width.
    df = pd.read_csv(csv_path, parse_dates=["datetime"])
    df = df.sort_values("bar_index").reset_index(drop=True)

    # Reconstruct the full close array (including the leading width bars that
    # were omitted from the CSV). We need the full price history to recompute.
    # The CSV bar_index tells us the absolute position of each row.
    # bar 0..width-1 are absent; bar width is the first row.
    first_bar = int(df["bar_index"].iloc[0])
    total_bars = int(df["bar_index"].iloc[-1]) + 1

    # Compute weights to find expected width
    weights = get_weights_ffd(d, thres)
    width   = len(weights) - 1
    print(f"Weight vector:  width={width}  min_bars={width + 1}")

    if first_bar != width:
        print(f"WARNING: expected first data bar at index {width}, "
              f"but CSV starts at {first_bar}. Check that d and thres match.")

    # We can only validate bars present in the CSV; we need the corresponding
    # close prices. The CSV contains close for those rows. To reconstruct the
    # full window for bar `width`, we need the prior `width` bars — which the
    # validation script did NOT export (they were EMPTY_VALUE rows). Therefore,
    # we validate using a rolling approach: for each bar i in the CSV, compute
    # Python FFD using only the close prices in the CSV, shifted by first_bar.
    #
    # This works because the CSV exports bars [width, N-1], and for bar width
    # the window is exactly [0, width], i.e., all bars from bar 0 onward.
    # We assume the user ran the script with enough bars that close[0] is the
    # bar_index=0 bar. We reconstruct close[0..first_bar-1] via the FFD formula
    # cannot — so instead we compare the Python computation on the exported
    # data directly.

    close_exported = df["close"].values.astype(float)
    ffd_mql5       = df["ffd"].values.astype(float)

    # Recompute Python FFD on the exported close prices.
    # Since the CSV starts at bar `first_bar`, the window for the first
    # exported bar requires prices[0..width] — all of which are in the CSV
    # (bar first_bar is the first row; it needs width prior bars which are
    # NOT in the CSV). Therefore this direct approach only works for bars
    # beyond the first exported bar.
    #
    # Correct approach: run Python FFD on the full reconstructed close
    # array. Since we don't have bars 0..width-1 in the CSV, we validate
    # only the *relative* differences between consecutive FFD values, OR
    # we ask the user to export close[] including the lookback bars.
    #
    # Pragmatic solution: the validation script already exports bar_index for
    # each row. We validate by computing FFD on the exported close[] as if
    # index 0 is bar 0 of the exported slice (which is bar `width` of the full
    # series). Bars exported from width onward form a contiguous slice; the
    # Python FFD computed on THAT slice starting at bar 0 of the slice equals
    # the MQL5 value at the corresponding bar — because both use the same price
    # history. The first valid Python output on this slice is at index `width`
    # of the slice, which corresponds to bar 2*width of the full series.
    #
    # To validate from the very first exported bar, we need all bars in the
    # CSV (rows 0..N-1 correspond to close[0..N-1] in the slice starting at
    # bar `first_bar`). This gives us enough data if N > width.

    n_exported = len(close_exported)
    if n_exported < width + 1:
        print(f"ERROR: Not enough exported bars ({n_exported}) to recompute "
              f"even one Python FFD value (need at least {width + 1}). "
              f"Increase InpBars in FFDValidation.mq5.")
        return

    # Compute Python FFD on the exported slice.
    ffd_py_on_slice = frac_diff_ffd_from_csv(close_exported, d, thres, use_log)

    # The first valid Python value on the slice is at index `width`.
    # That corresponds to CSV row `width` = MQL5 bar `first_bar + width`.
    valid_mask   = ~np.isnan(ffd_py_on_slice)
    py_valid     = ffd_py_on_slice[valid_mask]
    mql5_aligned = ffd_mql5[valid_mask]   # align by dropping the first `width` CSV rows

    n_compared = len(py_valid)
    if n_compared == 0:
        print("No bars available for comparison after aligning lookback.")
        return

    diffs = np.abs(py_valid - mql5_aligned)

    print(f"Bars compared:   {n_compared}  (of {n_exported} exported bars)")
    print(f"Max |diff|:      {diffs.max():.4e}")
    print(f"Mean |diff|:     {diffs.mean():.4e}")
    print(f"Bars within 1e-10: {(diffs < 1e-10).sum()} / {n_compared}")
    print(f"Bars within 1e-12: {(diffs < 1e-12).sum()} / {n_compared}")

    threshold_pass = 1e-10
    if diffs.max() < threshold_pass:
        print(f"\nPASS — max difference {diffs.max():.2e} < {threshold_pass:.0e}")
    else:
        print(f"\nFAIL — max difference {diffs.max():.2e} >= {threshold_pass:.0e}")
        # Show the 5 worst-offending bars
        worst = np.argsort(diffs)[::-1][:5]
        print("\nTop 5 discrepancies:")
        print(f"  {'slice_idx':>10}  {'mql5_ffd':>18}  {'py_ffd':>18}  {'|diff|':>12}")
        for idx in worst:
            row_idx = np.where(valid_mask)[0][idx]
            print(f"  {row_idx:>10}  {mql5_aligned[idx]:>18.12f}"
                  f"  {py_valid[idx]:>18.12f}  {diffs[idx]:>12.4e}")

    # Informational: first few matched values
    print("\nFirst 5 matched values (slice index = exported row):")
    print(f"  {'slice_idx':>10}  {'mql5_ffd':>20}  {'py_ffd':>20}  {'|diff|':>12}")
    for j in range(min(5, n_compared)):
        row_idx = np.where(valid_mask)[0][j]
        print(f"  {row_idx:>10}  {mql5_aligned[j]:>20.12f}"
              f"  {py_valid[j]:>20.12f}  {diffs[j]:>12.4e}")


if __name__ == "__main__":
    main()