import MetaTrader5 as mt5
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Set the Agg backend for non-interactive plotting
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import pytz
import os

# Create a directory for saving plots if it doesn't exist
def ensure_output_dir():
    """Create output directory for plots if it doesn't exist"""
    output_dir = 'trend_analysis_results'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    return output_dir

def initialize_mt5():
    """Initialize connection to MetaTrader5"""
    if not mt5.initialize():
        print("MT5 initialization error")
        mt5.shutdown()
        return False
    return True

def get_eurusd_data(timeframe=mt5.TIMEFRAME_D1, bars_count=1000):
    """Load EURUSD historical data"""
    # Set UTC timezone
    timezone = pytz.timezone("UTC")
    utc_now = datetime.now(timezone)
    
    # Get data
    eurusd_data = mt5.copy_rates_from_pos("EURUSD", timeframe, 0, bars_count)
    
    if eurusd_data is None or len(eurusd_data) == 0:
        print("Error retrieving EURUSD data")
        return None
    
    # Convert to pandas DataFrame
    df = pd.DataFrame(eurusd_data)
    
    # Convert time from timestamp to datetime
    df['time'] = pd.to_datetime(df['time'], unit='s')
    
    # Set time as index
    df.set_index('time', inplace=True)
    
    return df

def identify_trends(df, window_size=5):
    """
    Identify local maximums and minimums to determine trends
    
    window_size: size of the window for finding local extremes
    """
    # Copy DataFrame to avoid modifying the original
    df_trends = df.copy()
    
    # Find local maximums and minimums
    df_trends['local_max'] = df_trends['high'].rolling(window=window_size, center=True).apply(
        lambda x: x[window_size//2] == max(x), raw=True
    ).fillna(0).astype(bool)
    
    df_trends['local_min'] = df_trends['low'].rolling(window=window_size, center=True).apply(
        lambda x: x[window_size//2] == min(x), raw=True
    ).fillna(0).astype(bool)
    
    # Create structures to store trend information
    trends = []
    current_trend = {'start_idx': 0, 'start_price': 0, 'end_idx': 0, 'end_price': 0, 'type': None}
    
    local_max_idx = df_trends[df_trends['local_max']].index.tolist()
    local_min_idx = df_trends[df_trends['local_min']].index.tolist()
    
    # Combine and sort maximums and minimums by index
    all_extremes = [(idx, 'max', df_trends.loc[idx, 'high']) for idx in local_max_idx] + \
                  [(idx, 'min', df_trends.loc[idx, 'low']) for idx in local_min_idx]
    all_extremes.sort(key=lambda x: x[0])
    
    # Filter repeated local extremes
    filtered_extremes = []
    for i, extreme in enumerate(all_extremes):
        if i == 0:
            filtered_extremes.append(extreme)
            continue
        
        last_extreme = filtered_extremes[-1]
        if last_extreme[1] != extreme[1]:  # Different types (max and min)
            filtered_extremes.append(extreme)
    
    # Create trends based on the sequence of extremes
    for i in range(1, len(filtered_extremes)):
        prev_extreme = filtered_extremes[i-1]
        curr_extreme = filtered_extremes[i]
        
        trend = {
            'start_date': prev_extreme[0],
            'start_price': prev_extreme[2],
            'end_date': curr_extreme[0],
            'end_price': curr_extreme[2],
            'duration': (curr_extreme[0] - prev_extreme[0]).days,
            'points': abs(curr_extreme[2] - prev_extreme[2]) * 10000,  # Convert to points
            'type': 'Uptrend' if curr_extreme[1] == 'max' else 'Downtrend'
        }
        
        # Add percentage change
        trend['percentage'] = (abs(trend['end_price'] - trend['start_price']) / trend['start_price']) * 100
        
        trends.append(trend)
    
    return pd.DataFrame(trends)

def analyze_trends(trend_df, output_dir):
    """Analyze trends and output statistics with saving results"""
    if trend_df.empty:
        print("No data for trend analysis")
        return
    
    # Create file for saving statistics
    stats_file = os.path.join(output_dir, 'trend_statistics.txt')
    with open(stats_file, 'w', encoding='utf-8') as f:
        # Basic statistics for all trends
        header = "=" * 50 + "\nGeneral Trend Statistics:\n" + "=" * 50
        print(header)
        f.write(header + "\n")
        
        stats_text = f"Total trends: {len(trend_df)}\n"
        stats_text += f"Average trend duration: {trend_df['duration'].mean():.2f} days\n"
        stats_text += f"Average trend magnitude: {trend_df['points'].mean():.2f} points\n"
        stats_text += f"Average trend magnitude: {trend_df['percentage'].mean():.2f}%\n"
        stats_text += "-" * 50 + "\n"
        
        print(stats_text)
        f.write(stats_text)
        
        # Statistics by trend types
        up_trends = trend_df[trend_df['type'] == 'Uptrend']
        down_trends = trend_df[trend_df['type'] == 'Downtrend']
        
        up_stats = "Uptrend Statistics:\n"
        up_stats += f"Count: {len(up_trends)}\n"
        if not up_trends.empty:
            up_stats += f"Average duration: {up_trends['duration'].mean():.2f} days\n"
            up_stats += f"Average magnitude: {up_trends['points'].mean():.2f} points\n"
            up_stats += f"Average magnitude: {up_trends['percentage'].mean():.2f}%\n"
            up_stats += f"Maximum magnitude: {up_trends['points'].max():.2f} points\n"
            up_stats += f"Minimum magnitude: {up_trends['points'].min():.2f} points\n"
        up_stats += "-" * 50 + "\n"
        
        print(up_stats)
        f.write(up_stats)
        
        down_stats = "Downtrend Statistics:\n"
        down_stats += f"Count: {len(down_trends)}\n"
        if not down_trends.empty:
            down_stats += f"Average duration: {down_trends['duration'].mean():.2f} days\n"
            down_stats += f"Average magnitude: {down_trends['points'].mean():.2f} points\n"
            down_stats += f"Average magnitude: {down_trends['percentage'].mean():.2f}%\n"
            down_stats += f"Maximum magnitude: {down_trends['points'].max():.2f} points\n"
            down_stats += f"Minimum magnitude: {down_trends['points'].min():.2f} points\n"
        down_stats += "-" * 50 + "\n"
        
        print(down_stats)
        f.write(down_stats)
        
        # Distribution of trends by magnitude
        bins = [0, 50, 100, 200, 500, 1000, float('inf')]
        labels = ['0-50', '50-100', '100-200', '200-500', '500-1000', '1000+']
        trend_df['size_category'] = pd.cut(trend_df['points'], bins=bins, labels=labels)
        
        size_dist = "Distribution of trends by magnitude (in points):\n"
        size_distribution = trend_df['size_category'].value_counts().sort_index()
        for category, count in size_distribution.items():
            size_dist += f"{category}: {count} trends ({count/len(trend_df)*100:.2f}%)\n"
        size_dist += "-" * 50 + "\n"
        
        print(size_dist)
        f.write(size_dist)
        
        # Distribution of trends by duration
        duration_bins = [0, 5, 10, 20, 30, 60, float('inf')]
        duration_labels = ['0-5', '5-10', '10-20', '20-30', '30-60', '60+']
        trend_df['duration_category'] = pd.cut(trend_df['duration'], bins=duration_bins, labels=duration_labels)
        
        duration_dist = "Distribution of trends by duration (in days):\n"
        duration_distribution = trend_df['duration_category'].value_counts().sort_index()
        for category, count in duration_distribution.items():
            duration_dist += f"{category}: {count} trends ({count/len(trend_df)*100:.2f}%)\n"
        duration_dist += "-" * 50 + "\n"
        
        print(duration_dist)
        f.write(duration_dist)
        
        # Top 5 strongest trends
        top_trends_text = "Top 5 strongest trends (by points):\n"
        top_trends = trend_df.nlargest(5, 'points')
        for i, (_, trend) in enumerate(top_trends.iterrows(), 1):
            top_trends_text += f"{i}. {trend['type']} from {trend['start_date'].date()} to {trend['end_date'].date()}: "
            top_trends_text += f"{trend['points']:.2f} points ({trend['percentage']:.2f}%), "
            top_trends_text += f"duration: {trend['duration']} days\n"
        top_trends_text += "=" * 50 + "\n"
        
        print(top_trends_text)
        f.write(top_trends_text)
    
    print(f"Statistics saved to file '{stats_file}'")
    
    # Create additional distribution charts
    save_distribution_plots(trend_df, output_dir)
    
    return {
        'total_trends': len(trend_df),
        'avg_duration': trend_df['duration'].mean(),
        'avg_magnitude': trend_df['points'].mean(),
        'up_trends': len(up_trends),
        'down_trends': len(down_trends)
    }

def save_distribution_plots(trend_df, output_dir):
    """Create and save additional distribution charts"""
    # Set consistent figure size with fixed width of 750px (convert to inches with 100dpi)
    figwidth_inches = 750/100
    figheight_inches = figwidth_inches * (8/12)  # Keep the same aspect ratio

    # 1. Chart of trend magnitude distribution
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    up_trends = trend_df[trend_df['type'] == 'Uptrend']
    down_trends = trend_df[trend_df['type'] == 'Downtrend']
    
    plt.hist([up_trends['points'], down_trends['points']], 
             bins=10, alpha=0.7, color=['green', 'red'], 
             label=['Uptrends', 'Downtrends'])
    
    plt.title('Distribution of Trend Magnitude (in points)')
    plt.xlabel('Trend Magnitude (points)')
    plt.ylabel('Number of Trends')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    magnitude_plot_path = os.path.join(output_dir, 'trend_magnitude_distribution.png')
    plt.savefig(magnitude_plot_path, dpi=100)
    plt.close()
    print(f"Trend magnitude distribution chart saved: {magnitude_plot_path}")
    
    # 2. Chart of trend duration distribution
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    plt.hist([up_trends['duration'], down_trends['duration']], 
             bins=10, alpha=0.7, color=['green', 'red'], 
             label=['Uptrends', 'Downtrends'])
    
    plt.title('Distribution of Trend Duration (in days)')
    plt.xlabel('Trend Duration (days)')
    plt.ylabel('Number of Trends')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    duration_plot_path = os.path.join(output_dir, 'trend_duration_distribution.png')
    plt.savefig(duration_plot_path, dpi=100)
    plt.close()
    print(f"Trend duration distribution chart saved: {duration_plot_path}")
    
    # 3. Chart of relationship between duration and magnitude
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    plt.scatter(up_trends['duration'], up_trends['points'], 
                color='green', alpha=0.7, label='Uptrends')
    plt.scatter(down_trends['duration'], down_trends['points'], 
                color='red', alpha=0.7, label='Downtrends')
    
    plt.title('Relationship Between Duration and Magnitude of Trends')
    plt.xlabel('Trend Duration (days)')
    plt.ylabel('Trend Magnitude (points)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    correlation_plot_path = os.path.join(output_dir, 'trend_correlation.png')
    plt.savefig(correlation_plot_path, dpi=100)
    plt.close()
    print(f"Trend correlation chart saved: {correlation_plot_path}")
    
    # 4. Pie chart of trend type distribution
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    trend_types = trend_df['type'].value_counts()
    plt.pie(trend_types, labels=trend_types.index, autopct='%1.1f%%', 
            colors=['green', 'red'], startangle=90, shadow=True)
    plt.title('Distribution of Trend Types')
    plt.axis('equal')
    pie_chart_path = os.path.join(output_dir, 'trend_types_pie.png')
    plt.savefig(pie_chart_path, dpi=100)
    plt.close()
    print(f"Trend types pie chart saved: {pie_chart_path}")
    
    # 5. Box plots for trend characteristics comparison
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    
    # Box plot for trend magnitude
    plt.subplot(1, 2, 1)
    trend_df.boxplot(column=['points'], by='type')
    plt.title('Distribution of Trend Magnitude (in points)')
    plt.suptitle('')  # Remove automatic title
    plt.ylabel('Trend Magnitude (points)')
    plt.grid(True, alpha=0.3)
    
    # Box plot for trend duration
    plt.subplot(1, 2, 2)
    trend_df.boxplot(column=['duration'], by='type')
    plt.title('Distribution of Trend Duration (in days)')
    plt.suptitle('')  # Remove automatic title
    plt.ylabel('Trend Duration (days)')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    boxplot_path = os.path.join(output_dir, 'trend_boxplots.png')
    plt.savefig(boxplot_path, dpi=100)
    plt.close()
    print(f"Trend boxplots saved: {boxplot_path}")

def plot_trends(df, trend_df, output_dir):
    """Visualize trends on a chart with saving"""
    # Set consistent figure size with fixed width of 750px (convert to inches with 100dpi)
    figwidth_inches = 750/100
    figheight_inches = figwidth_inches * (10/15)  # Keep the same aspect ratio
    
    plt.figure(figsize=(figwidth_inches, figheight_inches))
    
    # Main price chart
    plt.subplot(2, 1, 1)
    plt.plot(df.index, df['close'], label='EURUSD Close', color='blue', alpha=0.7)
    plt.title('EURUSD Trend Analysis')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Mark local maximums and minimums
    for _, trend in trend_df.iterrows():
        start_color = 'green' if trend['type'] == 'Uptrend' else 'red'
        end_color = 'red' if trend['type'] == 'Uptrend' else 'green'
        
        # Start point of trend
        plt.scatter(trend['start_date'], trend['start_price'], color=start_color, s=50)
        
        # End point of trend
        plt.scatter(trend['end_date'], trend['end_price'], color=end_color, s=50)
        
        # Trend line
        plt.plot([trend['start_date'], trend['end_date']], 
                 [trend['start_price'], trend['end_price']], 
                 color='orange', alpha=0.5, linestyle='--')
    
    # Chart of trend magnitude distribution
    plt.subplot(2, 1, 2)
    
    up_trends = trend_df[trend_df['type'] == 'Uptrend']
    down_trends = trend_df[trend_df['type'] == 'Downtrend']
    
    plt.hist([up_trends['points'], down_trends['points']], 
             bins=10, alpha=0.7, color=['green', 'red'], 
             label=['Uptrends', 'Downtrends'])
    
    plt.title('Distribution of Trend Magnitude (in points)')
    plt.xlabel('Trend Magnitude (points)')
    plt.ylabel('Number of Trends')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    main_plot_path = os.path.join(output_dir, 'eurusd_trend_analysis.png')
    plt.savefig(main_plot_path, dpi=100)
    plt.close()
    print(f"Main trend analysis chart saved: {main_plot_path}")

def main():
    # Initialize MT5
    if not initialize_mt5():
        return
    
    print("EURUSD Trend Analysis")
    print("=" * 50)
    
    # Create directory for results
    output_dir = ensure_output_dir()
    print(f"Results will be saved to directory: {output_dir}")
    
    # Get data and settings
    timeframe_map = {
        1: mt5.TIMEFRAME_D1,
        2: mt5.TIMEFRAME_H4,
        3: mt5.TIMEFRAME_H1,
        4: mt5.TIMEFRAME_M30,
        5: mt5.TIMEFRAME_M15
    }
    
    print("Select timeframe:")
    print("1. D1 (daily)")
    print("2. H4 (4 hours)")
    print("3. H1 (1 hour)")
    print("4. M30 (30 minutes)")
    print("5. M15 (15 minutes)")
    
    choice = int(input("Enter timeframe number (1-5): "))
    
    if choice not in timeframe_map:
        print("Invalid choice. Using daily timeframe.")
        choice = 1
    
    timeframe = timeframe_map[choice]
    
    bars_count = int(input("Enter number of bars for analysis (recommended 500-1000): "))
    window_size = int(input("Enter window size for finding extremes (recommended 5-15): "))
    
    # Load data
    df = get_eurusd_data(timeframe=timeframe, bars_count=bars_count)
    
    if df is None:
        print("Failed to get data. Terminating program.")
        mt5.shutdown()
        return
    
    print(f"Loaded EURUSD data from {df.index[0]} to {df.index[-1]}")
    
    # Identify trends
    trend_df = identify_trends(df, window_size=window_size)
    
    if trend_df.empty:
        print("Failed to identify trends. Try changing parameters.")
        mt5.shutdown()
        return
    
    # Analyze trends and save statistics
    stats = analyze_trends(trend_df, output_dir)
    
    # Visualize trends and save charts
    plot_trends(df, trend_df, output_dir)
    
    # Save results to CSV
    csv_path = os.path.join(output_dir, 'eurusd_trends.csv')
    trend_df.to_csv(csv_path)
    print(f"Analysis results saved to file '{csv_path}'")
    
    # Save data to Excel with formatting
    excel_path = os.path.join(output_dir, 'eurusd_trends.xlsx')
    with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
        trend_df.to_excel(writer, sheet_name='Trends')
        workbook = writer.book
        worksheet = writer.sheets['Trends']
        
        # Add formatting
        header_format = workbook.add_format({'bold': True, 'bg_color': '#D7E4BC', 'border': 1})
        for col_num, value in enumerate(trend_df.columns.values):
            worksheet.write(0, col_num + 1, value, header_format)
        
        # Set column widths
        for i, col in enumerate(trend_df.columns):
            worksheet.set_column(i + 1, i + 1, 15)
    
    print(f"Analysis results saved to Excel file '{excel_path}'")
    
    # Finalize MT5 work
    mt5.shutdown()
    print("Analysis completed. MT5 disconnected.")
    print(f"All analysis results saved to directory: {output_dir}")

if __name__ == "__main__":
    main()
