import numpy as np
import MetaTrader5 as mt5
import pandas as pd
from datetime import datetime, timedelta
from qiskit import QuantumCircuit, transpile, QuantumRegister, ClassicalRegister
from qiskit_aer import AerSimulator
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import warnings

warnings.filterwarnings("ignore")


class MT5DataLoader:
    def __init__(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1):
        if not mt5.initialize():
            raise Exception("MetaTrader5 initialization failed")

        self.symbol = symbol
        self.timeframe = timeframe

    def get_historical_data(self, lookback_bars=1000):
        current_time = datetime.now()
        rates = mt5.copy_rates_from(
            self.symbol, self.timeframe, current_time, lookback_bars
        )

        if rates is None:
            raise Exception(f"Failed to get data for {self.symbol}")

        df = pd.DataFrame(rates)
        df["time"] = pd.to_datetime(df["time"], unit="s")
        return df


class EnhancedQuantumPredictor:
    def __init__(self, num_qubits=8):  # Decrease the number of quibits for stability
        self.num_qubits = num_qubits
        self.simulator = AerSimulator()
        self.scaler = MinMaxScaler()

    def create_qpe_circuit(self, market_data, current_price):
        """Create a simplified quantum circuit"""
        qr = QuantumRegister(self.num_qubits, "qr")
        cr = ClassicalRegister(self.num_qubits, "cr")
        qc = QuantumCircuit(qr, cr)

        # Normalize data
        scaled_data = self.scaler.fit_transform(market_data.reshape(-1, 1)).flatten()

        # Create superposition
        for i in range(self.num_qubits):
            qc.h(qr[i])

        # Apply market data as phases
        for i in range(min(len(scaled_data), self.num_qubits)):
            angle = float(scaled_data[i] * np.pi)  # Convert to float
            qc.ry(angle, qr[i])

        # Create entanglement
        for i in range(self.num_qubits - 1):
            qc.cx(qr[i], qr[i + 1])

        # Apply the current price
        price_angle = float(
            (current_price % 0.01) * 100 * np.pi
        )  # Use only the last 2 characters
        qc.ry(price_angle, qr[0])

        # Measure all qubits
        qc.measure(qr, cr)

        return qc

    def predict(self, market_data, current_price, features=None, shots=2000):
        """Simplified forecast"""
        # Trim input data
        if market_data.shape[0] > self.num_qubits:
            market_data = market_data[-self.num_qubits :]

        # Create and execute the circuit
        qc = self.create_qpe_circuit(market_data, current_price)
        compiled_circuit = transpile(qc, self.simulator, optimization_level=3)
        job = self.simulator.run(compiled_circuit, shots=shots)
        result = job.result()
        counts = result.get_counts()

        # Analyze results
        predictions = []
        total_shots = sum(counts.values())

        for bitstring, count in counts.items():
            # Use the number of ones in bitstring to define direction
            ones = bitstring.count("1")
            direction = ones / self.num_qubits  # Normalized direction

            # Predict change for not more than by 0.1%
            price_change = (direction - 0.5) * 0.001
            predicted_price = current_price * (1 + price_change)
            predictions.extend([predicted_price] * count)

        predicted_price = np.mean(predictions)
        up_probability = sum(1 for p in predictions if p > current_price) / len(
            predictions
        )

        confidence = 1 - np.std(predictions) / current_price

        return {
            "predicted_price": predicted_price,
            "up_probability": up_probability,
            "down_probability": 1 - up_probability,
            "confidence": confidence,
        }


class MarketPredictor:
    def __init__(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1, window_size=14):
        self.symbol = symbol
        self.timeframe = timeframe
        self.window_size = window_size
        self.quantum_predictor = EnhancedQuantumPredictor()
        self.data_loader = MT5DataLoader(symbol, timeframe)

    def prepare_features(self, df):
        """Prepare technical indicators"""
        df["sma"] = df["close"].rolling(window=self.window_size).mean()
        df["ema"] = df["close"].ewm(span=self.window_size).mean()
        df["std"] = df["close"].rolling(window=self.window_size).std()
        df["upper_band"] = df["sma"] + (df["std"] * 2)
        df["lower_band"] = df["sma"] - (df["std"] * 2)
        df["rsi"] = self.calculate_rsi(df["close"])
        df["momentum"] = df["close"] - df["close"].shift(self.window_size)
        df["rate_of_change"] = (df["close"] / df["close"].shift(1) - 1) * 100

        features = df[
            [
                "sma",
                "ema",
                "std",
                "upper_band",
                "lower_band",
                "rsi",
                "momentum",
                "rate_of_change",
            ]
        ].dropna()
        return features

    def calculate_rsi(self, prices, period=14):
        delta = prices.diff()
        gain = (delta.where(delta > 0, 0)).ewm(alpha=1 / period).mean()
        loss = (-delta.where(delta < 0, 0)).ewm(alpha=1 / period).mean()
        rs = gain / loss
        return 100 - (100 / (1 + rs))

    def predict(self):
        # Get data
        df = self.data_loader.get_historical_data(self.window_size + 50)
        features = self.prepare_features(df)

        if len(features) < self.window_size:
            raise ValueError("Insufficient data")

        # Get the last data for forecast
        latest_features = features.iloc[-self.window_size :].values
        current_price = df["close"].iloc[-1]

        # Perform the forecast, now pass 'features' as DataFrame
        prediction = self.quantum_predictor.predict(
            market_data=latest_features,
            current_price=current_price,
            features=features.iloc[-self.window_size :],  # Pass the last entries
        )

        prediction.update(
            {
                "timestamp": datetime.now(),
                "current_price": current_price,
                "rsi": features["rsi"].iloc[-1],
                "sma": features["sma"].iloc[-1],
                "ema": features["ema"].iloc[-1],
            }
        )

        return prediction


def evaluate_model(symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1, test_periods=100):
    """Estimate model accuracy"""
    predictor = MarketPredictor(symbol, timeframe)
    predictions = []
    actual_movements = []

    # Get historical data
    df = predictor.data_loader.get_historical_data(test_periods + 50)

    for i in range(test_periods):
        try:
            temp_df = df.iloc[: -(test_periods - i)]
            predictor_temp = MarketPredictor(symbol, timeframe)
            features_temp = predictor_temp.prepare_features(temp_df)

            # Get forecast data
            latest_features = features_temp.iloc[-predictor_temp.window_size :].values
            current_price = temp_df["close"].iloc[-1]

            # Perform a forecast passing all the necessary parameters
            prediction = predictor_temp.quantum_predictor.predict(
                market_data=latest_features,
                current_price=current_price,
                features=features_temp.iloc[-predictor_temp.window_size :],
            )

            predicted_movement = 1 if prediction["up_probability"] > 0.5 else 0
            predictions.append(predicted_movement)

            actual_price_next = df["close"].iloc[-(test_periods - i)]
            actual_price_current = df["close"].iloc[-(test_periods - i) - 1]
            actual_movement = 1 if actual_price_next > actual_price_current else 0
            actual_movements.append(actual_movement)

        except Exception as e:
            print(f"Error in evaluation: {e}")
            continue

    if len(predictions) > 0:
        metrics = {
            "accuracy": accuracy_score(actual_movements, predictions),
            "precision": precision_score(actual_movements, predictions),
            "recall": recall_score(actual_movements, predictions),
            "f1": f1_score(actual_movements, predictions),
        }
    else:
        metrics = {"accuracy": 0, "precision": 0, "recall": 0, "f1": 0}

    return metrics


if __name__ == "__main__":
    if not mt5.initialize():
        print("MetaTrader5 initialization failed")
        mt5.shutdown()
    else:
        try:
            symbol = "EURUSD"
            timeframe = mt5.TIMEFRAME_H1

            print("\nTest the model...")
            metrics = evaluate_model(symbol, timeframe, test_periods=100)

            print("\nModel quality metrics:")
            print(f"Accuracy: {metrics['accuracy']:.2%}")
            print(f"Precision: {metrics['precision']:.2%}")
            print(f"Recall: {metrics['recall']:.2%}")
            print(f"F1-score: {metrics['f1']:.2%}")

            print("\nCurrent forecast:")
            predictor = MarketPredictor(symbol, timeframe)
            df = predictor.data_loader.get_historical_data(predictor.window_size + 50)
            features = predictor.prepare_features(df)
            latest_features = features.iloc[-predictor.window_size :].values
            current_price = df["close"].iloc[-1]

            prediction = (
                predictor.predict()
            )  # Now the method passes all the parameters correctly

            print(f"Predicted price: {prediction['predicted_price']:.5f}")
            print(f"Growth probability: {prediction['up_probability']:.2%}")
            print(f"Fall probability: {prediction['down_probability']:.2%}")
            print(f"Forecast confidence: {prediction['confidence']:.2%}")
            print(f"Current price: {prediction['current_price']:.5f}")
            print(f"RSI: {prediction['rsi']:.2f}")
            print(f"SMA: {prediction['sma']:.5f}")
            print(f"EMA: {prediction['ema']:.5f}")

        finally:
            mt5.shutdown()
