
import numpy as np
import pandas as pd
import MetaTrader5 as mt5
import logging
import plotly.graph_objects as go

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')


# Step 1: Data Extraction

def get_ticks(symbol, start_date, end_date):
    """
    Downloads tick data from the MT5 terminal.

    Args:
        symbol (str): Financial instrument (e.g., currency pair or stock).
        start_date, end_date (str or datetime): Time range for data (YYYY-MM-DD).

    Returns:
        pd.DataFrame: Tick data with a datetime index.
    """
    if not mt5.initialize():
        logging.error("MT5 connection not established.")
        raise RuntimeError("MT5 connection error.")

    start_date = pd.Timestamp(start_date, tz='UTC') if isinstance(start_date, str) else (
        start_date if start_date.tzinfo is not None else pd.Timestamp(start_date, tz='UTC')
    )
    end_date = pd.Timestamp(end_date, tz='UTC') if isinstance(end_date, str) else (
        end_date if end_date.tzinfo is not None else pd.Timestamp(end_date, tz='UTC')
    )

    try:
        ticks = mt5.copy_ticks_range(symbol, start_date, end_date, mt5.COPY_TICKS_ALL)
        df = pd.DataFrame(ticks)
        df['time'] = pd.to_datetime(df['time_msc'], unit='ms')
        df.set_index('time', inplace=True)
        df.drop('time_msc', axis=1, inplace=True)
        df = df[df.columns[df.any()]]
        df.info()
    except Exception as e:
        logging.error(f"Error while downloading ticks: {e}")
        return None

    return df


# Step 2: Data Cleaning

def clean_tick_data(df: pd.DataFrame, timezone='UTC', min_spread=1e-5):
    """
    Cleans and validates tick data.

    Args:
        df (pd.DataFrame): Tick data.
        timezone (str): Timezone for localization.
        min_spread (float): Minimum acceptable spread between bid and ask.

    Returns:
        pd.DataFrame: Clean tick data.
    """
    if df.empty:
        return None

    if not isinstance(df.index, pd.DatetimeIndex):
        try:
            df.index = pd.to_datetime(df.index)
            df = df[~df.index.isnull()]
        except Exception as e:
            raise ValueError(f"Index parsing error: {e}")

    if df.index.tz is None:
        df = df.tz_localize(timezone)
    else:
        df = df.tz_convert(timezone)

    price_filter = (
        (df['bid'] > 0) &
        (df['ask'] > 0) &
        (df['ask'] > df['bid']) &
        ((df['ask'] - df['bid']) >= min_spread)
    )
    df = df[price_filter]

    if df.isna().any().sum() > 0:
        logging.warning(f"Dropped NA values:\n{df.isna().sum()}")
        df.dropna(inplace=True)

    if not df.index.microsecond.any():
        logging.warning("No microsecond precision found in timestamps.")

    duplicate_mask = df.index.duplicated(keep='last')
    if duplicate_mask.sum() > 0:
        logging.info(f"Removed {duplicate_mask.sum()} duplicate timestamps")
        df = df[~duplicate_mask]

    if not df.index.is_monotonic_increasing:
        df.sort_index(inplace=True)

    if df.empty:
        logging.warning("DataFrame is empty after cleaning.")
        return None

    return df


# Step 3: Create Bars and Convert to End-Time

## Resampling Frequency Conversion

def set_resampling_freq(timeframe: str) -> str:
    """
    Converts an MT5 timeframe to a pandas resampling frequency.

    Args:
        timeframe (str): MT5 timeframe (e.g., 'M1', 'H1', 'D1', 'W1').

    Returns:
        str: Pandas frequency string.
    """
    timeframe = timeframe.upper()
    nums = [x for x in timeframe if x.isnumeric()]
    if not nums:
        raise ValueError("Timeframe must include numeric values (e.g., 'M1').")
    
    x = int(''.join(nums))
    if timeframe == 'W1':
        freq = 'W-FRI'
    elif timeframe == 'D1':
        freq = 'B'
    elif timeframe.startswith('H'):
        freq = f'{x}H'
    elif timeframe.startswith('M'):
        freq = f'{x}min'
    elif timeframe.startswith('S'):
        freq = f'{x}S'
    else:
        raise ValueError("Valid timeframes include W1, D1, Hx, Mx, Sx.")
    
    return freq


def calculate_ticks_per_period(df: pd.DataFrame, timeframe: str = "M1", method: str = 'median', verbose: bool = True) -> int:
    """
    Dynamically calculates the average number of ticks per given timeframe.

    Args:
        df (pd.DataFrame): Tick data.
        timeframe (str): MT5 timeframe.
        method (str): 'median' or 'mean' for the calculation.
        verbose (bool): Whether to print the result.

    Returns:
        int: Rounded average ticks per period.
    """
    freq = set_resampling_freq(timeframe)
    resampled = df.resample(freq).size()
    fn = getattr(np, method)
    num_ticks = fn(resampled.values)
    num_rounded = int(np.round(num_ticks))
    num_digits = len(str(num_rounded)) - 1
    rounded_ticks = int(round(num_rounded, -num_digits))
    rounded_ticks = max(1, rounded_ticks)
    
    if verbose:
        t0 = df.index[0].date()
        t1 = df.index[-1].date()
        logging.info(f"From {t0} to {t1}, {method} ticks per {timeframe}: {num_ticks:,} rounded to {rounded_ticks:,}")
    
    return rounded_ticks


## Grouping and Bar Creation

def flatten_column_names(df: pd.DataFrame):
    """Flatten MultiIndex column names into single strings."""
    return ["_".join(map(str, col)).strip() for col in df.columns.values]


def make_bar_type_grouper(df: pd.DataFrame, bar_type: str = 'tick', bar_size: int = 100, timeframe: str = 'M1'):
    """
    Creates a grouping object to aggregate tick data into bars.

    Args:
        df (pd.DataFrame): Tick data.
        bar_type (str): Bar type ('time', 'tick', 'dollar', or 'volume').
        bar_size (int): Used for non-time bars; set to 0 for dynamic sizing.
        timeframe (str): Timeframe for time-based resampling or dynamic calculations.

    Returns:
        tuple: A (groupby object, final bar_size).
    """
    if not isinstance(df.index, pd.DatetimeIndex):
        try:
            df = df.copy()
            df.set_index('time', inplace=True)
        except Exception:
            raise TypeError("DatetimeIndex expected.")

    if not df.index.is_monotonic_increasing:
        df.sort_index(inplace=True)

    if bar_type == 'time':
        freq = set_resampling_freq(timeframe)
        bar_group = df.resample(freq, closed='right', label='right') if not freq.startswith(('B', 'W')) else df.resample(freq)
    else:
        if bar_size == 0:
            if bar_type == 'tick':
                bar_size = calculate_ticks_per_period(df, timeframe)
            else:
                raise NotImplementedError(f"{bar_type} bars require a non-zero bar_size.")
        df = df.copy()
        df['time'] = df.index
        if bar_type == 'tick':
            bar_id = np.arange(len(df)) // bar_size
        elif bar_type in ('volume', 'dollar'):
            if 'volume' not in df.columns:
                raise KeyError(f"Column 'volume' required for creating {bar_type} bars.")
            ts = 0
            bar_id = []
            if bar_type == 'volume':
                for i, vol in enumerate(df['volume'].values):
                    ts += vol
                    if ts >= bar_size:
                        bar_id.append(i)
                        ts = 0
            else:
                ts = 0
                for i, (vol, price) in enumerate(zip(df['volume'].values, df['bid'].values)):
                    ts += vol * price
                    if ts >= bar_size:
                        bar_id.append(i)
                        ts = 0
        else:
            raise NotImplementedError(f"{bar_type} bars not implemented yet.")
        
        bar_group = df.groupby(bar_id)
    
    return bar_group, bar_size


def make_bars(tick_df: pd.DataFrame, bar_type: str = 'tick', bar_size: int = 0, timeframe: str = 'M1',
              price: str = 'midprice', verbose: bool = True):
    """
    Constructs OHLC bars from tick data.

    Args:
        tick_df (pd.DataFrame): Tick data.
        bar_type (str): Bar type ('tick', 'time', 'volume', 'dollar').
        bar_size (int): For non-time bars; if 0, dynamic calculation is used.
        timeframe (str): Timeframe for calculation.
        price (str): Price field strategy ('midprice' or 'bid_ask').
        verbose (bool): Prints runtime details if True.

    Returns:
        pd.DataFrame: OHLC bars with additional metrics.
    """
    if 'midprice' not in tick_df.columns:
        tick_df['midprice'] = (tick_df['bid'] + tick_df['ask']) / 2

    bar_group, bar_size_ = make_bar_type_grouper(tick_df, bar_type, bar_size, timeframe)
    ohlc_df = bar_group['midprice'].ohlc().astype('float64')
    ohlc_df['tick_volume'] = bar_group['bid'].count() if bar_type != 'tick' else bar_size_

    if price == 'bid_ask':
        bid_ask_df = bar_group.agg({k: 'ohlc' for k in ('bid', 'ask')})
        bid_ask_df.columns = flatten_column_names(bid_ask_df)
        ohlc_df = ohlc_df.join(bid_ask_df)

    if 'volume' in tick_df.columns:
        ohlc_df['volume'] = bar_group['volume'].sum()

    if bar_type == 'time':
        ohlc_df.ffill(inplace=True)
    else:
        end_time = bar_group['time'].last()
        ohlc_df.index = end_time + pd.Timedelta(microseconds=1)
        if len(tick_df) % bar_size_ > 0:
            ohlc_df = ohlc_df.iloc[:-1]

    if verbose:
        tm_info = f"{timeframe} - {bar_size_:,} ticks" if (bar_type == 'tick' and bar_size == 0) else f"{bar_size_:,}"
        logging.info(f"Tick data contains {tick_df.shape[0]:,} rows")
        logging.info(f"{bar_type}_bar with info: {tm_info}")
        logging.info(ohlc_df.info())

    try:
        ohlc_df = ohlc_df.tz_convert(None)
    except Exception:
        pass

    return ohlc_df


# Volatility Analysis Plotting

def plot_volatility_analysis_of_bars(df, symbol, start, end, freq, thres=0.01, bins=100):
    """
    Plots volatility analysis using Plotly.

    Args:
        df (pd.DataFrame): DataFrame with 'open' and 'close' columns.
        symbol (str): Asset symbol.
        start (str): Start date.
        end (str): End date.
        freq (str): Data frequency.
        thres (float): Threshold for extreme price changes.
        bins (int): Number of histogram bins.

    Returns:
        go.Figure: Plotly figure object.
    """
    abs_price_changes = (df['close'] / df['open'] - 1).mul(100).abs()
    cutoff = abs_price_changes.quantile(1 - thres)
    filtered_changes = abs_price_changes[abs_price_changes < cutoff]
    
    counts, bin_edges = np.histogram(filtered_changes, bins=bins)
    bins_centers = bin_edges[:-1]
    
    total_counts = len(filtered_changes)
    proportion_candles_right = []
    proportion_price_change_right = []
    
    for b in bins_centers:
        candles_right = filtered_changes[filtered_changes >= b]
        proportion_candles_right.append(len(candles_right) / total_counts)
        proportion_price_change_right.append(np.sum(candles_right) / np.sum(filtered_changes))
    
    fig = go.Figure()
    fig.add_trace(go.Bar(
        x=bins_centers,
        y=counts,
        name='Histogram absolute price change (%)',
        marker=dict(color='#1f77b4'),
        hovertemplate='<b>Bin: %{x:.2f}</b><br>Frequency: %{y}',
        yaxis='y1',
        opacity=0.65
    ))
    
    ms = 3
    lw = 0.5
    fig.add_trace(go.Scatter(
        x=bins_centers,
        y=proportion_candles_right,
        name='Proportion of candles at the right',
        mode='lines+markers',
        marker=dict(color='red', size=ms),
        line=dict(width=lw),
        hoverinfo='text',
        text=[f"Bin: {x:.2f}, Proportion: {y:.4f}" for x, y in zip(bins_centers, proportion_candles_right)],
        yaxis='y2'
    ))
    fig.add_trace(go.Scatter(
        x=bins_centers,
        y=proportion_price_change_right,
        name='Proportion price change (candles right)',
        mode='lines+markers',
        marker=dict(color='green', size=ms),
        line=dict(width=lw),
        hoverinfo='text',
        text=[f"Bin: {x:.2f}, Proportion: {y:.4f}" for x, y in zip(bins_centers, proportion_price_change_right)],
        yaxis='y2'
    ))
    
    search_idx = [0.01, 0.05] + np.linspace(0.1, 1, 10).tolist()
    price_idxs = np.searchsorted(sorted(proportion_candles_right), search_idx, side='right')
    for ix in price_idxs:
        if ix < len(bins_centers):
            x_val = bins_centers[-ix]
            y_val = proportion_candles_right[-ix]
            fig.add_annotation(
                x=x_val,
                y=y_val,
                text=f"{y_val:.4f}",
                showarrow=True,
                arrowhead=1,
                ax=0,
                ay=-15,
                font=dict(color="salmon"),
                arrowcolor="red",
                yref='y2'
            )
    
    fig.update_layout(
        title=f'Volatility Analysis of {symbol} {freq} from {start} to {end}',
        xaxis_title='Absolute price change (%)',
        yaxis_title='Frequency',
        yaxis2=dict(
            title='Proportion',
            overlaying='y',
            side='right',
            gridcolor='#444'
        ),
        plot_bgcolor='#222',
        paper_bgcolor='#222',
        font=dict(color='white'),
        xaxis=dict(gridcolor='#444'),
        yaxis=dict(gridcolor='#444'),
        legend=dict(x=0.3, y=0.95, traceorder="normal", font=dict(color="white"))
    )
    
    return fig

