import numpy as np
import pandas as pd
from scipy.optimize import minimize
from scipy.stats import norm
from datetime import datetime, timedelta
import logging
import os
import matplotlib.pyplot as plt
import seaborn as sns

def setup_logging():
    log_dir = 'logs'
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    log_file = os.path.join(log_dir, f'backtest_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

def generate_valid_correlation_matrix(size):
    A = np.random.uniform(-0.5, 0.5, (size, size))
    A = (A + A.T) / 2
    A = A + size * np.eye(size)
    D = np.diag(1 / np.sqrt(np.diag(A)))
    corr_matrix = D @ A @ D
    return corr_matrix

class ForexPortfolioBacktester:
    def __init__(self, start_date, end_date, initial_capital=100000):
        self.start_date = start_date
        self.end_date = end_date
        self.initial_capital = initial_capital
        self.var_limit = 0.04
        self.confidence_level = 0.95
        self.min_weight = 0.04
        self.max_weight = 0.25
        self.risk_free_rate = 0.0001
        
        # Extended list of currency pairs considering liquidity
        self.symbols = {
            'major': ['EURUSD', 'GBPUSD', 'USDJPY', 'USDCHF'],
            'commodity': ['AUDUSD', 'NZDUSD', 'USDCAD'],
            'cross_major': ['EURJPY', 'GBPJPY', 'EURGBP', 'EURCHF'],
            'cross_commodity': ['AUDNZD', 'GBPCHF', 'EURCAD', 'AUDCAD'],
            'cross_exotic': ['CADJPY', 'AUDJPY', 'NZDJPY', 'EURAUD', 'GBPCAD']
        }
        
        self.all_symbols = [sym for group in self.symbols.values() for sym in group]
        self.returns_data = None
        self.portfolio_values = []
        self.portfolio_weights = []
        self.portfolio_returns = []
        self.dates = []
        self.transaction_costs = 0.0002  # 2 basis points per deal
        self.annual_reports = []
        
        logging.info(f"Initialized backtester with {len(self.all_symbols)} currency pairs")
        
    def generate_synthetic_data(self):
        dates = pd.date_range(self.start_date, self.end_date, freq='D')
        data = {}
        
        # Varios volatility parameters for various market periods
        market_regimes = {
            'low_vol': {'weight': 0.3, 'multiplier': 0.7},
            'normal': {'weight': 0.4, 'multiplier': 1.0},
            'high_vol': {'weight': 0.3, 'multiplier': 1.5}
        }
        
        volatility_params = {
                                'major': (0.0003, 0.003),  # increased basic profitability
                                'commodity': (0.0004, 0.004),
                                'cross_major': (0.0003, 0.004),
                                'cross_commodity': (0.0004, 0.005),
                                'cross_exotic': (0.0005, 0.006)
                            }
        
        # Generate base correlation matrix
        base_corr_matrix = generate_valid_correlation_matrix(len(self.all_symbols))
        n_days = len(dates)
        
        # Add cyclical patterns for different periods
        cycles = {
            'short': np.sin(np.linspace(0, 20*np.pi, n_days)) * 0.002,
            'medium': np.sin(np.linspace(0, 8*np.pi, n_days)) * 0.003,
            'long': np.sin(np.linspace(0, 4*np.pi, n_days)) * 0.004
        }
        
        for group, symbols in self.symbols.items():
            base_mu, base_sigma = volatility_params[group]
            
            for symbol in symbols:
                returns = np.zeros(n_days)
                
                # Divide the period by volatility modes
                regime_points = np.random.choice(
                    ['low_vol', 'normal', 'high_vol'],
                    size=n_days,
                    p=[0.3, 0.4, 0.3]
                )
                
                for regime in market_regimes:
                    mask = regime_points == regime
                    mu = base_mu * market_regimes[regime]['multiplier']
                    sigma = base_sigma * market_regimes[regime]['multiplier']
                    returns[mask] = np.random.normal(mu, sigma, mask.sum())
                
                # Add cyclical components
                returns += cycles['short']
                returns += cycles['medium']
                returns += cycles['long']
                
                # Add long-term trend
                trend = np.linspace(0, base_mu * 5, n_days)
                returns += trend
                
                # Convert into prices
                price = 1.0
                prices = [price]
                for r in returns:
                    price *= (1 + r)
                    prices.append(price)
                
                data[symbol] = pd.Series(prices[:-1], index=dates)
        
        self.prices = pd.DataFrame(data)
        self.returns_data = np.log(self.prices / self.prices.shift(1)).dropna()
        logging.info(f"Generated synthetic data shape: {self.returns_data.shape}")
        
    def calculate_portfolio_metrics(self, weights, window_returns=None):
        if window_returns is None:
            window_returns = self.returns_data
            
        returns = window_returns.mean()
        
        # Dynamic volatility-based risk premiums
        volatilities = window_returns.std()
        risk_premiums = np.zeros(len(self.all_symbols))
        
        for i, symbol in enumerate(self.all_symbols):
            vol = volatilities[symbol]
            base_premium = 0.01
            
            if symbol in self.symbols['major']:
                risk_premiums[i] = base_premium * (2 + vol)
            elif symbol in self.symbols['commodity']:
                risk_premiums[i] = base_premium * (2.5 + vol)
            else:
                risk_premiums[i] = base_premium * (3.25 + vol)
                
        adjusted_returns = returns + risk_premiums
        portfolio_return = adjusted_returns.dot(weights) * 252
        portfolio_vol = np.sqrt(weights.T @ window_returns.cov() * 252 @ weights)
        
        return portfolio_return, portfolio_vol
        
    def calculate_var(self, weights, window_returns=None):
        if window_returns is None:
            window_returns = self.returns_data
            
        portfolio_returns = window_returns.dot(weights)
        var = -np.percentile(portfolio_returns, (1 - self.confidence_level) * 100)
        return var
            
    def optimize_portfolio(self, current_date):
        window = 252  # Year window
        data_slice = self.returns_data[:current_date].tail(window)
        
        if len(data_slice) < window/2:
            return np.array([1/len(self.all_symbols)] * len(self.all_symbols))
            
        n_assets = len(self.all_symbols)
        
        def objective(weights):
            portfolio_return, portfolio_vol = self.calculate_portfolio_metrics(weights, data_slice)
            var = self.calculate_var(weights, data_slice)
            
            # Modified target function considering transaction costs
            sharpe = (portfolio_return - self.risk_free_rate) / portfolio_vol
            var_penalty = max(0, var - self.var_limit) * 100
            
            # Add a penalty for frequent weight changes
            if len(self.portfolio_weights) > 0:
                turnover = np.sum(np.abs(weights - self.portfolio_weights[-1]))
                transaction_cost = turnover * self.transaction_costs
            else:
                transaction_cost = 0
                
            return -(sharpe - var_penalty - transaction_cost)
            
        constraints = [
            {'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
            {'type': 'ineq', 'fun': lambda x: self.var_limit - self.calculate_var(x, data_slice)}
        ]
        
        bounds = tuple((self.min_weight, self.max_weight) for _ in range(n_assets))
        
        # Improved initialization of weights considering the previous distribution
        if len(self.portfolio_weights) > 0:
            initial_weights = self.portfolio_weights[-1]
        else:
            initial_weights = np.array([
                0.25 if sym in self.symbols['major']
                else 0.15 if sym in self.symbols['commodity']
                else 0.05
                for sym in self.all_symbols
            ])
            initial_weights = initial_weights / initial_weights.sum()
        
        try:
            result = minimize(objective, initial_weights, method='SLSQP',
                            bounds=bounds, constraints=constraints,
                            options={'ftol': 1e-8, 'maxiter': 1000})
            
            if result.success:
                logging.info(f"Optimization successful for {current_date}")
                weights = np.clip(result.x, self.min_weight, self.max_weight)
                weights = weights / weights.sum()
                return weights
            else:
                logging.warning(f"Optimization failed for {current_date}, using previous weights")
                return initial_weights
        except Exception as e:
            logging.error(f"Optimization error: {str(e)}")
            return initial_weights
            
    def create_annual_report(self, year):
        year_mask = [d.year == year for d in self.dates]
        year_returns = np.array(self.portfolio_returns)[year_mask]
        year_values = np.array(self.portfolio_values)[year_mask]
        
        if len(year_returns) == 0:
            return None
            
        cumulative_return = (year_values[-1] / year_values[0]) - 1
        volatility = np.std(year_returns) * np.sqrt(252)
        sharpe = (cumulative_return - self.risk_free_rate) / volatility
        max_drawdown = np.min(
            np.array([(v / np.maximum.accumulate(year_values)[i]) - 1 
                     for i, v in enumerate(year_values)])
        )
        
        return {
            'year': year,
            'return': cumulative_return,
            'volatility': volatility,
            'sharpe': sharpe,
            'max_drawdown': max_drawdown,
            'ending_value': year_values[-1]
        }
            
    def run_backtest(self, rebalance_freq='QE'):
        logging.info("Starting backtest")
        self.generate_synthetic_data()
        
        portfolio_value = self.initial_capital
        weights = None
        rebalance_dates = pd.date_range(self.start_date, self.end_date, freq=rebalance_freq)
        
        leverage = 10  # add a leverage
        daily_interest = 0.0001  # leverage usage interest (0.01% per day)
        
        for date in rebalance_dates:
            if date not in self.returns_data.index:
                continue
                
            weights = self.optimize_portfolio(date)
            next_date = date + pd.Timedelta(days=1)
            
            if next_date in self.returns_data.index:
                next_return = self.returns_data.loc[next_date]
                
                # Consider transaction costs
                if len(self.portfolio_weights) > 0:
                    turnover = np.sum(np.abs(weights - self.portfolio_weights[-1]))
                    transaction_cost = turnover * self.transaction_costs
                else:
                    transaction_cost = 0
                
                # Calculate return considering leverage and interest
                leveraged_return = np.sum(weights * next_return) * leverage
                interest_cost = (leverage - 1) * daily_interest  # interest on borrowed funds only
                portfolio_return = leveraged_return - transaction_cost - interest_cost
                
                portfolio_value *= (1 + portfolio_return)
                
                # Margin call check (if decrease exceeds 10%)
                if portfolio_return < -0.1:
                    logging.warning(f"Margin call on {date}! Portfolio return: {portfolio_return:.2%}")
                    portfolio_value *= 0.1  # leave 10% of equity after margin call
                
                self.portfolio_values.append(portfolio_value)
                self.portfolio_weights.append(weights)
                self.portfolio_returns.append(portfolio_return)
                self.dates.append(date)
                
                # Generate the annual report when changing a year
                if len(self.dates) > 1 and self.dates[-1].year != self.dates[-2].year:
                    annual_report = self.create_annual_report(self.dates[-2].year)
                    if annual_report:
                        self.annual_reports.append(annual_report)
                        logging.info(f"Annual report for {annual_report['year']}: "
                                   f"Return: {annual_report['return']:.2%}, "
                                   f"Sharpe: {annual_report['sharpe']:.2f}")
                
        # Create a final daily report
        final_annual_report = self.create_annual_report(self.dates[-1].year)
        if final_annual_report:
            self.annual_reports.append(final_annual_report)
            
        self.create_performance_report()

    def create_performance_report(self):
        returns = np.array(self.portfolio_returns)
        cumulative_returns = np.cumprod(1 + returns) - 1
        total_return = (self.portfolio_values[-1] / self.initial_capital - 1)
        annual_return = (1 + total_return) ** (1 / len(self.annual_reports)) - 1
        volatility = np.std(returns) * np.sqrt(252)
        sharpe_ratio = (annual_return - self.risk_free_rate) / volatility
        max_drawdown = np.min(cumulative_returns - np.maximum.accumulate(cumulative_returns))

        # Create extended graphs
        plt.figure(figsize=(20, 25))
        
        # Portfolio value graph
        plt.subplot(5, 1, 1)
        plt.plot(self.dates, self.portfolio_values, linewidth=2)
        plt.title('Portfolio Value Over Time', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value ($)')
        
        # Returns graph
        plt.subplot(5, 1, 2)
        plt.plot(self.dates, np.array(self.portfolio_returns) * 100, linewidth=1)
        plt.title('Daily Returns (%)', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Return (%)')
        
        # Cumulative returns graph
        plt.subplot(5, 1, 3)
        plt.plot(self.dates, cumulative_returns * 100, linewidth=2)
        plt.title('Cumulative Returns (%)', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Cumulative Return (%)')
        
        # Weight distribution graph
        plt.subplot(5, 1, 4)
        weights_df = pd.DataFrame(self.portfolio_weights, columns=self.all_symbols, index=self.dates)
        sns.heatmap(weights_df.T, cmap='YlOrRd', cbar_kws={'label': 'Weight'})
        plt.title('Portfolio Weights Over Time', fontsize=14)
        
        # Annual performance graph
        plt.subplot(5, 1, 5)
        annual_data = pd.DataFrame(self.annual_reports)
        annual_data.set_index('year', inplace=True)
        annual_data[['return', 'sharpe', 'volatility', 'max_drawdown']].plot(kind='bar', figsize=(15, 5))
        plt.title('Annual Performance Metrics', fontsize=14)
        plt.grid(True)
        plt.tight_layout()
        
        # Save the graphs
        plt.savefig('backtest_results.png')
        
        # Save the portfolio value graph separately
        plt.figure(figsize=(15, 8))
        plt.plot(self.dates, self.portfolio_values, linewidth=2)
        plt.title('Portfolio Value Over Time', fontsize=14)
        plt.grid(True)
        plt.xlabel('Date')
        plt.ylabel('Portfolio Value ($)')
        plt.tight_layout()
        plt.savefig('portfolio_value.png')
        plt.close()  # Close the graph to release memory

        # Original code with the general graph
        plt.figure(figsize=(20, 25))
        
        # Create a detailed text report
        with open('backtest_report.txt', 'w') as f:
            f.write("Forex Portfolio Backtest Report\n")
            f.write("=============================\n\n")
            
            f.write("General Information\n")
            f.write("-----------------\n")
            f.write(f"Period: {self.start_date} to {self.end_date}\n")
            f.write(f"Initial Capital: ${self.initial_capital:,.2f}\n")
            f.write(f"Final Portfolio Value: ${self.portfolio_values[-1]:,.2f}\n")
            f.write(f"Total Return: {total_return:.2%}\n")
            f.write(f"Annualized Return: {annual_return:.2%}\n")
            f.write(f"Annual Volatility: {volatility:.2%}\n")
            f.write(f"Sharpe Ratio: {sharpe_ratio:.2f}\n")
            f.write(f"Maximum Drawdown: {max_drawdown:.2%}\n\n")
            
            f.write("Annual Performance\n")
            f.write("-----------------\n")
            for report in self.annual_reports:
                f.write(f"\nYear {report['year']}:\n")
                f.write(f"  Return: {report['return']:.2%}\n")
                f.write(f"  Volatility: {report['volatility']:.2%}\n")
                f.write(f"  Sharpe Ratio: {report['sharpe']:.2f}\n")
                f.write(f"  Max Drawdown: {report['max_drawdown']:.2%}\n")
                f.write(f"  Ending Value: ${report['ending_value']:,.2f}\n")
            
            f.write("\nPortfolio Composition\n")
            f.write("--------------------\n")
            final_weights = self.portfolio_weights[-1]
            for symbol, weight in zip(self.all_symbols, final_weights):
                f.write(f"{symbol}: {weight:.2%}\n")
                
            # Add currency pair groups statistics
            f.write("\nCurrency Pair Group Statistics\n")
            f.write("----------------------------\n")
            for group, symbols in self.symbols.items():
                group_returns = []
                for symbol in symbols:
                    symbol_idx = self.all_symbols.index(symbol)
                    symbol_weights = [w[symbol_idx] for w in self.portfolio_weights]
                    avg_weight = np.mean(symbol_weights)
                    group_returns.append(avg_weight)
                
                f.write(f"\n{group.upper()}:\n")
                f.write(f"  Average Group Weight: {np.sum(group_returns):.2%}\n")
                f.write(f"  Number of Pairs: {len(symbols)}\n")
        
        logging.info("Performance report created and saved")

# Example of usage for 10-year period
if __name__ == "__main__":
    setup_logging()
    
    start_date = datetime(2000, 1, 1)  # 10 years ago
    end_date = datetime(2024, 12, 1)   # Till the end of 2024
    
    backtester = ForexPortfolioBacktester(start_date, end_date, initial_capital=1000000)
    backtester.run_backtest(rebalance_freq='QE')  # Monthly rebalancing
        
