import numpy as np
import pandas as pd
from scipy.optimize import minimize
from scipy.stats import norm
from datetime import datetime, timedelta
import logging
import os
import matplotlib.pyplot as plt
import seaborn as sns

def setup_logging():
    log_dir = 'logs'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    log_file = os.path.join(log_dir, f'backtest_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

def generate_valid_correlation_matrix(size):
    A = np.random.uniform(-0.5, 0.5, (size, size))
    A = (A + A.T) / 2
    A = A + size * np.eye(size)
    D = np.diag(1 / np.sqrt(np.diag(A)))
    corr_matrix = D @ A @ D
    return corr_matrix

class ForexPortfolioBacktester:
    def __init__(self, start_date, end_date, initial_capital=100000):
        self.start_date = start_date
        self.end_date = end_date
        self.initial_capital = initial_capital
        self.var_limit = 0.04
        self.confidence_level = 0.95
        self.min_weight = 0.04
        self.max_weight = 0.25
        self.risk_free_rate = 0.0001
        
        # Расширенный список валютных пар с учетом ликвидности
        self.symbols = {
            'major': ['EURUSD', 'GBPUSD', 'USDJPY', 'USDCHF'],
            'commodity': ['AUDUSD', 'NZDUSD', 'USDCAD'],
            'cross_major': ['EURJPY', 'GBPJPY', 'EURGBP', 'EURCHF'],
            'cross_commodity': ['AUDNZD', 'GBPCHF', 'EURCAD', 'AUDCAD'],
            'cross_exotic': ['CADJPY', 'AUDJPY', 'NZDJPY', 'EURAUD', 'GBPCAD']
        }
        
        self.all_symbols = [sym for group in self.symbols.values() for sym in group]
        self.returns_data = None
        self.portfolio_values = []
        self.portfolio_weights = []
        self.portfolio_returns = []
        self.dates = []
        self.transaction_costs = 0.0002  # 2 базисных пункта на сделку
        self.annual_reports = []
        
        logging.info(f"Initialized backtester with {len(self.all_symbols)} currency pairs")
        
    def generate_synthetic_data(self):
        dates = pd.date_range(self.start_date, self.end_date, freq='D')
        data = {}
        
        # Различные параметры волатильности для разных периодов рынка
        market_regimes = {
            'low_vol': {'weight': 0.3, 'multiplier': 0.7},
            'normal': {'weight': 0.4, 'multiplier': 1.0},
            'high_vol': {'weight': 0.3, 'multiplier': 1.5}
        }
        
        volatility_params = {
                                'major': (0.0003, 0.003),  # увеличили базовую доходность
                                'commodity': (0.0004, 0.004),
                                'cross_major': (0.0003, 0.004),
                                'cross_commodity': (0.0004, 0.005),
                                'cross_exotic': (0.0005, 0.006)
                            }
        
        # Генерация базовой корреляционной матрицы
        base_corr_matrix = generate_valid_correlation_matrix(len(self.all_symbols))
        n_days = len(dates)
        
        # Добавляем циклические паттерны для разных периодов
        cycles = {
            'short': np.sin(np.linspace(0, 20*np.pi, n_days)) * 0.002,
            'medium': np.sin(np.linspace(0, 8*np.pi, n_days)) * 0.003,
            'long': np.sin(np.linspace(0, 4*np.pi, n_days)) * 0.004
        }
        
        for group, symbols in self.symbols.items():
            base_mu, base_sigma = volatility_params[group]
            
            for symbol in symbols:
                returns = np.zeros(n_days)
                
                # Разбиваем период на режимы волатильности
                regime_points = np.random.choice(
                    ['low_vol', 'normal', 'high_vol'],
                    size=n_days,
                    p=[0.3, 0.4, 0.3]
                )
                
                for regime in market_regimes:
                    mask = regime_points == regime
                    mu = base_mu * market_regimes[regime]['multiplier']
                    sigma = base_sigma * market_regimes[regime]['multiplier']
                    returns[mask] = np.random.normal(mu, sigma, mask.sum())
                
                # Добавляем циклические компоненты
                returns += cycles['short']
                returns += cycles['medium']
                returns += cycles['long']
                
                # Добавляем долгосрочный тренд
                trend = np.linspace(0, base_mu * 5, n_days)
                returns += trend
                
                # Конвертируем в цены
                price = 1.0
                prices = [price]
                for r in returns:
                    price *= (1 + r)
                    prices.append(price)
                
                data[symbol] = pd.Series(prices[:-1], index=dates)
        
        self.prices = pd.DataFrame(data)
        self.returns_data = np.log(self.prices / self.prices.shift(1)).dropna()
        logging.info(f"Generated synthetic data shape: {self.returns_data.shape}")
        
    def calculate_portfolio_metrics(self, weights, window_returns=None):
        if window_returns is None:
            window_returns = self.returns_data
            
        returns = window_returns.mean()
        
        # Динамические риск-премии на основе волатильности
        volatilities = window_returns.std()
        risk_premiums = np.zeros(len(self.all_symbols))
        
        for i, symbol in enumerate(self.all_symbols):
            vol = volatilities[symbol]
            base_premium = 0.01
            
            if symbol in self.symbols['major']:
                risk_premiums[i] = base_premium * (2 + vol)
            elif symbol in self.symbols['commodity']:
                risk_premiums[i] = base_premium * (2.5 + vol)
            else:
                risk_premiums[i] = base_premium * (3.25 + vol)
                
        adjusted_returns = returns + risk_premiums
        portfolio_return = adjusted_returns.dot(weights) * 252
        portfolio_vol = np.sqrt(weights.T @ window_returns.cov() * 252 @ weights)
        
        return portfolio_return, portfolio_vol
        
    def calculate_var(self, weights, window_returns=None):
        if window_returns is None:
            window_returns = self.returns_data
            
        portfolio_returns = window_returns.dot(weights)
        var = -np.percentile(portfolio_returns, (1 - self.confidence_level) * 100)
        return var
            
    def optimize_portfolio(self, current_date):
        window = 252  # Годовое окно
        data_slice = self.returns_data[:current_date].tail(window)
        
        if len(data_slice) < window/2:
            return np.array([1/len(self.all_symbols)] * len(self.all_symbols))
            
        n_assets = len(self.all_symbols)
        
        def objective(weights):
            portfolio_return, portfolio_vol = self.calculate_portfolio_metrics(weights, data_slice)
            var = self.calculate_var(weights, data_slice)
            
            # Модифицированная целевая функция с учетом транзакционных издержек
            sharpe = (portfolio_return - self.risk_free_rate) / portfolio_vol
            var_penalty = max(0, var - self.var_limit) * 100
            
            # Добавляем штраф за частые изменения весов
            if len(self.portfolio_weights) > 0:
                turnover = np.sum(np.abs(weights - self.portfolio_weights[-1]))
                transaction_cost = turnover * self.transaction_costs
            else:
                transaction_cost = 0
                
            return -(sharpe - var_penalty - transaction_cost)
            
        constraints = [
            {'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
            {'type': 'ineq', 'fun': lambda x: self.var_limit - self.calculate_var(x, data_slice)}
        ]
        
        bounds = tuple((self.min_weight, self.max_weight) for _ in range(n_assets))
        
        # Улучшенная инициализация весов с учетом предыдущего распределения
        if len(self.portfolio_weights) > 0:
            initial_weights = self.portfolio_weights[-1]
        else:
            initial_weights = np.array([
                0.25 if sym in self.symbols['major']
                else 0.15 if sym in self.symbols['commodity']
                else 0.05
                for sym in self.all_symbols
            ])
            initial_weights = initial_weights / initial_weights.sum()
        
        try:
            result = minimize(objective, initial_weights, method='SLSQP',
                            bounds=bounds, constraints=constraints,
                            options={'ftol': 1e-8, 'maxiter': 1000})
            
            if result.success:
                logging.info(f"Optimization successful for {current_date}")
                weights = np.clip(result.x, self.min_weight, self.max_weight)
                weights = weights / weights.sum()
                return weights
            else:
                logging.warning(f"Optimization failed for {current_date}, using previous weights")
                return initial_weights
        except Exception as e:
            logging.error(f"Optimization error: {str(e)}")
            return initial_weights
            
    def create_annual_report(self, year):
        year_mask = [d.year == year for d in self.dates]
        year_returns = np.array(self.portfolio_returns)[year_mask]
        year_values = np.array(self.portfolio_values)[year_mask]
        
        if len(year_returns) == 0:
            return None
            
        cumulative_return = (year_values[-1] / year_values[0]) - 1
        volatility = np.std(year_returns) * np.sqrt(252)
        sharpe = (cumulative_return - self.risk_free_rate) / volatility
        max_drawdown = np.min(
            np.array([(v / np.maximum.accumulate(year_values)[i]) - 1 
                     for i, v in enumerate(year_values)])
        )
        
        return {
            'year': year,
            'return': cumulative_return,
            'volatility': volatility,
            'sharpe': sharpe,
            'max_drawdown': max_drawdown,
            'ending_value': year_values[-1]
        }
            
    def run_backtest(self, rebalance_freq='QE'):
        logging.info("Starting backtest")
        self.generate_synthetic_data()
        
        portfolio_value = self.initial_capital
        weights = None
        rebalance_dates = pd.date_range(self.start_date, self.end_date, freq=rebalance_freq)
        
        leverage = 10  # добавляем плечо
        daily_interest = 0.0001  # процентная ставка за использование плеча (0.01% в день)
        
        for date in rebalance_dates:
            if date not in self.returns_data.index:
                continue
                
            weights = self.optimize_portfolio(date)
            next_date = date + pd.Timedelta(days=1)
            
            if next_date in self.returns_data.index:
                next_return = self.returns_data.loc[next_date]
                
                # Учитываем транзакционные издержки
                if len(self.portfolio_weights) > 0:
                    turnover = np.sum(np.abs(weights - self.portfolio_weights[-1]))
                    transaction_cost = turnover * self.transaction_costs
                else:
                    transaction_cost = 0
                
                # Расчет доходности с учетом плеча и процентов по нему
                leveraged_return = np.sum(weights * next_return) * leverage
                interest_cost = (leverage - 1) * daily_interest  # процент только на заемные средства
                portfolio_return = leveraged_return - transaction_cost - interest_cost
                
                portfolio_value *= (1 + portfolio_return)
                
                # Проверка на margin call (если падение больше 10%)
                if portfolio_return < -0.1:
                    logging.warning(f"Margin call on {date}! Portfolio return: {portfolio_return:.2%}")
                    portfolio_value *= 0.1  # оставляем 10% от капитала после margin call
                
                self.portfolio_values.append(portfolio_value)
                self.portfolio_weights.append(weights)
                self.portfolio_returns.append(portfolio_return)
                self.dates.append(date)
                
                # Создаем годовой отчет при смене года
                if len(self.dates) > 1 and self.dates[-1].year != self.dates[-2].year:
                    annual_report = self.create_annual_report(self.dates[-2].year)
                    if annual_report:
                        self.annual_reports.append(annual_report)
                        logging.info(f"Annual report for {annual_report['year']}: "
                                   f"Return: {annual_report['return']:.2%}, "
                                   f"Sharpe: {annual_report['sharpe']:.2f}")
                
        # Создаем финальный годовой отчет
        final_annual_report = self.create_annual_report(self.dates[-1].year)
        if final_annual_report:
            self.annual_reports.append(final_annual_report)
            
        self.create_performance_report()

    def create_performance_report(self):
        returns = np.array(self.portfolio_returns)
        cumulative_returns = np.cumprod(1 + returns) - 1
        total_return = (self.portfolio_values[-1] / self.initial_capital - 1)
        annual_return = (1 + total_return) ** (1 / len(self.annual_reports)) - 1
        volatility = np.std(returns) * np.sqrt(252)
        sharpe_ratio = (annual_return - self.risk_free_rate) / volatility
        max_drawdown = np.min(cumulative_returns - np.maximum.accumulate(cumulative_returns))

        # Создаем расширенные графики
        plt.figure(figsize=(20, 25))
        
        # График стоимости портфеля
        plt.subplot(5, 1, 1)
        plt.plot(self.dates, self.portfolio_values, linewidth=2)
        plt.title('Portfolio Value Over Time', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value ($)')
        
        # График доходности
        plt.subplot(5, 1, 2)
        plt.plot(self.dates, np.array(self.portfolio_returns) * 100, linewidth=1)
        plt.title('Daily Returns (%)', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Return (%)')
        
        # График накопленной доходности
        plt.subplot(5, 1, 3)
        plt.plot(self.dates, cumulative_returns * 100, linewidth=2)
        plt.title('Cumulative Returns (%)', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Cumulative Return (%)')
        
        # График распределения весов
        plt.subplot(5, 1, 4)
        weights_df = pd.DataFrame(self.portfolio_weights, columns=self.all_symbols, index=self.dates)
        sns.heatmap(weights_df.T, cmap='YlOrRd', cbar_kws={'label': 'Weight'})
        plt.title('Portfolio Weights Over Time', fontsize=14)
        
        # График годовых показателей
        plt.subplot(5, 1, 5)
        annual_data = pd.DataFrame(self.annual_reports)
        annual_data.set_index('year', inplace=True)
        annual_data[['return', 'sharpe', 'volatility', 'max_drawdown']].plot(kind='bar', figsize=(15, 5))
        plt.title('Annual Performance Metrics', fontsize=14)
        plt.grid(True)
        plt.tight_layout()
        
        # Сохраняем графики
        plt.savefig('backtest_results.png')
        
        # Сохраняем отдельный график стоимости портфеля
        plt.figure(figsize=(15, 8))
        plt.plot(self.dates, self.portfolio_values, linewidth=2)
        plt.title('Portfolio Value Over Time', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value ($)')
        plt.tight_layout()
        plt.savefig('portfolio_value.png')
        plt.close()  # Закрываем график, чтобы освободить память

        # Далее идет оригинальный код с общим графиком
        plt.figure(figsize=(20, 25))
        
        # Создаем подробный текстовый отчет
        with open('backtest_report.txt', 'w') as f:
            f.write("Forex Portfolio Backtest Report\n")
            f.write("=============================\n\n")
            
            f.write("General Information\n")
            f.write("-----------------\n")
            f.write(f"Period: {self.start_date} to {self.end_date}\n")
            f.write(f"Initial Capital: ${self.initial_capital:,.2f}\n")
            f.write(f"Final Portfolio Value: ${self.portfolio_values[-1]:,.2f}\n")
            f.write(f"Total Return: {total_return:.2%}\n")
            f.write(f"Annualized Return: {annual_return:.2%}\n")
            f.write(f"Annual Volatility: {volatility:.2%}\n")
            f.write(f"Sharpe Ratio: {sharpe_ratio:.2f}\n")
            f.write(f"Maximum Drawdown: {max_drawdown:.2%}\n\n")
            
            f.write("Annual Performance\n")
            f.write("-----------------\n")
            for report in self.annual_reports:
                f.write(f"\nYear {report['year']}:\n")
                f.write(f"  Return: {report['return']:.2%}\n")
                f.write(f"  Volatility: {report['volatility']:.2%}\n")
                f.write(f"  Sharpe Ratio: {report['sharpe']:.2f}\n")
                f.write(f"  Max Drawdown: {report['max_drawdown']:.2%}\n")
                f.write(f"  Ending Value: ${report['ending_value']:,.2f}\n")
            
            f.write("\nPortfolio Composition\n")
            f.write("--------------------\n")
            final_weights = self.portfolio_weights[-1]
            for symbol, weight in zip(self.all_symbols, final_weights):
                f.write(f"{symbol}: {weight:.2%}\n")
                
            # Добавляем статистику по группам валютных пар
            f.write("\nCurrency Pair Group Statistics\n")
            f.write("----------------------------\n")
            for group, symbols in self.symbols.items():
                group_returns = []
                for symbol in symbols:
                    symbol_idx = self.all_symbols.index(symbol)
                    symbol_weights = [w[symbol_idx] for w in self.portfolio_weights]
                    avg_weight = np.mean(symbol_weights)
                    group_returns.append(avg_weight)
                
                f.write(f"\n{group.upper()}:\n")
                f.write(f"  Average Group Weight: {np.sum(group_returns):.2%}\n")
                f.write(f"  Number of Pairs: {len(symbols)}\n")
        
        logging.info("Performance report created and saved")

# Пример использования для 10-летнего периода
if __name__ == "__main__":
    setup_logging()
    
    start_date = datetime(2000, 1, 1)  # 10 лет назад
    end_date = datetime(2024, 12, 1)   # До конца 2024
    
    backtester = ForexPortfolioBacktester(start_date, end_date, initial_capital=1000000)
    backtester.run_backtest(rebalance_freq='QE')  # Ежемесячная ребалансировка
        
