import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
from scipy.optimize import minimize
import MetaTrader5 as mt5
import time
from datetime import datetime, timedelta
import logging

# Setup logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


# Initialize MetaTrader5
def initialize_mt5():
    if not mt5.initialize():
        logging.error("Failed to initialize MetaTrader5")
        return False
    return True


def get_data(symbol, timeframe, start_date, end_date):
    rates = mt5.copy_rates_range(symbol, timeframe, start_date, end_date)
    df = pd.DataFrame(rates)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    df.set_index("time", inplace=True)
    df["returns"] = df["close"].pct_change()
    return df


def update_market_data(portfolio, timeframe, lookback):
    end_date = datetime.now()
    start_date = end_date - timedelta(days=lookback)
    data = {}
    for symbol in portfolio:
        df = get_data(symbol, timeframe, start_date, end_date)
        data[symbol] = df["returns"].dropna()
    return pd.DataFrame(data)


def calculate_var(returns, confidence_level=0.95, holding_period=1):
    return np.percentile(returns, (1 - confidence_level) * 100) * np.sqrt(
        holding_period
    )


def calculate_cvar(returns, confidence_level=0.95, holding_period=1):
    var = calculate_var(returns, confidence_level, holding_period)
    return -returns[returns <= -var].mean() * np.sqrt(holding_period)


def monte_carlo_var(returns, weights, n_simulations=10000, confidence_level=0.95):
    portfolio_returns = returns.dot(weights)
    mu, sigma = portfolio_returns.mean(), portfolio_returns.std()
    simulations = np.random.normal(mu, sigma, n_simulations)
    return -np.percentile(simulations, (1 - confidence_level) * 100)


def optimize_portfolio(returns, target_return, confidence_level=0.95):
    n = len(returns.columns)

    def portfolio_var(weights):
        return monte_carlo_var(returns, weights, confidence_level=confidence_level)

    def portfolio_return(weights):
        return np.sum(returns.mean() * weights)

    constraints = (
        {"type": "eq", "fun": lambda x: np.sum(x) - 1},
        {"type": "eq", "fun": lambda x: portfolio_return(x) - target_return},
    )

    bounds = tuple((0, 1) for _ in range(n))

    result = minimize(
        portfolio_var,
        n * [1.0 / n],
        method="SLSQP",
        bounds=bounds,
        constraints=constraints,
    )

    return result.x


def get_account_balance():
    account_info = mt5.account_info()
    if account_info is None:
        raise RuntimeError("Failed to get account info")
    return account_info.balance


def get_open_positions():
    positions = mt5.positions_get()
    if positions is None:
        logging.info("No open positions")
        return {}
    return {pos.symbol: pos.volume for pos in positions}


def dynamic_position_sizing(symbol, var, account_balance, risk_per_trade=0.02):
    symbol_info = mt5.symbol_info(symbol)
    if symbol_info is None:
        logging.error(f"Failed to get symbol info for {symbol}")
        return 0
    pip_value = symbol_info.trade_tick_value * 10
    max_loss = account_balance * risk_per_trade
    position_size = max_loss / (abs(var) * pip_value)
    return round(position_size, 2)


def update_positions(portfolio, portfolio_var, account_balance, min_position_change):
    current_positions = get_open_positions()
    for symbol in portfolio:
        current_position = current_positions.get(symbol, 0)
        optimal_position = dynamic_position_sizing(
            symbol, portfolio_var[symbol], account_balance
        )

        if abs(current_position - optimal_position) > min_position_change:
            if current_position < optimal_position:
                mt5.order_send(
                    symbol, mt5.ORDER_TYPE_BUY, optimal_position - current_position
                )
            else:
                mt5.order_send(
                    symbol, mt5.ORDER_TYPE_SELL, current_position - optimal_position
                )


def calculate_portfolio_var(returns, weights, confidence_level=0.95):
    portfolio_returns = returns.dot(weights)
    return calculate_var(portfolio_returns, confidence_level)


def monitor_drawdown(portfolio_var, account_balance, max_drawdown=0.2):
    current_drawdown = portfolio_var / account_balance
    if current_drawdown > max_drawdown:
        logging.warning(f"Drawdown exceeded: {current_drawdown:.2%}")
        return True
    return False


def sharpe_ratio(returns, risk_free_rate=0.02):
    if isinstance(returns, pd.DataFrame):
        return ((returns.mean() - risk_free_rate) / returns.std() * np.sqrt(252)).mean()
    else:
        return (returns.mean() - risk_free_rate) / returns.std() * np.sqrt(252)


def profit_factor(returns):
    if isinstance(returns, pd.DataFrame):
        positive_returns = returns[returns > 0].sum().sum()
        negative_returns = abs(returns[returns < 0].sum()).sum()
    else:
        positive_returns = returns[returns > 0].sum()
        negative_returns = abs(returns[returns < 0].sum())

    return (
        positive_returns / negative_returns if negative_returns != 0 else float("inf")
    )


def log_performance(returns, var, balance):
    current_return = returns.iloc[-1].sum()
    current_var = var.sum()

    logging.info(f"Date: {datetime.now()}")
    logging.info(f"Current Return: {current_return:.2%}")
    logging.info(f"Current VaR: {current_var:.2%}")
    logging.info(f"Account Balance: ${balance:.2f}")
    logging.info(f"Sharpe Ratio: {sharpe_ratio(returns):.2f}")
    logging.info(f"Profit Factor: {profit_factor(returns):.2f}")
    logging.info("--------------------")


def main():
    if not initialize_mt5():
        return

    portfolio = ["EURUSD", "GBPUSD", "USDJPY", "AUDUSD", "USDCAD"]
    timeframe = mt5.TIMEFRAME_H1
    lookback = 252
    confidence_level = 0.95
    max_drawdown = 0.2
    rebalance_frequency = 20
    visualization_frequency = 24 * 7
    min_position_change = 0.01

    weights = np.array([1 / len(portfolio)] * len(portfolio))

    iteration = 0
    while True:
        try:
            returns = update_market_data(portfolio, timeframe, lookback)

            if returns.empty:
                logging.warning("Failed to get profitability data. Skip iteration.")
                time.sleep(300)
                continue

            var = returns.apply(lambda x: calculate_var(x, confidence_level))

            if iteration % rebalance_frequency == 0:
                target_return = returns.mean().mean()
                weights = optimize_portfolio(returns, target_return, confidence_level)

            portfolio_var = calculate_portfolio_var(returns, weights, confidence_level)

            account_balance = get_account_balance()

            update_positions(portfolio, var, account_balance, min_position_change)

            if monitor_drawdown(portfolio_var, account_balance, max_drawdown):
                logging.warning("Reached maximum drawdown lvel. Decrease positions.")
                # Add position decrease logic here

            log_performance(returns, var, account_balance)

            if iteration % visualization_frequency == 0:
                # Add visualization here if necessary
                pass

            iteration += 1
            time.sleep(3600)

        except Exception as e:
            logging.error(f"Error: {e}", exc_info=True)
            time.sleep(300)


if __name__ == "__main__":
    main()
