import numpy as np
import MetaTrader5 as mt5
import pandas as pd
from datetime import datetime, timedelta
from qiskit import QuantumCircuit, transpile, QuantumRegister, ClassicalRegister
from qiskit_aer import AerSimulator
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
import warnings

warnings.filterwarnings("ignore")


class MT5DataLoader:
    def __init__(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1):
        if not mt5.initialize():
            raise Exception("MetaTrader5 initialization failed")

        self.symbol = symbol
        self.timeframe = timeframe

    def get_historical_data(self, lookback_bars=1000):
        current_time = datetime.now()
        rates = mt5.copy_rates_from(
            self.symbol, self.timeframe, current_time, lookback_bars
        )

        if rates is None:
            raise Exception(f"Failed to get data for {self.symbol}")

        df = pd.DataFrame(rates)
        df["time"] = pd.to_datetime(df["time"], unit="s")
        return df


class BinaryPatternGenerator:
    def __init__(self, df, lookback=10):
        self.df = df
        self.lookback = lookback

    def direction_encoding(self):
        return (self.df["close"] > self.df["close"].shift(1)).astype(int)

    def momentum_encoding(self, threshold=0.0001):
        returns = self.df["close"].pct_change()
        return (returns.abs() > threshold).astype(int)

    def volume_encoding(self):
        return (
            self.df["tick_volume"]
            > self.df["tick_volume"].rolling(self.lookback).mean()
        ).astype(int)

    def convergence_encoding(self):
        ma_fast = self.df["close"].rolling(5).mean()
        ma_slow = self.df["close"].rolling(20).mean()
        return (ma_fast > ma_slow).astype(int)

    def volatility_encoding(self):
        volatility = self.df["high"] - self.df["low"]
        avg_volatility = volatility.rolling(20).mean()
        return (volatility > avg_volatility).astype(int)

    def rsi_encoding(self, period=14, threshold=50):
        delta = self.df["close"].diff()
        gain = (delta.where(delta > 0, 0)).ewm(alpha=1 / period).mean()
        loss = (-delta.where(delta < 0, 0)).ewm(alpha=1 / period).mean()
        rs = gain / loss
        rsi = 100 - (100 / (1 + rs))
        return (rsi > threshold).astype(int)

    def bollinger_encoding(self, window=20):
        ma = self.df["close"].rolling(window=window).mean()
        std = self.df["close"].rolling(window=window).std()
        upper = ma + (std * 2)
        lower = ma - (std * 2)
        return ((self.df["close"] - lower) / (upper - lower) > 0.5).astype(int)

    def get_all_patterns(self):
        patterns = {
            "direction": self.direction_encoding(),
            "momentum": self.momentum_encoding(),
            "volume": self.volume_encoding(),
            "convergence": self.convergence_encoding(),
            "volatility": self.volatility_encoding(),
            "rsi": self.rsi_encoding(),
            "bollinger": self.bollinger_encoding(),
        }
        return patterns


class QuantumFeatureGenerator:
    def __init__(self, num_qubits=8):
        self.num_qubits = num_qubits
        self.simulator = AerSimulator()
        self.scaler = MinMaxScaler()

    def create_quantum_circuit(self, market_data, current_price):
        qr = QuantumRegister(self.num_qubits, "qr")
        cr = ClassicalRegister(self.num_qubits, "cr")
        qc = QuantumCircuit(qr, cr)

        # Normalize data
        scaled_data = self.scaler.fit_transform(market_data.reshape(-1, 1)).flatten()

        # Create superposition
        for i in range(self.num_qubits):
            qc.h(qr[i])

        # Apply market data as phases
        for i in range(min(len(scaled_data), self.num_qubits)):
            angle = float(scaled_data[i] * np.pi)
            qc.ry(angle, qr[i])

        # Create entanglement
        for i in range(self.num_qubits - 1):
            qc.cx(qr[i], qr[i + 1])

        # Add the current price
        price_angle = float((current_price % 0.01) * 100 * np.pi)
        qc.ry(price_angle, qr[0])

        qc.measure(qr, cr)
        return qc

    def get_quantum_features(self, market_data, current_price):
        qc = self.create_quantum_circuit(market_data, current_price)
        compiled_circuit = transpile(qc, self.simulator, optimization_level=3)
        job = self.simulator.run(compiled_circuit, shots=2000)
        result = job.result()
        counts = result.get_counts()

        # Create the vector of quantum features
        feature_vector = np.zeros(2**self.num_qubits)
        total_shots = sum(counts.values())

        for bitstring, count in counts.items():
            index = int(bitstring, 2)
            feature_vector[index] = count / total_shots

        return feature_vector


class HybridQuantumBinaryPredictor:
    def __init__(self, num_qubits=8, lookback=10, forecast_window=5):
        self.num_qubits = num_qubits
        self.lookback = lookback
        self.forecast_window = forecast_window
        self.quantum_generator = QuantumFeatureGenerator(num_qubits)
        self.model = CatBoostClassifier(
            iterations=500,
            learning_rate=0.03,
            depth=6,
            loss_function="Logloss",
            verbose=False,
        )

    def prepare_features(self, df):
        """Prepare hybrid features"""
        pattern_generator = BinaryPatternGenerator(df, self.lookback)
        binary_patterns = pattern_generator.get_all_patterns()

        features = []
        labels = []

        # Fill in NaN in binary patterns
        for key in binary_patterns:
            binary_patterns[key] = binary_patterns[key].fillna(0)

        for i in range(self.lookback, len(df) - self.forecast_window):
            try:
                # Qunatum features
                market_data = df["close"].iloc[i - self.lookback : i].values
                current_price = df["close"].iloc[i]
                quantum_features = self.quantum_generator.get_quantum_features(
                    market_data, current_price
                )

                # Binary features
                binary_vector = []
                for key in binary_patterns:
                    window = binary_patterns[key].iloc[i - self.lookback : i].values
                    binary_vector.extend(
                        [
                            sum(window),  # Total number of signals
                            window[-1],  # Last signal
                            sum(window[-3:]),  # Last 3 signals
                        ]
                    )

                # Technical indicators
                rsi = binary_patterns["rsi"].iloc[i]
                bollinger = binary_patterns["bollinger"].iloc[i]
                momentum = binary_patterns["momentum"].iloc[i]

                # Combine all features
                feature_vector = np.concatenate(
                    [quantum_features, binary_vector, [rsi, bollinger, momentum]]
                )

                # Label: price direction
                future_price = df["close"].iloc[i + self.forecast_window]
                current_price = df["close"].iloc[i]
                label = 1 if future_price > current_price else 0

                features.append(feature_vector)
                labels.append(label)

            except Exception as e:
                print(f"Error at index {i}: {str(e)}")
                continue

        return np.array(features), np.array(labels)

    def train(self, df):
        """Train hybrid model"""
        print("Preparing features...")
        X, y = self.prepare_features(df)

        # Divide into train and test samples
        split_point = int(len(X) * 0.8)
        X_train, X_test = X[:split_point], X[split_point:]
        y_train, y_test = y[:split_point], y[split_point:]

        print("Training model...")
        self.model.fit(X_train, y_train, eval_set=(X_test, y_test))

        # Assess the model
        predictions = self.model.predict(X_test)
        probas = self.model.predict_proba(X_test)

        # Calculate the metrics
        metrics = {
            "accuracy": accuracy_score(y_test, predictions),
            "precision": precision_score(y_test, predictions),
            "recall": recall_score(y_test, predictions),
            "f1": f1_score(y_test, predictions),
        }

        # Features importance analysis
        feature_importance = self.model.feature_importances_
        quantum_importance = np.mean(feature_importance[: 2**self.num_qubits])
        binary_importance = np.mean(feature_importance[2**self.num_qubits :])

        metrics.update(
            {
                "quantum_importance": quantum_importance,
                "binary_importance": binary_importance,
                "test_predictions": predictions,
                "test_probas": probas,
                "test_actual": y_test,
            }
        )

        return metrics

    def predict_next(self, df):
        """Next movement forecast"""
        X, _ = self.prepare_features(df)
        if len(X) > 0:
            last_features = X[-1].reshape(1, -1)
            prediction_proba = self.model.predict_proba(last_features)[0]
            prediction = self.model.predict(last_features)[0]

            return {
                "direction": "UP" if prediction == 1 else "DOWN",
                "probability_up": prediction_proba[1],
                "probability_down": prediction_proba[0],
                "confidence": max(prediction_proba),
            }
        return None


def test_hybrid_model(symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1, periods=1000):
    """Full hybrid model test"""
    try:
        # Initialize MT5
        if not mt5.initialize():
            raise Exception("Failed to initialize MT5")

        # Load data
        print(f"Loading {periods} periods of {symbol} {timeframe} data...")
        loader = MT5DataLoader(symbol, timeframe)
        df = loader.get_historical_data(periods)

        # Create and train the model
        print("Creating hybrid model...")
        model = HybridQuantumBinaryPredictor()

        # Training and assessment
        print("Training and evaluating model...")
        metrics = model.train(df)

        # Display results
        print("\nModel Performance Metrics:")
        print(f"Accuracy: {metrics['accuracy']:.2%}")
        print(f"Precision: {metrics['precision']:.2%}")
        print(f"Recall: {metrics['recall']:.2%}")
        print(f"F1 Score: {metrics['f1']:.2%}")

        print("\nFeature Importance Analysis:")
        print(f"Quantum Features: {metrics['quantum_importance']:.2%}")
        print(f"Binary Features: {metrics['binary_importance']:.2%}")

        # Current forecast
        print("\nCurrent Market Prediction:")
        prediction = model.predict_next(df)
        if prediction:
            print(f"Predicted Direction: {prediction['direction']}")
            print(f"Up Probability: {prediction['probability_up']:.2%}")
            print(f"Down Probability: {prediction['probability_down']:.2%}")
            print(f"Confidence: {prediction['confidence']:.2%}")

        return model, metrics

    finally:
        mt5.shutdown()


if __name__ == "__main__":
    print("Starting Quantum-Binary Hybrid Trading System...")
    model, metrics = test_hybrid_model()
    print("\nSystem test completed.")
