import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
import MetaTrader5 as mt5
from datetime import datetime, timedelta


def calculate_rsi(prices, period=14):
    """Calculate RSI"""
    # Replace NaN with previous values
    prices = pd.Series(prices).ffill().values

    deltas = np.diff(prices)
    seed = deltas[: period + 1]
    up = seed[seed >= 0].mean() if len(seed[seed >= 0]) > 0 else 0
    down = -seed[seed < 0].mean() if len(seed[seed < 0]) > 0 else 0

    # Avoid zero divide
    if down == 0:
        down = 0.00001

    relative_strength = up / down
    rsi = np.zeros_like(prices)
    rsi[:period] = 100.0 - 100.0 / (1.0 + relative_strength)

    for i in range(period, len(prices)):
        delta = deltas[i - 1]
        if delta > 0:
            upval = delta
            downval = 0.0
        else:
            upval = 0.0
            downval = -delta

        up = (up * (period - 1) + upval) / period
        down = (down * (period - 1) + downval) / period

        relative_strength = up / down
        rsi[i] = 100.0 - 100.0 / (1.0 + relative_strength)

    return rsi


def calculate_macd(prices, fast=12, slow=26, signal=9):
    """Calculate MACD"""
    exp1 = prices.ewm(span=fast, adjust=False).mean()
    exp2 = prices.ewm(span=slow, adjust=False).mean()
    macd = exp1 - exp2
    signal_line = macd.ewm(span=signal, adjust=False).mean()
    return macd, signal_line


def calculate_bollinger_bands(prices, period=20, num_std=2):
    """Calculate Bollinger Bands"""
    rolling_mean = prices.rolling(window=period).mean()
    rolling_std = prices.rolling(window=period).std()
    upper_band = rolling_mean + (rolling_std * num_std)
    lower_band = rolling_mean - (rolling_std * num_std)
    return upper_band, lower_band


class MT5DataLoader:
    def __init__(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1):
        self.symbol = symbol
        self.timeframe = timeframe
        if not mt5.initialize():
            raise Exception("MetaTrader5 initialization failed")

    def __del__(self):
        mt5.shutdown()

    def load_data(self, lookback_days=100):
        """Download historical data from MT5"""
        current_time = datetime.now()
        past_time = current_time - timedelta(days=lookback_days)

        rates = mt5.copy_rates_range(
            self.symbol, self.timeframe, past_time, current_time
        )

        if rates is None:
            raise Exception(f"Failed to load data for {self.symbol}")

        df = pd.DataFrame(rates)
        df["time"] = pd.to_datetime(df["time"], unit="s")
        df.set_index("time", inplace=True)

        df = df.rename(
            columns={
                "open": "open",
                "high": "high",
                "low": "low",
                "close": "close",
                "tick_volume": "volume",
                "spread": "spread",
                "real_volume": "real_volume",
            }
        )

        return df

    def get_current_price(self):
        """Get the current price"""
        tick = mt5.symbol_info_tick(self.symbol)
        if tick is None:
            raise Exception(f"Failed to get current price for {self.symbol}")
        return (tick.bid + tick.ask) / 2


class PatternAnalyzer:
    def __init__(self, min_pattern_len=5, max_pattern_len=8):
        self.min_pattern_len = min_pattern_len
        self.max_pattern_len = max_pattern_len
        self.patterns_dict = {}

    def _encode_price_movement(self, prices):
        """Code price movmement into binary sequence"""
        return [1 if prices[i] > prices[i - 1] else 0 for i in range(1, len(prices))]

    def _extract_patterns(self, encoded_moves, pattern_len):
        """Extract patterns of the specified length"""
        patterns = {}
        for i in range(len(encoded_moves) - pattern_len):
            pattern = tuple(encoded_moves[i : i + pattern_len])
            next_move = encoded_moves[i + pattern_len]
            if pattern not in patterns:
                patterns[pattern] = {"up": 0, "down": 0, "total": 0}
            patterns[pattern]["up" if next_move == 1 else "down"] += 1
            patterns[pattern]["total"] += 1
        return patterns

    def analyze_patterns(self, prices):
        """Analyze all patterns of separate lengths"""
        encoded_moves = self._encode_price_movement(prices)

        for length in range(self.min_pattern_len, self.max_pattern_len + 1):
            patterns = self._extract_patterns(encoded_moves, length)
            for pattern, stats in patterns.items():
                if stats["total"] >= 50:
                    winrate = stats["up"] / stats["total"]
                    reliability = stats["total"] * winrate * (1 - abs(0.5 - winrate))
                    self.patterns_dict[pattern] = {
                        "length": length,
                        "winrate": winrate,
                        "frequency": stats["total"],
                        "reliability": reliability,
                    }

    def get_active_patterns(self, prices):
        """Define current active patterns"""
        encoded_moves = self._encode_price_movement(prices)
        active_patterns = []

        for pattern in self.patterns_dict:
            length = self.patterns_dict[pattern]["length"]
            if len(encoded_moves) >= length:
                current_pattern = tuple(encoded_moves[-length:])
                if current_pattern == pattern:
                    active_patterns.append(
                        {"pattern": pattern, "stats": self.patterns_dict[pattern]}
                    )

        return active_patterns


class NeuroSymbolicTrader:
    def __init__(self, symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1):
        self.pattern_analyzer = PatternAnalyzer()
        self.model = None
        self.scaler = MinMaxScaler()
        self.data_loader = MT5DataLoader(symbol, timeframe)

    def prepare_data(self, df):
        """Prepare training data"""
        # Handle price gaps first
        df["close"] = df["close"].ffill()

        # Calculate technical indicators
        df["rsi"] = calculate_rsi(df["close"].values)
        df["macd"], df["macd_signal"] = calculate_macd(df["close"])
        df["bb_upper"], df["bb_lower"] = calculate_bollinger_bands(df["close"])

        # Fill the remaining gaps using direct method
        df = df.ffill().bfill()

        # Analyze the patterns
        self.pattern_analyzer.analyze_patterns(df["close"].values)

        # Prepare the features
        features = []
        labels = []
        sequence_length = 10

        for i in range(sequence_length, len(df) - 1):
            # Basic features
            # Extract and check the sequence
            sequence = df[["close", "rsi", "macd", "bb_upper", "bb_lower"]].values[
                i - sequence_length : i
            ]

            # Check for NaN
            if np.isnan(sequence).any():
                continue

            # Normalize data
            try:
                sequence = self.scaler.fit_transform(sequence)
            except ValueError:
                continue

            # Get active patterns
            active_patterns = self.pattern_analyzer.get_active_patterns(
                df["close"].values[:i]
            )
            pattern_features = np.zeros(3)  # winrate, frequency, reliability

            if active_patterns:
                pattern_stats = active_patterns[0]["stats"]
                pattern_features = np.array(
                    [
                        pattern_stats["winrate"],
                        pattern_stats["frequency"] / 1000,
                        pattern_stats["reliability"],
                    ]
                )

            # Combine all features
            combined_features = np.column_stack(
                (sequence, np.tile(pattern_features, (sequence_length, 1)))
            )
            features.append(combined_features)

            # Label: 1 if the price rose, 0 - if it fell
            labels.append(1 if df["close"].values[i + 1] > df["close"].values[i] else 0)

        return np.array(features), np.array(labels)

    def build_model(self, input_shape):
        """Create neural network model"""
        model = Sequential(
            [
                LSTM(128, input_shape=input_shape, return_sequences=True),
                Dropout(0.3),
                LSTM(64),
                Dropout(0.2),
                Dense(32, activation="relu"),
                Dense(1, activation="sigmoid"),
            ]
        )

        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss="binary_crossentropy",
            metrics=["accuracy"],
        )

        return model

    def train(self, lookback_days=100, validation_split=0.2):
        """Train system on historical data"""
        df = self.data_loader.load_data(lookback_days)

        X, y = self.prepare_data(df)

        self.model = self.build_model(input_shape=(X.shape[1], X.shape[2]))

        history = self.model.fit(
            X, y, validation_split=validation_split, epochs=50, batch_size=32, verbose=1
        )

        return history

    def predict_next_movement(self):
        """Get a trading signal for the next movement"""
        df = self.data_loader.load_data(lookback_days=10)

        X, _ = self.prepare_data(df)

        if len(X) == 0:
            return None

        prediction = self.model.predict(X[-1:])
        active_patterns = self.pattern_analyzer.get_active_patterns(df["close"].values)

        confidence = float(prediction[0][0])

        if active_patterns:
            pattern_confidence = active_patterns[0]["stats"]["reliability"]
            confidence = 0.7 * confidence + 0.3 * pattern_confidence

        current_price = self.data_loader.get_current_price()

        return {
            "direction": "BUY" if confidence > 0.5 else "SELL",
            "confidence": confidence,
            "active_patterns": active_patterns,
            "current_price": current_price,
        }


# Create a trading system
trader = NeuroSymbolicTrader(symbol="EURUSD", timeframe=mt5.TIMEFRAME_H1)

# Train on historical data
history = trader.train(lookback_days=100)

# Get a signal
signal = trader.predict_next_movement()
print(f"Direction: {signal['direction']}")
print(f"Confidence: {signal['confidence']:.2f}")
print(f"Current price: {signal['current_price']}")
print("Active patterns:", len(signal["active_patterns"]))
