import numpy as np
import pandas as pd
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime
import pickle
import MetaTrader5 as mt5
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans


def connect_to_metatrader():
    if not mt5.initialize():
        print(f"MetaTrader5 initialization error: {mt5.last_error()}")
        return False
    print(f"MetaTrader5 initialized successfully. Terminal: {mt5.terminal_info()}")
    return True


def get_eurusd_data(bars_count=5000):
    eurusd_rates = mt5.copy_rates_from_pos("EURUSD", mt5.TIMEFRAME_H1, 0, bars_count)
    if eurusd_rates is None:
        print(f"Data retrieval error: {mt5.last_error()}")
        return None
    df = pd.DataFrame(eurusd_rates)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    print(f"Received {len(df)} bars EURUSD H1")
    return df


def calc_rsi(series, period=14):
    """RSI calculation"""
    delta = series.diff()
    gain = delta.where(delta > 0, 0).rolling(window=period).mean()
    loss = -delta.where(delta < 0, 0).rolling(window=period).mean()
    rs = gain / loss
    return 100 - (100 / (1 + rs))


def add_indicators(df):
    """
    Add indicators divided into three groups:
    - price
    - time
    - volume
    """
    # --- PRICE INDICATORS ---
    # Trend
    df["ema_9"] = df["close"].ewm(span=9, adjust=False).mean()
    df["ema_21"] = df["close"].ewm(span=21, adjust=False).mean()
    df["ema_50"] = df["close"].ewm(span=50, adjust=False).mean()

    # EMA crossings
    df["ema_cross_9_21"] = (df["ema_9"] > df["ema_21"]).astype(int)
    df["ema_cross_21_50"] = (df["ema_21"] > df["ema_50"]).astype(int)

    # Volatility
    df["atr_14"] = df["high"].rolling(14).max() - df["low"].rolling(14).min()
    df["range"] = df["high"] - df["low"]
    df["range_ratio"] = df["range"] / df["range"].rolling(14).mean()

    # Impulse
    df["rsi_14"] = calc_rsi(df["close"], 14)
    df["momentum"] = df["close"] - df["close"].shift(10)

    # Candle characteristics
    df["body_size"] = abs(df["close"] - df["open"]) / (df["high"] - df["low"])
    df["upper_shadow"] = (df["high"] - df[["open", "close"]].max(axis=1)) / df["range"]
    df["lower_shadow"] = (df[["open", "close"]].min(axis=1) - df["low"]) / df["range"]

    # --- TIME INDICATORS ---
    # Week hour and day
    df["hour"] = df["time"].dt.hour
    df["day_of_week"] = df["time"].dt.dayofweek

    # Cyclic time features
    df["hour_sin"] = np.sin(2 * np.pi * df["hour"] / 24)
    df["hour_cos"] = np.cos(2 * np.pi * df["hour"] / 24)
    df["day_sin"] = np.sin(2 * np.pi * df["day_of_week"] / 7)
    df["day_cos"] = np.cos(2 * np.pi * df["day_of_week"] / 7)

    # Seasonality
    df["month"] = df["time"].dt.month
    df["week_of_year"] = df["time"].dt.isocalendar().week
    df["month_sin"] = np.sin(2 * np.pi * df["month"] / 12)
    df["month_cos"] = np.cos(2 * np.pi * df["month"] / 12)

    # Sessions (European, American, Asian)
    df["european_session"] = ((df["hour"] >= 7) & (df["hour"] < 16)).astype(int)
    df["american_session"] = ((df["hour"] >= 13) & (df["hour"] < 22)).astype(int)
    df["asian_session"] = ((df["hour"] >= 0) & (df["hour"] < 9)).astype(int)

    # --- VOLUME INDICATORS ---
    # Basic volume indicators
    df["tick_volume"] = df["tick_volume"].astype(float)  # Provide float type
    df["volume_change"] = df["tick_volume"].pct_change(1)
    df["volume_ma_14"] = df["tick_volume"].rolling(14).mean()
    df["rel_volume"] = df["tick_volume"] / df["volume_ma_14"]

    # Volume relative to range
    df["volume_per_range"] = df["tick_volume"] / df["range"]

    # Accumulation/distribution
    df["ad"] = 0.0
    close = df["close"].values
    high = df["high"].values
    low = df["low"].values
    volume = df["tick_volume"].values

    for i in range(1, len(df)):
        money_flow_mult = (
            ((close[i] - low[i]) - (high[i] - close[i])) / (high[i] - low[i])
            if high[i] != low[i]
            else 0
        )
        money_flow_volume = money_flow_mult * volume[i]
        df.loc[i, "ad"] = df.loc[i - 1, "ad"] + money_flow_volume

    # Volume impulse
    df["volume_momentum"] = df["tick_volume"] - df["tick_volume"].shift(5)
    df["volume_rsi"] = calc_rsi(df["tick_volume"], 14)

    return df


def create_price_patterns(df, lookback=5):
    """Creating price sequence patterns"""
    # Rise/fall labels
    df["price_change"] = df["close"].diff()
    df["binary_label"] = (df["price_change"] > 0).astype(int)

    # Lagging labels
    for i in range(1, lookback + 1):
        df[f"binary_lag_{i}"] = df["binary_label"].shift(i)

    # Key patterns
    df["rise_after_rise"] = (
        (df["binary_lag_1"] == 1) & (df["binary_lag_2"] == 1)
    ).astype(int)
    df["fall_after_rise"] = (
        (df["binary_lag_1"] == 0) & (df["binary_lag_2"] == 1)
    ).astype(int)
    df["fall_after_fall"] = (
        (df["binary_lag_1"] == 0) & (df["binary_lag_2"] == 0)
    ).astype(int)
    df["rise_after_fall"] = (
        (df["binary_lag_1"] == 1) & (df["binary_lag_2"] == 0)
    ).astype(int)

    # Triple patterns
    df["triple_rise"] = (
        (df["binary_lag_1"] == 1)
        & (df["binary_lag_2"] == 1)
        & (df["binary_lag_3"] == 1)
    ).astype(int)
    df["triple_fall"] = (
        (df["binary_lag_1"] == 0)
        & (df["binary_lag_2"] == 0)
        & (df["binary_lag_3"] == 0)
    ).astype(int)

    # Coded patterns (binary numbers)
    df["pattern_3bar"] = (
        df["binary_lag_1"] * 4 + df["binary_lag_2"] * 2 + df["binary_lag_3"] * 1
    )

    return df


def prepare_features_grouped(df, lookback=5):
    """
    Preparing features divided into three groups
    """
    # Add indicators
    df = add_indicators(df)
    df = create_price_patterns(df, lookback)

    # Remove NaN
    df.dropna(inplace=True)

    # Forecast labels
    df["price_change"] = df["close"].diff()
    df["binary_label"] = (df["price_change"] > 0).astype(int)
    df["forecast_label"] = df["binary_label"].shift(-1)
    df.dropna(inplace=True)

    # --- GROUPING FEATURES ---
    # Price features
    price_features = [
        "ema_9",
        "ema_21",
        "ema_50",
        "ema_cross_9_21",
        "ema_cross_21_50",
        "atr_14",
        "range",
        "range_ratio",
        "rsi_14",
        "momentum",
        "body_size",
        "upper_shadow",
        "lower_shadow",
        "rise_after_rise",
        "fall_after_rise",
        "fall_after_fall",
        "rise_after_fall",
        "triple_rise",
        "triple_fall",
        "pattern_3bar",
    ]

    # Time features
    time_features = [
        "hour_sin",
        "hour_cos",
        "day_sin",
        "day_cos",
        "month_sin",
        "month_cos",
        "european_session",
        "american_session",
        "asian_session",
    ]

    # Volume features
    volume_features = [
        "tick_volume",
        "volume_change",
        "volume_ma_14",
        "rel_volume",
        "volume_per_range",
        "ad",
        "volume_momentum",
        "volume_rsi",
    ]

    # Check the presence of all features in the dataframe
    all_features = price_features + time_features + volume_features
    for feature in all_features:
        if feature not in df.columns:
            print(f"Warning: {feature} feature not in dataframe")

    # Get the group of features
    price_data = df[price_features].values
    time_data = df[time_features].values
    volume_data = df[volume_features].values

    # Scale data
    price_scaler = StandardScaler()
    time_scaler = StandardScaler()
    volume_scaler = StandardScaler()

    scaled_price = price_scaler.fit_transform(price_data)
    scaled_time = time_scaler.fit_transform(time_data)
    scaled_volume = volume_scaler.fit_transform(volume_data)

    # Combine all features
    all_scaled_data = np.hstack((scaled_price, scaled_time, scaled_volume))

    # Labels
    labels = df["binary_label"].values
    forecast_labels = df["forecast_label"].values

    # Grouped features and scalers
    feature_groups = {
        "price": {
            "data": scaled_price,
            "scaler": price_scaler,
            "features": price_features,
        },
        "time": {"data": scaled_time, "scaler": time_scaler, "features": time_features},
        "volume": {
            "data": scaled_volume,
            "scaler": volume_scaler,
            "features": volume_features,
        },
        "all": {"data": all_scaled_data, "features": all_features},
    }

    return feature_groups, labels, forecast_labels, df


def create_state_clusters(feature_groups, n_clusters_per_group=3):
    """
    Create state clusters for each group of features
    """
    group_clusters = {}
    kmeans_models = {}

    # Create clusters for each group of features
    for group_name, group_data in feature_groups.items():
        if group_name != "all":  # Skip the general group
            print(f"Creating {n_clusters_per_group} clusters for group {group_name}...")

            # Apply KMeans for clustering
            kmeans = KMeans(n_clusters=n_clusters_per_group, random_state=42, n_init=10)
            clusters = kmeans.fit_predict(group_data["data"])

            group_clusters[group_name] = clusters
            kmeans_models[group_name] = kmeans

    return group_clusters, kmeans_models


def combine_state_clusters(group_clusters, labels):
    """
    Combine clusters from different groups into a unified matrix of states
    """
    # Get clusters for each group
    price_clusters = group_clusters["price"]
    time_clusters = group_clusters["time"]
    volume_clusters = group_clusters["volume"]

    # Create the state matrix (3x3x3 = 27 possible states, but we use only 9 for simplification)
    transition_matrix = np.zeros((9, 9))
    rise_matrix = np.zeros((9, 9))
    state_counts = np.zeros(9)
    rise_counts = np.zeros(9)

    # Mapping of cluster combinations to states
    state_mapping = {}
    for p in range(3):
        for t in range(3):
            for v in range(3):
                state_id = p * 3 * 3 + t * 3 + v
                if state_id < 9:  # Limit with 9 states
                    state_mapping[(p, t, v)] = state_id

    print(f"Created the display of clusters to states: {state_mapping}")

    # Transform cluster combinations into states
    states = np.zeros(len(price_clusters), dtype=int)
    for i in range(len(price_clusters)):
        p, t, v = price_clusters[i], time_clusters[i], volume_clusters[i]
        state_key = (
            p % 3,
            t % 3,
            v % 3,
        )  # Take by module 3 to accurately fit into the range
        if state_key in state_mapping:
            states[i] = state_mapping[state_key]

    print(f"State distribution: {np.bincount(states)}")

    # Analyze rise probabilities for each state
    for state in range(9):
        state_mask = states == state
        state_count = np.sum(state_mask)
        if state_count > 0:
            state_counts[state] = state_count
            rise_count = np.sum(labels[state_mask])
            rise_counts[state] = rise_count

    # Normalize to get the rise probabilities
    state_price_direction = {}
    for state in range(9):
        if state_counts[state] > 0:
            state_price_direction[state] = rise_counts[state] / state_counts[state]
        else:
            state_price_direction[state] = 0.5

    # Fill transition and growth matrixes
    for i in range(len(states) - 1):
        curr_state = states[i]
        next_state = states[i + 1]

        # Increase the transition counter
        transition_matrix[curr_state, next_state] += 1

        # If the next candle is bullish, increase the rise counter
        if i + 1 < len(labels) and labels[i + 1] == 1:
            rise_matrix[curr_state, next_state] += 1

    # Normalize transition matrixes
    state_transitions = np.zeros((9, 9))
    for i in range(9):
        row_sum = np.sum(transition_matrix[i, :])
        if row_sum > 0:
            state_transitions[i, :] = transition_matrix[i, :] / row_sum
        else:
            # If there are no data for the state, even distribution
            state_transitions[i, :] = 1 / 9

    # Calculate rise probability matrix
    rise_probability_matrix = np.zeros((9, 9))
    for i in range(9):
        for j in range(9):
            if transition_matrix[i, j] > 0:
                rise_probability_matrix[i, j] = (
                    rise_matrix[i, j] / transition_matrix[i, j]
                )
            else:
                rise_probability_matrix[i, j] = 0.5

    return (
        states,
        state_transitions,
        rise_probability_matrix,
        state_price_direction,
        state_mapping,
    )


def predict_with_matrix(state_transitions, rise_probability_matrix, current_state):
    """
    Forecasting the next price movement using the transition matrixes,
    considering all 27 possible states
    """
    # Probabilities of transition to the next state
    next_state_probs = state_transitions[current_state, :]

    # Calculate weighted rise probability considering all possible transitions
    weighted_prob = 0
    total_prob = 0

    for next_state, prob in enumerate(next_state_probs):
        weighted_prob += prob * rise_probability_matrix[current_state, next_state]
        total_prob += prob

    # Normalization (if necessary)
    if total_prob > 0:
        weighted_prob = weighted_prob / total_prob

    # Forecast
    prediction = 1 if weighted_prob > 0.5 else 0
    confidence = max(weighted_prob, 1 - weighted_prob)

    # Most probable next state
    next_state = np.argmax(next_state_probs)

    return prediction, confidence, next_state, weighted_prob


def get_state_for_features(
    kmeans_models, state_mapping, price_features, time_features, volume_features
):
    """
    Defining a state for new features
    """
    # Get clusters for new data
    price_cluster = kmeans_models["price"].predict([price_features])[0] % 3
    time_cluster = kmeans_models["time"].predict([time_features])[0] % 3
    volume_cluster = kmeans_models["volume"].predict([volume_features])[0] % 3

    # Transform to state
    state_key = (price_cluster, time_cluster, volume_cluster)
    if state_key in state_mapping:
        return state_mapping[state_key]
    else:
        return 0  # Default state


def save_model(
    feature_groups,
    kmeans_models,
    state_transitions,
    rise_probability_matrix,
    state_price_direction,
    state_mapping,
    filename="hmm_eurusd_h1_matrix.bin",
):
    with open(filename, "wb") as f:
        pickle.dump(
            {
                "feature_groups": {
                    k: {"scaler": v["scaler"], "features": v.get("features", None)}
                    for k, v in feature_groups.items()
                    if k != "all"
                },
                "kmeans_models": kmeans_models,
                "state_transitions": state_transitions,
                "rise_probability_matrix": rise_probability_matrix,
                "state_price_direction": state_price_direction,
                "state_mapping": state_mapping,
                "created_at": datetime.now(),
            },
            f,
        )
    print(f"Model saved to file {filename}")


def visualize_matrices(states, state_transitions, rise_probability_matrix):
    """
    Visualization of transition matrixes and rise probabilities
    """
    # Visualize state distribution
    plt.figure(figsize=(12, 6))
    state_counts = np.bincount(states, minlength=9)
    plt.bar(range(9), state_counts)
    plt.title("Distribution of states in training data")
    plt.xlabel("State")
    plt.ylabel("Number of observations")
    plt.xticks(range(9))
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    plt.savefig("state_distribution.png")
    plt.close()

    # Visualize transition matrixes
    plt.figure(figsize=(10, 8))
    plt.imshow(state_transitions, cmap="viridis", interpolation="none")
    plt.colorbar(label="Transition probability")
    plt.title("Matrix of transitions between states")
    plt.xlabel("Next state")
    plt.ylabel("Current state")
    plt.xticks(range(9))
    plt.yticks(range(9))

    # Add text values
    for i in range(9):
        for j in range(9):
            text_color = "white" if state_transitions[i, j] < 0.5 else "black"
            plt.text(
                j,
                i,
                f"{state_transitions[i, j]:.2f}",
                ha="center",
                va="center",
                color=text_color,
            )

    plt.savefig("transition_matrix.png")
    plt.close()

    # Visualize rise probability matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(
        rise_probability_matrix, cmap="RdYlGn", interpolation="none", vmin=0, vmax=1
    )
    plt.colorbar(label="Rise probability")
    plt.title("Rise probability when transitioning between states")
    plt.xlabel("Next state")
    plt.ylabel("Current state")
    plt.xticks(range(9))
    plt.yticks(range(9))

    # Add text values
    for i in range(9):
        for j in range(9):
            text_color = "white" if rise_probability_matrix[i, j] < 0.5 else "black"
            plt.text(
                j,
                i,
                f"{rise_probability_matrix[i, j]:.2f}",
                ha="center",
                va="center",
                color=text_color,
            )

    plt.savefig("rise_probability_matrix.png")
    plt.close()


def evaluate_model_with_all_transitions(
    test_states, state_transitions, rise_probability_matrix, test_labels
):
    """
    Evaluate the model considering all 27 possible transitions
    """
    # Error matrix
    confusion = np.zeros((2, 2))

    # Evaluate forecasts
    predictions = []
    confidences = []
    weighted_probabilities = []

    for i in range(len(test_states) - 1):
        curr_state = test_states[i]

        # Get all transition probabilities from the current state
        next_state_probs = state_transitions[curr_state, :]

        # Calculate weighted rise considering all 27 possible transitions
        weighted_prob = 0
        total_transition_probability = 0

        # Iterate over all possible combinations (3x3x3=27 states, but use 9 for simplification)
        for next_state in range(9):
            transition_prob = next_state_probs[next_state]
            rise_prob = rise_probability_matrix[curr_state, next_state]
            weighted_prob += transition_prob * rise_prob
            total_transition_probability += transition_prob

        # Normalize (if necessary)
        if total_transition_probability > 0:
            weighted_prob = weighted_prob / total_transition_probability

        # Forecast
        prediction = 1 if weighted_prob > 0.5 else 0
        confidence = max(weighted_prob, 1 - weighted_prob)

        predictions.append(prediction)
        confidences.append(confidence)
        weighted_probabilities.append(weighted_prob)

        # Update error matrixes
        if i + 1 < len(test_labels):
            actual = test_labels[i + 1]
            confusion[prediction, actual] += 1

    # Calculate different efficiency metrics
    accuracy = np.sum(confusion[0, 0] + confusion[1, 1]) / np.sum(confusion)
    precision = (
        confusion[1, 1] / (confusion[1, 0] + confusion[1, 1])
        if (confusion[1, 0] + confusion[1, 1]) > 0
        else 0
    )
    recall = (
        confusion[1, 1] / (confusion[0, 1] + confusion[1, 1])
        if (confusion[0, 1] + confusion[1, 1]) > 0
        else 0
    )
    f1_score = (
        2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    )

    print(f"Model accuracy on test sample: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-score: {f1_score:.4f}")
    print(f"Error matrix:\n{confusion}")

    # Analyze forecast confidence
    bins = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0]
    confidence_groups = np.digitize(confidences, bins)
    accuracy_by_confidence = {}

    for bin_idx in range(1, len(bins)):
        bin_mask = confidence_groups == bin_idx
        if np.sum(bin_mask) > 0:
            bin_accuracy = np.mean(
                [
                    1 if predictions[i] == test_labels[i + 1] else 0
                    for i in range(len(predictions))
                    if bin_mask[i] and i + 1 < len(test_labels)
                ]
            )
            accuracy_by_confidence[f"{bins[bin_idx-1]:.2f}-{bins[bin_idx]:.2f}"] = (
                bin_accuracy,
                np.sum(bin_mask),
            )

    print("\nAccuracy by confidence groups:")
    for conf_range, (acc, count) in sorted(accuracy_by_confidence.items()):
        print(f"Confidence {conf_range}: accuracy {acc:.4f} (number: {count})")

    return accuracy, confusion, predictions, confidences, weighted_probabilities


def main():
    # Connect to MT5
    if not connect_to_metatrader():
        return

    # Get data
    df = get_eurusd_data(bars_count=5000)
    if df is None:
        return

    # Prepare features by groups
    feature_groups, labels, forecast_labels, df = prepare_features_grouped(
        df, lookback=5
    )

    # Divide data
    train_size = int(0.8 * len(labels))
    train_labels = labels[:train_size]
    test_labels = labels[train_size:]

    # Create feature groups for training and testing
    train_groups = {}
    test_groups = {}

    for group_name, group_data in feature_groups.items():
        train_groups[group_name] = {
            "data": group_data["data"][:train_size],
            "scaler": group_data.get("scaler"),
            "features": group_data.get("features"),
        }

        test_groups[group_name] = {
            "data": group_data["data"][train_size:],
            "scaler": group_data.get("scaler"),
            "features": group_data.get("features"),
        }

    # Create state cluster
    group_clusters, kmeans_models = create_state_clusters(train_groups)

    # Combine clusters into a matrix
    print("Creating a matrix of states...")
    (
        states,
        state_transitions,
        rise_probability_matrix,
        state_price_direction,
        state_mapping,
    ) = combine_state_clusters(group_clusters, train_labels)

    # Visualizing matrixes
    visualize_matrices(states, state_transitions, rise_probability_matrix)

    # Evaluate the model
    print("Evaluating model...")

    # Get states for text data
    test_states = []

    for i in range(len(test_groups["price"]["data"])):
        price_features = test_groups["price"]["data"][i]
        time_features = test_groups["time"]["data"][i]
        volume_features = test_groups["volume"]["data"][i]

        # Get the state
        state = get_state_for_features(
            kmeans_models, state_mapping, price_features, time_features, volume_features
        )
        test_states.append(state)

    # Evaluate the model considering all possible transitions
    accuracy, confusion, predictions, confidences, weighted_probabilities = (
        evaluate_model_with_all_transitions(
            test_states, state_transitions, rise_probability_matrix, test_labels
        )
    )

    # Analyze states
    print("\nAnalyzing states:")
    for state in range(9):
        count = np.sum(states == state)
        percentage = count / len(states) * 100
        direction = "rise" if state_price_direction[state] > 0.5 else "fall"
        print(
            f"State {state}: {direction} (rise probability: {state_price_direction[state]:.4f}, {percentage:.2f}%)"
        )

    # Analyze transition matrix
    print("\nMost probable state transitions:")
    for i in range(9):
        next_state = np.argmax(state_transitions[i, :])
        prob = state_transitions[i, next_state]
        rise_prob = rise_probability_matrix[i, next_state]
        direction = "rise" if rise_prob > 0.5 else "fall"
        print(
            f"{i} -> {next_state}: transition probability {prob:.4f}, {direction} (rise probability: {rise_prob:.4f})"
        )

    # Forecast for the last state considering all transitions
    last_state = test_states[-1]

    # Transition probabilities to the next state
    next_state_probs = state_transitions[last_state, :]

    # Calculate the weighted rise probability considering all possible transitions
    weighted_rise_prob = 0
    for next_state, prob in enumerate(next_state_probs):
        weighted_rise_prob += prob * rise_probability_matrix[last_state, next_state]

    # Forecast
    prediction = "rise" if weighted_rise_prob > 0.5 else "fall"
    confidence = max(weighted_rise_prob, 1 - weighted_rise_prob)

    print(f"\nForecast for the next bar (considering all 27 transitions): {prediction}")
    print(f"Confidence: {confidence:.4f}")
    print(f"Current state: {last_state}")
    print(f"Weighted rise probability: {weighted_rise_prob:.4f}")

    # Analyze the effect of feature groups
    print("\nAnalyzing the effect of feature groups on forecasts:")

    # Get the last bar features
    last_price = test_groups["price"]["data"][-1]
    last_time = test_groups["time"]["data"][-1]
    last_volume = test_groups["volume"]["data"][-1]

    # Get clusters
    price_cluster = kmeans_models["price"].predict([last_price])[0] % 3
    time_cluster = kmeans_models["time"].predict([last_time])[0] % 3
    volume_cluster = kmeans_models["volume"].predict([last_volume])[0] % 3

    print(
        f"Last bar: price cluster = {price_cluster}, time cluster = {time_cluster}, volume cluster = {volume_cluster}"
    )

    # Importance of each group
    print("\nMost important features in each group:")

    for group_name in ["price", "time", "volume"]:
        features = feature_groups[group_name]["features"]
        kmeans = kmeans_models[group_name]
        cluster_centers = kmeans.cluster_centers_

        for cluster_idx in range(3):
            center = cluster_centers[cluster_idx]
            # Get the importance of each feature as its deviation from zero in the cluster center
            importances = np.abs(center)
            # Sort features by importance
            sorted_idx = np.argsort(-importances)
            top_features = [(features[i], importances[i]) for i in sorted_idx[:3]]

            print(f"Cluster {cluster_idx} of {group_name} group. Top 3 features:")
            for feature, importance in top_features:
                print(f"  - {feature}: {importance:.4f}")

    # Save the model
    save_model(
        feature_groups,
        kmeans_models,
        state_transitions,
        rise_probability_matrix,
        state_price_direction,
        state_mapping,
    )

    # Shutdown
    mt5.shutdown()
    print("MetaTrader5 disabled")


def load_model(filename="hmm_eurusd_h1_matrix.bin"):
    """
    Uploading model from file
    """
    with open(filename, "rb") as f:
        model_data = pickle.load(f)

    print(f"Model uploaded from {filename} file, created: {model_data['created_at']}")
    return model_data


def predict_next_bar(model_data, new_data):
    """
    Forecast for the new candle using the trained model
    """
    # Prepare data
    df = new_data.copy()
    df = add_indicators(df)
    df = create_price_patterns(df)

    # Remove NaN
    df.dropna(inplace=True)

    # Get the last candle
    last_row = df.iloc[-1]

    # Extract features
    price_features = [
        last_row[feature]
        for feature in model_data["feature_groups"]["price"]["features"]
    ]
    time_features = [
        last_row[feature]
        for feature in model_data["feature_groups"]["time"]["features"]
    ]
    volume_features = [
        last_row[feature]
        for feature in model_data["feature_groups"]["volume"]["features"]
    ]

    # Scale features
    price_scaler = model_data["feature_groups"]["price"]["scaler"]
    time_scaler = model_data["feature_groups"]["time"]["scaler"]
    volume_scaler = model_data["feature_groups"]["volume"]["scaler"]

    scaled_price = price_scaler.transform([price_features])[0]
    scaled_time = time_scaler.transform([time_features])[0]
    scaled_volume = volume_scaler.transform([volume_features])[0]

    # Get state
    kmeans_models = model_data["kmeans_models"]
    state_mapping = model_data["state_mapping"]

    state = get_state_for_features(
        kmeans_models, state_mapping, scaled_price, scaled_time, scaled_volume
    )

    # Forecast considering all 27 possible transitions
    state_transitions = model_data["state_transitions"]
    rise_probability_matrix = model_data["rise_probability_matrix"]

    # Transition probabilities
    next_state_probs = state_transitions[state, :]

    # Calculate weighted rise probability
    weighted_rise_prob = 0
    for next_state, prob in enumerate(next_state_probs):
        weighted_rise_prob += prob * rise_probability_matrix[state, next_state]

    # Forecast
    prediction = "rise" if weighted_rise_prob > 0.5 else "fall"
    confidence = max(weighted_rise_prob, 1 - weighted_rise_prob)

    return {
        "prediction": prediction,
        "confidence": confidence,
        "current_state": state,
        "weighted_rise_probability": weighted_rise_prob,
        "next_state_probabilities": next_state_probs,
    }


if __name__ == "__main__":
    main()
