"""
RenkoRegressorCascade.py
========================
Регрессионная система прогнозирования Ренко-баров с ИТЕРАТИВНОЙ
каскадной коррекцией ошибок и финальным bias correction.

Логика каскада (residual boosting over CatBoost):
-------------------------------------------------
Стадия 0:  M_0(X) -> ŷ_0,        residual_0 = y - ŷ_0
Стадия 1:  M_1(X) -> r̂_1,        ŷ_1 = ŷ_0 + η·r̂_1,    residual_1 = y - ŷ_1
Стадия 2:  M_2(X) -> r̂_2,        ŷ_2 = ŷ_1 + η·r̂_2,    residual_2 = y - ŷ_2
...
Итоговый сырой прогноз: ŷ = ŷ_0 + η·(r̂_1 + r̂_2 + ... + r̂_K)
Финальный прогноз:       ŷ_final = ŷ - bias(ŷ)

Ключевые отличия от "просто CatBoost с большим числом итераций":
  * Каждая стадия — полностью отдельный CatBoostRegressor со своими
    early-stopping'ом и сплитом на train/val по времени.
  * Между стадиями residuals считаются на OOF (TimeSeriesSplit),
    так что каскад не подглядывает в будущее.
  * Shrinkage η < 1 замедляет переобучение на шуме residuals.
  * Early stopping на уровне КАСКАДА — останавливаемся, когда OOF MAE
    перестаёт падать на `patience` итераций подряд.
  * После каскада — ещё один слой conditional bias correction.

Автор: Evgeniy Koshtenko (MIDAS ecosystem)
"""

import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple

try:
    import MetaTrader5 as mt5
    MT5_AVAILABLE = True
except ImportError:
    MT5_AVAILABLE = False

from catboost import CatBoostRegressor, Pool
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score


# ============================================================
# 1. ЗАГРУЗКА ДАННЫХ (MT5)
# ============================================================

def get_mt5_data(symbol: str = "EURUSD",
                 timeframe: int = None,
                 days: int = 60) -> Optional[pd.DataFrame]:
    if not MT5_AVAILABLE:
        raise ImportError("MetaTrader5 не установлен. pip install MetaTrader5")
    if timeframe is None:
        timeframe = mt5.TIMEFRAME_M5
    if not mt5.initialize():
        print(f"Ошибка инициализации MT5: {mt5.last_error()}")
        return None
    start_date = datetime.now() - timedelta(days=days)
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, datetime.now())
    mt5.shutdown()
    if rates is None or len(rates) == 0:
        return None
    df = pd.DataFrame(rates)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    return df


# ============================================================
# 2. РЕНКО-БАРЫ
# ============================================================

def _calc_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
    tr = pd.concat([
        df["high"] - df["low"],
        (df["high"] - df["close"].shift()).abs(),
        (df["low"] - df["close"].shift()).abs(),
    ], axis=1).max(axis=1)
    return tr.rolling(period).mean()


def create_renko_bars(df: pd.DataFrame,
                      brick_size: Optional[float] = None,
                      atr_multiplier: float = 0.5) -> Tuple[pd.DataFrame, float]:
    if brick_size is None:
        atr = _calc_atr(df, 14)
        brick_size = float(atr.mean() * atr_multiplier)
        print(f"Размер Ренко-блока: {brick_size:.6f}")
    if brick_size <= 0:
        raise ValueError("brick_size должен быть > 0")

    renko_bars = []
    current_price = float(df.iloc[0]["close"])
    current_direction = None
    bar_open = current_price
    bar_time = df.iloc[0]["time"]
    volume_sum = 0.0

    vol_col = "tick_volume" if "tick_volume" in df.columns else (
        "volume" if "volume" in df.columns else None
    )

    for _, row in df.iterrows():
        if vol_col is not None:
            volume_sum += float(row[vol_col])
        price = float(row["close"])
        price_change = price - current_price
        num_bricks = int(abs(price_change) / brick_size)

        if num_bricks > 0:
            direction = 1 if price_change > 0 else -1
            if current_direction is not None and direction != current_direction:
                num_bricks += 1

            for _ in range(num_bricks):
                if current_direction is None or current_direction == direction:
                    bar_close = bar_open + brick_size * direction
                else:
                    bar_open = bar_open + brick_size * current_direction
                    bar_close = bar_open + brick_size * direction

                renko_bars.append({
                    "time": bar_time, "open": bar_open,
                    "high": max(bar_open, bar_close),
                    "low": min(bar_open, bar_close),
                    "close": bar_close, "volume": volume_sum,
                    "direction": direction,
                })
                bar_open = bar_close
                bar_time = row["time"]
                current_direction = direction

            volume_sum = 0.0
            current_price = current_price + num_bricks * brick_size * direction

    renko_df = pd.DataFrame(renko_bars)
    if len(renko_df) == 0:
        return renko_df, brick_size

    renko_df["is_up"] = (renko_df["direction"] > 0).astype(int)
    renko_df["is_down"] = (renko_df["direction"] < 0).astype(int)
    for col in ["is_up", "is_down"]:
        grp = renko_df[col].ne(renko_df[col].shift()).cumsum()
        renko_df[f"{col}_streak"] = renko_df.groupby(grp)[col].cumsum()
    return renko_df, brick_size


# ============================================================
# 3. ПРИЗНАКИ И ТАРГЕТ
# ============================================================

def prepare_regression_features(renko_df: pd.DataFrame,
                                lookback: int = 10,
                                horizon: int = 1) -> Tuple[pd.DataFrame, np.ndarray, pd.Series]:
    features, targets, times = [], [], []
    n = len(renko_df)
    if n < lookback + horizon + 1:
        return pd.DataFrame(), np.array([]), pd.Series(dtype="datetime64[ns]")

    closes = renko_df["close"].values
    dirs = renko_df["direction"].values
    vols = renko_df["volume"].values
    up_streaks = renko_df["is_up_streak"].values
    dn_streaks = renko_df["is_down_streak"].values

    for i in range(lookback, n - horizon):
        w_dirs = dirs[i - lookback:i]
        w_closes = closes[i - lookback:i]
        w_vols = vols[i - lookback:i]
        avg_vol = w_vols.mean() if w_vols.mean() > 0 else 1.0
        brick_proxy = np.mean(np.abs(np.diff(w_closes))) if len(w_closes) > 1 else 0.0

        feat = {
            **{f"dir_{j}": int(w_dirs[-(j + 1)]) for j in range(lookback)},
            "up_ratio": float((w_dirs > 0).mean()),
            "last_dir": int(w_dirs[-1]),
            "dir_changes": int(np.sum(np.abs(np.diff(w_dirs)) > 0)),
            "last_up_streak": int(up_streaks[i - 1]),
            "last_down_streak": int(dn_streaks[i - 1]),
            "max_up_streak": int(up_streaks[i - lookback:i].max()),
            "max_down_streak": int(dn_streaks[i - lookback:i].max()),
            "last_volume": float(w_vols[-1]),
            "avg_volume": float(avg_vol),
            "volume_ratio": float(w_vols[-1] / avg_vol),
            "volume_trend": float(np.polyfit(range(lookback), w_vols, 1)[0]) if lookback > 1 else 0.0,
            "price_range": float(w_closes.max() - w_closes.min()),
            "range_per_bar": float(brick_proxy),
            "last_return": float(np.log(w_closes[-1] / w_closes[-2])) if w_closes[-2] > 0 else 0.0,
            "cum_return": float(np.log(w_closes[-1] / w_closes[0])) if w_closes[0] > 0 else 0.0,
        }
        c_now = closes[i - 1]
        c_fut = closes[i - 1 + horizon]
        if c_now <= 0 or c_fut <= 0:
            continue
        target = float(np.log(c_fut / c_now))
        features.append(feat)
        targets.append(target)
        times.append(renko_df["time"].iloc[i - 1])

    return pd.DataFrame(features), np.array(targets, dtype=float), pd.Series(times, name="time")


# ============================================================
# 4. BIAS CORRECTOR (финальный слой)
# ============================================================

@dataclass
class BiasCorrector:
    method: str = "conditional"
    global_bias: float = 0.0
    bull_bias: float = 0.0
    bear_bias: float = 0.0
    residual_std: float = 0.0

    def fit(self, y_true: np.ndarray, y_pred_oof: np.ndarray) -> "BiasCorrector":
        residuals = y_pred_oof - y_true
        self.global_bias = float(np.mean(residuals))
        self.residual_std = float(np.std(residuals, ddof=1)) if len(residuals) > 1 else 0.0
        if self.method == "global_median":
            self.global_bias = float(np.median(residuals))
        if self.method == "conditional":
            bull = y_pred_oof > 0
            bear = y_pred_oof < 0
            self.bull_bias = float(np.mean(residuals[bull])) if bull.sum() > 5 else self.global_bias
            self.bear_bias = float(np.mean(residuals[bear])) if bear.sum() > 5 else self.global_bias
        return self

    def correct(self, raw_pred: float) -> float:
        if self.method == "conditional":
            if raw_pred > 0: return raw_pred - self.bull_bias
            if raw_pred < 0: return raw_pred - self.bear_bias
        return raw_pred - self.global_bias

    def confidence_interval(self, pred: float, k: float = 1.96) -> Tuple[float, float]:
        return pred - k * self.residual_std, pred + k * self.residual_std


# ============================================================
# 5. КАСКАД МОДЕЛЕЙ — ядро новой версии
# ============================================================

@dataclass
class CascadeStage:
    """Одна стадия каскада: модель, обученная на residuals предыдущей стадии."""
    model: CatBoostRegressor
    oof_mae: float
    shrinkage_applied: float


@dataclass
class RenkoCascadeModel:
    # Параметры фичей
    lookback: int = 10
    horizon: int = 1
    atr_multiplier: float = 0.5

    # Параметры каскада
    max_stages: int = 15          # максимальное число стадий
    patience: int = 3             # сколько стадий без улучшения до остановки
    min_improvement: float = 1e-7 # минимальное улучшение MAE, которое считается значимым
    shrinkage: float = 0.6        # коэффициент η перед residual-моделями (первая стадия всегда 1.0)

    # Параметры CatBoost
    base_iterations: int = 500
    stage_iterations: int = 300   # меньше для residual-стадий (они решают всё более мелкую задачу)
    base_depth: int = 6
    stage_depth: int = 4          # residuals проще — хватит меньшей глубины

    # Cross-validation
    n_splits: int = 5

    # Bias correction
    bias_method: str = "conditional"

    # --- Состояние после fit() ---
    stages: List[CascadeStage] = field(default_factory=list)
    corrector: Optional[BiasCorrector] = None
    feature_names: List[str] = field(default_factory=list)
    brick_size: float = 0.0
    training_history: List[Dict] = field(default_factory=list)
    final_metrics: Dict[str, float] = field(default_factory=dict)

    def _base_params(self, stage_idx: int) -> dict:
        is_base = (stage_idx == 0)
        return {
            "iterations": self.base_iterations if is_base else self.stage_iterations,
            "learning_rate": 0.03 if is_base else 0.02,
            "depth": self.base_depth if is_base else self.stage_depth,
            "loss_function": "RMSE",
            "eval_metric": "MAE",
            "l2_leaf_reg": 3.0 if is_base else 5.0,  # residuals шумнее -> сильнее регуляризация
            "random_seed": 42 + stage_idx,
            "verbose": False,
            "early_stopping_rounds": 40,
        }

    def _compute_oof(self, X: pd.DataFrame, y: np.ndarray, stage_idx: int) -> Tuple[np.ndarray, CatBoostRegressor]:
        """Считаем OOF-прогнозы для текущего таргета (residuals) через TimeSeriesSplit."""
        tscv = TimeSeriesSplit(n_splits=self.n_splits)
        oof = np.full(len(X), np.nan)
        params = self._base_params(stage_idx)

        for tr_idx, va_idx in tscv.split(X):
            m = CatBoostRegressor(**params)
            m.fit(
                Pool(X.iloc[tr_idx], y[tr_idx]),
                eval_set=Pool(X.iloc[va_idx], y[va_idx]),
                verbose=False,
            )
            oof[va_idx] = m.predict(X.iloc[va_idx])

        # Финальная модель этой стадии — на всех данных
        final_model = CatBoostRegressor(**params)
        final_model.fit(X, y, verbose=False)
        return oof, final_model

    def fit(self, df: pd.DataFrame) -> "RenkoCascadeModel":
        # --- Ренко + признаки ---
        renko_df, self.brick_size = create_renko_bars(df, atr_multiplier=self.atr_multiplier)
        print(f"Создано {len(renko_df)} Ренко-баров")

        X, y, _ = prepare_regression_features(renko_df, self.lookback, self.horizon)
        if len(X) < 100:
            raise ValueError(f"Слишком мало образцов: {len(X)}")
        self.feature_names = X.columns.tolist()
        print(f"Обучающая выборка: {len(X)} сэмплов, {len(self.feature_names)} признаков\n")

        # --- Нулевая стадия: базовая модель ---
        print("=" * 70)
        print("КАСКАДНОЕ ОБУЧЕНИЕ")
        print("=" * 70)

        print(f"\n[Stage 0] Базовая модель (shrinkage=1.00)")
        oof_cum, base_model = self._compute_oof(X, y, stage_idx=0)
        # Валидные индексы — те, где TimeSeriesSplit дал прогноз
        valid = ~np.isnan(oof_cum)
        y_v = y[valid]
        oof_cum_v = oof_cum[valid]

        mae_stage = mean_absolute_error(y_v, oof_cum_v)
        self.stages.append(CascadeStage(model=base_model, oof_mae=mae_stage, shrinkage_applied=1.0))
        self.training_history.append({
            "stage": 0, "oof_mae": mae_stage,
            "oof_rmse": float(np.sqrt(mean_squared_error(y_v, oof_cum_v))),
            "oof_r2": r2_score(y_v, oof_cum_v),
            "dir_acc": float(np.mean(np.sign(oof_cum_v) == np.sign(y_v))),
        })
        print(f"          OOF MAE = {mae_stage:.8f}   R² = {self.training_history[-1]['oof_r2']:+.4f}"
              f"   DirAcc = {self.training_history[-1]['dir_acc']*100:.2f}%")

        # --- Итеративные residual-стадии ---
        best_mae = mae_stage
        best_stage = 0
        no_improve_count = 0

        for s in range(1, self.max_stages + 1):
            # residual для ВСЕХ точек (включая невалидные — там residual = y - nan = nan,
            # отфильтруем позже)
            residuals = y - oof_cum

            # Обучаем стадию на residuals, используя только валидные точки
            X_s = X.loc[valid].reset_index(drop=True)
            r_s = residuals[valid]

            print(f"\n[Stage {s}] Residual-модель (shrinkage={self.shrinkage:.2f}, "
                  f"residual_std={np.std(r_s):.8f})")

            oof_r, stage_model = self._compute_oof(X_s, r_s, stage_idx=s)
            valid_s = ~np.isnan(oof_r)
            if valid_s.sum() < 50:
                print("          Слишком мало OOF-точек, стоп.")
                break

            # Обновляем кумулятивный OOF-прогноз с учётом shrinkage
            # Внимание: oof_r посчитан на X_s (валидные точки), обновляем их же
            oof_cum_updated = oof_cum.copy()
            idx_valid = np.where(valid)[0]
            # Для точек с валидным oof_r применяем shrinkage
            mask_inner = valid_s
            oof_cum_updated[idx_valid[mask_inner]] = (
                oof_cum[idx_valid[mask_inner]] + self.shrinkage * oof_r[mask_inner]
            )

            # Валидный общий масив после этой стадии
            valid_new = ~np.isnan(oof_cum_updated)
            y_v_new = y[valid_new]
            oof_v_new = oof_cum_updated[valid_new]
            mae_new = mean_absolute_error(y_v_new, oof_v_new)

            improvement = best_mae - mae_new
            better = improvement > self.min_improvement

            self.training_history.append({
                "stage": s,
                "oof_mae": mae_new,
                "oof_rmse": float(np.sqrt(mean_squared_error(y_v_new, oof_v_new))),
                "oof_r2": r2_score(y_v_new, oof_v_new),
                "dir_acc": float(np.mean(np.sign(oof_v_new) == np.sign(y_v_new))),
                "improvement": improvement,
                "accepted": better,
            })

            print(f"          OOF MAE = {mae_new:.8f}   ΔMAE = {improvement:+.2e}"
                  f"   R² = {self.training_history[-1]['oof_r2']:+.4f}"
                  f"   DirAcc = {self.training_history[-1]['dir_acc']*100:.2f}%")

            if better:
                # Принимаем стадию
                self.stages.append(CascadeStage(
                    model=stage_model, oof_mae=mae_new, shrinkage_applied=self.shrinkage
                ))
                oof_cum = oof_cum_updated
                valid = valid_new
                best_mae = mae_new
                best_stage = s
                no_improve_count = 0
                print(f"          ✓ принята (улучшение)")
            else:
                no_improve_count += 1
                print(f"          ✗ отклонена (без улучшения, {no_improve_count}/{self.patience})")
                if no_improve_count >= self.patience:
                    print(f"\nРанняя остановка: нет улучшения {self.patience} стадий подряд.")
                    break

        # Обрезаем каскад до последней принятой стадии (на всякий случай)
        self.stages = self.stages[:best_stage + 1]
        print(f"\nИтоговая глубина каскада: {len(self.stages)} стадий "
              f"(базовая + {len(self.stages)-1} residual)")

        # --- Финальный bias correction поверх каскада ---
        print("\n" + "=" * 70)
        print("ФИНАЛЬНАЯ BIAS-КОРРЕКЦИЯ (conditional)")
        print("=" * 70)

        self.corrector = BiasCorrector(method=self.bias_method).fit(y_v_new := y[valid], oof_cum[valid])
        corrected = np.array([self.corrector.correct(p) for p in oof_cum[valid]])
        mae_corr_candidate = mean_absolute_error(y_v_new, corrected)

        # Guard: если bias correction ухудшает MAE — отключаем его
        if mae_corr_candidate > best_mae:
            print(f"\n  ⚠ Bias correction ухудшает MAE ({best_mae:.8f} -> {mae_corr_candidate:.8f}),"
                  f" отключаю её.")
            self.corrector.global_bias = 0.0
            self.corrector.bull_bias = 0.0
            self.corrector.bear_bias = 0.0
            corrected = oof_cum[valid].copy()

        mae_corr = mean_absolute_error(y_v_new, corrected)
        rmse_corr = float(np.sqrt(mean_squared_error(y_v_new, corrected)))
        r2_corr = r2_score(y_v_new, corrected)
        dir_corr = float(np.mean(np.sign(corrected) == np.sign(y_v_new)))

        # Стадия 0 — чистая базовая модель без каскада и bias (для сравнения)
        stage0_mae = self.training_history[0]["oof_mae"]
        stage0_r2 = self.training_history[0]["oof_r2"]
        stage0_dir = self.training_history[0]["dir_acc"]

        self.final_metrics = {
            "stage0_mae": stage0_mae,
            "cascade_mae": best_mae,
            "final_mae": mae_corr,
            "stage0_r2": stage0_r2,
            "final_r2": r2_corr,
            "stage0_dir_acc": stage0_dir,
            "final_dir_acc": dir_corr,
            "residual_std": self.corrector.residual_std,
            "global_bias": self.corrector.global_bias,
            "bull_bias": self.corrector.bull_bias,
            "bear_bias": self.corrector.bear_bias,
            "cascade_depth": len(self.stages),
        }

        print(f"\n  MAE:       base={stage0_mae:.8f}  ->  cascade={best_mae:.8f}"
              f"  ->  +bias={mae_corr:.8f}")
        print(f"  RMSE:                                                         {rmse_corr:.8f}")
        print(f"  R²:        base={stage0_r2:+.4f}        ->  final={r2_corr:+.4f}")
        print(f"  DirAcc:    base={stage0_dir*100:.2f}%      ->  final={dir_corr*100:.2f}%")
        mae_improvement_pct = (stage0_mae - mae_corr) / stage0_mae * 100
        print(f"  Суммарное улучшение MAE: {mae_improvement_pct:+.2f}%")
        print(f"\n  Bias:  global={self.corrector.global_bias:+.6e}  "
              f"bull={self.corrector.bull_bias:+.6e}  "
              f"bear={self.corrector.bear_bias:+.6e}")
        print(f"  Residual σ: {self.corrector.residual_std:.6e}")

        return self

    def predict_raw(self, X_row: pd.DataFrame) -> float:
        """Проход по всему каскаду без bias-коррекции."""
        pred = float(self.stages[0].model.predict(X_row)[0])
        for st in self.stages[1:]:
            pred += st.shrinkage_applied * float(st.model.predict(X_row)[0])
        return pred

    def predict(self, df: pd.DataFrame) -> Dict[str, float]:
        if not self.stages or self.corrector is None:
            raise RuntimeError("Модель не обучена — вызовите fit()")

        renko_df, _ = create_renko_bars(df, brick_size=self.brick_size)
        X, _, _ = prepare_regression_features(renko_df, self.lookback, self.horizon)
        if len(X) == 0:
            return {"error": "Недостаточно Ренко-баров"}

        X = X[self.feature_names]
        X_row = X.iloc[[-1]]

        # Пошаговый вывод, чтобы видеть вклад каждой стадии
        stage_contributions = []
        pred_cum = 0.0
        for i, st in enumerate(self.stages):
            p = float(st.model.predict(X_row)[0])
            contrib = p if i == 0 else st.shrinkage_applied * p
            pred_cum += contrib
            stage_contributions.append({
                "stage": i,
                "raw_output": p,
                "contribution": contrib,
                "cumulative": pred_cum,
            })

        raw = pred_cum
        corrected = self.corrector.correct(raw)
        lo, hi = self.corrector.confidence_interval(corrected, 1.96)

        last_close = float(renko_df["close"].iloc[-1])
        price_pred = last_close * np.exp(corrected)
        price_lo = last_close * np.exp(lo)
        price_hi = last_close * np.exp(hi)

        return {
            "last_close": last_close,
            "cascade_depth": len(self.stages),
            "stage_contributions": stage_contributions,
            "raw_log_return": raw,
            "bias_applied": raw - corrected,
            "corrected_log_return": corrected,
            "predicted_price": price_pred,
            "ci95_low": price_lo,
            "ci95_high": price_hi,
            "direction": "UP" if corrected > 0 else ("DOWN" if corrected < 0 else "FLAT"),
            "expected_move_pct": 100.0 * (np.exp(corrected) - 1.0),
        }


# ============================================================
# 6. MAIN
# ============================================================

def main():
    print("=" * 70)
    print("RENKO CASCADE REGRESSION with ITERATIVE ERROR CORRECTION")
    print("=" * 70)

    df = get_mt5_data(symbol="EURUSD", days=60)
    if df is None or len(df) == 0:
        print("Не удалось получить данные MT5")
        return
    print(f"Загружено {len(df)} баров M5 EURUSD\n")

    model = RenkoCascadeModel(
        lookback=10,
        horizon=1,
        atr_multiplier=0.5,
        max_stages=15,
        patience=3,
        shrinkage=0.6,
        bias_method="conditional",
    )
    model.fit(df)

    result = model.predict(df)
    print("\n" + "=" * 70)
    print("ПРОГНОЗ")
    print("=" * 70)
    print(f"  Последняя цена закрытия:  {result['last_close']:.6f}")
    print(f"  Глубина каскада:          {result['cascade_depth']}")
    print(f"\n  Вклад стадий:")
    for sc in result["stage_contributions"]:
        print(f"    Stage {sc['stage']}: out={sc['raw_output']:+.6e}   "
              f"contrib={sc['contribution']:+.6e}   cum={sc['cumulative']:+.6e}")
    print(f"\n  Raw log-return:         {result['raw_log_return']:+.6e}")
    print(f"  Bias applied:           {result['bias_applied']:+.6e}")
    print(f"  Corrected log-return:   {result['corrected_log_return']:+.6e}")
    print(f"  Predicted price:        {result['predicted_price']:.6f}")
    print(f"  95% CI:                 [{result['ci95_low']:.6f} ... {result['ci95_high']:.6f}]")
    print(f"  Direction:              {result['direction']}")
    print(f"  Expected move:          {result['expected_move_pct']:+.4f}%")


if __name__ == "__main__":
    main()
