import pandas as pd
import numpy as np
import MetaTrader5 as mt5
from datetime import datetime, timedelta
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report


# Initialize MT5
def get_mt5_data(symbol="EURUSD", timeframe=mt5.TIMEFRAME_M5, days=30):
    if not mt5.initialize():
        print(f"MT5 initialization error: {mt5.last_error()}")
        return None

    start_date = datetime.now() - timedelta(days=days)
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, datetime.now())
    mt5.shutdown()

    df = pd.DataFrame(rates)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    return df


# Create Renko bars
def create_renko_bars(df, brick_size=None):
    if brick_size is None:
        # Calculate ATR to define the block size
        df["tr"] = np.maximum(
            df["high"] - df["low"],
            np.maximum(
                np.abs(df["high"] - df["close"].shift(1)),
                np.abs(df["low"] - df["close"].shift(1)),
            ),
        )
        df["atr"] = df["tr"].rolling(window=14).mean()
        brick_size = df["atr"].mean() * 0.5
        print(f"Renko block size: {brick_size:.5f}")

    renko_bars = []
    current_price = df.iloc[0]["close"]
    current_direction = None
    bar_open = current_price
    bar_time = df.iloc[0]["time"]
    volume_sum = 0

    for i, row in df.iterrows():
        volume_sum += row["tick_volume"]
        price = row["close"]
        price_change = price - current_price
        num_bricks = int(abs(price_change) / brick_size)

        if num_bricks > 0:
            direction = 1 if price_change > 0 else -1

            if current_direction is not None and direction != current_direction:
                num_bricks += 1

            for _ in range(num_bricks):
                if current_direction is None or current_direction == direction:
                    bar_close = bar_open + (brick_size * direction)
                else:
                    bar_open = bar_open + (brick_size * current_direction)
                    bar_close = bar_open + (brick_size * direction)

                renko_bars.append(
                    {
                        "time": bar_time,
                        "open": bar_open,
                        "high": max(bar_open, bar_close),
                        "low": min(bar_open, bar_close),
                        "close": bar_close,
                        "volume": volume_sum,
                        "direction": direction,
                    }
                )

                bar_open = bar_close
                bar_time = row["time"]
                current_direction = direction

            volume_sum = 0
            current_price = current_price + (num_bricks * brick_size * direction)

    renko_df = pd.DataFrame(renko_bars)

    # Add features for sequences
    renko_df["consec_up"] = (renko_df["direction"] > 0).astype(int)
    renko_df["consec_down"] = (renko_df["direction"] < 0).astype(int)

    # Counters of consecutive movements
    for col in ["consec_up", "consec_down"]:
        g = renko_df[col].ne(renko_df[col].shift()).cumsum()
        renko_df[f"{col}_streak"] = renko_df.groupby(g)[col].cumsum()

    return renko_df, brick_size


# Prepare features
def prepare_features(renko_df, lookback=5):
    features = []
    targets = []

    for i in range(lookback, len(renko_df) - 1):
        window = renko_df.iloc[i - lookback : i]

        feature_dict = {
            # Directions of the last n bars
            **{f"dir_{j}": window["direction"].iloc[-(j + 1)] for j in range(lookback)},
            # Statistics on movement
            "up_ratio": (window["direction"] > 0).mean(),
            "max_up_streak": window["consec_up_streak"].max(),
            "max_down_streak": window["consec_down_streak"].max(),
            "last_up_streak": window["consec_up_streak"].iloc[-1],
            "last_down_streak": window["consec_down_streak"].iloc[-1],
            # Volume
            "last_volume": window["volume"].iloc[-1],
            "avg_volume": window["volume"].mean(),
            "volume_ratio": (
                window["volume"].iloc[-1] / window["volume"].mean()
                if window["volume"].mean() > 0
                else 1
            ),
        }

        features.append(feature_dict)

        # Direction of the next bar (1 - up, 0 - down)
        next_direction = 1 if renko_df.iloc[i + 1]["direction"] > 0 else 0
        targets.append(next_direction)

    return pd.DataFrame(features), np.array(targets)


# Train the model
def train_model(X, y, test_size=0.3):
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42, shuffle=True
    )

    params = {
        "iterations": 300,
        "learning_rate": 0.05,
        "depth": 5,
        "loss_function": "Logloss",
        "random_seed": 42,
        "verbose": False,
    }

    model = CatBoostClassifier(**params)
    model.fit(
        X_train,
        y_train,
        eval_set=(X_test, y_test),
        early_stopping_rounds=30,
        verbose=False,
    )

    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy on the test sample: {accuracy:.4f}")
    print(f"Accuracy in %: {accuracy * 100:.2f}%")
    print(classification_report(y_test, y_pred))

    # Feature importance
    importance = model.get_feature_importance(prettified=True)
    print("Top 5 features:")
    print(importance.head(5))

    return model, X_test, y_test


# Forecast the next bar
def predict_next_bar(model, renko_df, lookback=5, feature_names=None):
    if len(renko_df) < lookback:
        return {"error": "Insufficient data"}

    window = renko_df.iloc[-lookback:]

    feature_dict = {
        **{f"dir_{j}": window["direction"].iloc[-(j + 1)] for j in range(lookback)},
        "up_ratio": (window["direction"] > 0).mean(),
        "max_up_streak": window["consec_up_streak"].max(),
        "max_down_streak": window["consec_down_streak"].max(),
        "last_up_streak": window["consec_up_streak"].iloc[-1],
        "last_down_streak": window["consec_down_streak"].iloc[-1],
        "last_volume": window["volume"].iloc[-1],
        "avg_volume": window["volume"].mean(),
        "volume_ratio": (
            window["volume"].iloc[-1] / window["volume"].mean()
            if window["volume"].mean() > 0
            else 1
        ),
    }

    X_pred = pd.DataFrame([feature_dict])

    # Make sure all features are present
    if feature_names:
        for feature in feature_names:
            if feature not in X_pred.columns:
                X_pred[feature] = 0
        X_pred = X_pred[feature_names]

    prob = model.predict_proba(X_pred)[0]
    prediction = model.predict(X_pred)[0]

    return {
        "prediction": "UP" if prediction == 1 else "DOWN",
        "probability": prob[prediction],
        "prob_up": prob[1],
        "prob_down": prob[0],
        "signal": "BUY" if prob[1] > 0.75 else "SELL" if prob[0] > 0.75 else "NEUTRAL",
    }


# Main function
def main():
    # Get EURUSD data
    print("Load EURUSD data from MetaTrader5...")
    df = get_mt5_data(symbol="EURUSD", days=60)

    if df is None or len(df) == 0:
        print("Failed to get data")
        return

    print(f"Loaded {len(df)} bars")

    # Create Renko bars
    print("Creating Renko bars...")
    renko_df, brick_size = create_renko_bars(df)
    print(f"Created {len(renko_df)} Renko bars")

    # Prepare features
    print("Preparing features...")
    X, y = prepare_features(renko_df)
    print(f"Prepared {len(X)} samples")

    # Train the model
    print("Training model...")
    model, X_test, y_test = train_model(X, y)

    # Forecast the next bar
    feature_names = X.columns.tolist()
    prediction = predict_next_bar(model, renko_df, feature_names=feature_names)

    print("\nFORECASTING NEXT BAR:")
    for k, v in prediction.items():
        print(f"{k}: {v}")

    # Info on the last bars
    print("\Last 5 Renko bars:")
    print(renko_df.tail(5)[["time", "open", "close", "direction"]])


if __name__ == "__main__":
    main()
