import numpy as np
import pandas as pd
import MetaTrader5 as mt5
from datetime import datetime
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.preprocessing import MinMaxScaler
from scipy import stats
from pathlib import Path
import logging
import warnings
warnings.filterwarnings('ignore')

def setup_logging():
    logging.basicConfig(
        filename='3d_reversal.log',
        level=logging.DEBUG,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger()

def create_3d_bars(symbol, timeframe, start_date, end_date, min_spread_multiplier=45, volume_brick=500):
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
    if rates is None:
        raise ValueError(f"Error getting data for {symbol}")
        
    df = pd.DataFrame(rates)
    df['time'] = pd.to_datetime(df['time'], unit='s')
    
    symbol_info = mt5.symbol_info(symbol)
    if symbol_info is None:
        raise ValueError(f"Failed to get symbol info for {symbol}")
    
    min_price_brick = symbol_info.spread * min_spread_multiplier * symbol_info.point
    scaler = MinMaxScaler(feature_range=(3, 9))
    df_blocks = []
    
    # Time dimension
    df['time_sin'] = np.sin(2 * np.pi * df['time'].dt.hour / 24)
    df['time_cos'] = np.cos(2 * np.pi * df['time'].dt.hour / 24)
    df['time_numeric'] = (df['time'] - df['time'].min()).dt.total_seconds()
    
    # Price dimension
    df['typical_price'] = (df['high'] + df['low'] + df['close']) / 3
    df['price_return'] = df['typical_price'].pct_change()
    df['price_acceleration'] = df['price_return'].diff()
    
    # Volume dimension
    df['volume_change'] = df['tick_volume'].pct_change()
    df['volume_acceleration'] = df['volume_change'].diff()
    
    # Volatility dimension
    df['volatility'] = df['price_return'].rolling(20).std()
    df['volatility_change'] = df['volatility'].pct_change()
    
    for idx in range(20, len(df)):
        window = df.iloc[idx-20:idx+1]
        
        block = {
            'time': df.iloc[idx]['time'],
            'time_numeric': scaler.fit_transform([[float(df.iloc[idx]['time_numeric'])]]).item(),
            'open': float(window['price_return'].iloc[-1]),
            'high': float(window['price_acceleration'].iloc[-1]),
            'low': float(window['volume_change'].iloc[-1]),
            'close': float(window['volatility_change'].iloc[-1]),
            'tick_volume': float(window['volume_acceleration'].iloc[-1]),
            'direction': np.sign(window['price_return'].iloc[-1]),
            'spread': float(df.iloc[idx]['time_sin']),
            'type': float(df.iloc[idx]['time_cos']),
            'trend_count': len(window),
            'price_change': float(window['price_return'].mean()),
            'volume_intensity': float(window['volume_change'].mean()),
            'price_velocity': float(window['price_acceleration'].mean())
        }
        df_blocks.append(block)

    result_df = pd.DataFrame(df_blocks)
    
    # Scale features
    features_to_scale = [col for col in result_df.columns if col != 'time' and col != 'direction']
    result_df[features_to_scale] = scaler.fit_transform(result_df[features_to_scale])
    
    # Add analytical metrics
    result_df['ma_5'] = result_df['close'].rolling(5).mean()
    result_df['ma_20'] = result_df['close'].rolling(20).mean()
    result_df['volume_ma_5'] = result_df['tick_volume'].rolling(5).mean()
    result_df['price_volatility'] = result_df['price_change'].rolling(10).std()
    result_df['volume_volatility'] = result_df['tick_volume'].rolling(10).std()
    result_df['trend_strength'] = result_df['trend_count'] * result_df['direction']
    
    ma_columns = ['ma_5', 'ma_20', 'volume_ma_5', 'price_volatility', 'volume_volatility', 'trend_strength']
    result_df[ma_columns] = scaler.fit_transform(result_df[ma_columns])
    
    result_df['zscore_price'] = stats.zscore(result_df['close'], nan_policy='omit')
    result_df['zscore_volume'] = stats.zscore(result_df['tick_volume'], nan_policy='omit')
    zscore_columns = ['zscore_price', 'zscore_volume']
    result_df[zscore_columns] = scaler.fit_transform(result_df[zscore_columns])
    
    return result_df, min_price_brick

def detect_reversal_pattern(df, window_size=20):
    df['reversal_score'] = 0.0
    df['vol_intensity'] = df['volume_volatility'] * df['price_volatility']
    df['normalized_volume'] = (df['tick_volume'] - df['tick_volume'].rolling(window_size).mean()) / df['tick_volume'].rolling(window_size).std()
    
    for i in range(window_size, len(df)):
        window = df.iloc[i-window_size:i]
        
        volume_spike = window['normalized_volume'].iloc[-1] > 2.0
        volatility_spike = window['vol_intensity'].iloc[-1] > window['vol_intensity'].mean() + 2*window['vol_intensity'].std()
        trend_pressure = window['trend_strength'].sum() / window_size
        momentum_change = window['momentum'].diff().iloc[-1] if 'momentum' in df.columns else 0
        
        df.loc[df.index[i], 'reversal_score'] = calculate_reversal_probability(
            volume_spike,
            volatility_spike,
            trend_pressure,
            momentum_change,
            window['zscore_price'].iloc[-1],
            window['zscore_volume'].iloc[-1]
        )
    return df

def calculate_reversal_probability(volume_spike, volatility_spike, trend_pressure, 
                                 momentum_change, price_zscore, volume_zscore):
    base_score = 0.0
    
    if volume_spike and volatility_spike:
        base_score += 0.4
    elif volume_spike or volatility_spike:
        base_score += 0.2
        
    base_score += min(0.3, abs(trend_pressure) * 0.1)
    
    if abs(momentum_change) > 0:
        base_score += 0.15 * np.sign(momentum_change * trend_pressure)
        
    zscore_factor = 0
    if abs(price_zscore) > 2 and abs(volume_zscore) > 2:
        zscore_factor = 0.15
        
    return min(1.0, base_score + zscore_factor)

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def create_visualizations(df, reversal_points, symbol, save_dir):
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    for idx in reversal_points.index:
        start_idx = max(0, idx - 50)
        end_idx = min(len(df), idx + 50)
        window_df = df.iloc[start_idx:end_idx]
        
        # Create a figure with two subgraphs
        fig = plt.figure(figsize=(20, 10))
        
        # 3D chart
        ax1 = fig.add_subplot(121, projection='3d')
        scatter = ax1.scatter(
            np.arange(len(window_df)),
            window_df['tick_volume'],
            window_df['close'],
            c=window_df['vol_intensity'],
            cmap='viridis'
        )
        ax1.set_title(f'{symbol} 3D View at Reversal')
        plt.colorbar(scatter, ax=ax1)
        
        # Price chart
        ax2 = fig.add_subplot(122)
        ax2.plot(window_df['close'], color='blue', label='Close')
        ax2.scatter([idx - start_idx], [window_df.iloc[idx - start_idx]['close']], 
                   color='red', s=100, label='Reversal Point')
        ax2.set_title(f'{symbol} Price at Reversal')
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig(save_dir / f'reversal_{idx}.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # Save data
        window_df.to_csv(save_dir / f'reversal_data_{idx}.csv')

def main():
    logger = setup_logging()
    
    try:
        if not mt5.initialize():
            raise RuntimeError("MetaTrader5 initialization failed")

        symbols = ["EURUSD"]
        timeframe = mt5.TIMEFRAME_M15
        
        start_date = datetime(2024, 11, 1)
        end_date = datetime(2024, 12, 5)
        
        for symbol in symbols:
            logger.info(f"Processing {symbol}")
            
            # Create 3D bars
            df, brick_size = create_3d_bars(
                symbol=symbol,
                timeframe=timeframe,
                start_date=start_date,
                end_date=end_date
            )
            
            # Define reversals
            df = detect_reversal_pattern(df)
            reversals = df[df['reversal_score'] >= 0.7].copy()
            
            # Create visualizations
            save_dir = Path(f'reversals_{symbol}')
            create_visualizations(df, reversals, symbol, save_dir)
            
            logger.info(f"Found {len(reversals)} potential reversal points")
            
            # Save results
            df.to_csv(save_dir / f'{symbol}_analysis.csv')
            reversals.to_csv(save_dir / f'{symbol}_reversals.csv')
            
    except Exception as e:
        logger.error(f"Error occurred: {str(e)}", exc_info=True)
    finally:
        mt5.shutdown()

if __name__ == "__main__":
    main()
