import MetaTrader5 as mt5
import pandas as pd
import numpy as np
from meteostat import Point, Daily, Hourly
from datetime import datetime, timedelta
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
from tqdm import tqdm
import os

# Filter pandas warnings
warnings.filterwarnings('ignore', category=FutureWarning)

def fetch_historical_weather(years=5):
    """
    Fetch historical weather data for the specified number of years
    """
    key_regions = {
        "AU_WheatBelt": {
            "lat": -31.95, 
            "lon": 116.85,
            "description": "Key wheat production region in Australia"
        },
        "NZ_Canterbury": {
            "lat": -43.53,
            "lon": 172.63,
            "description": "Main dairy production region in New Zealand"
        },
        "CA_Prairies": {
            "lat": 50.45, 
            "lon": -104.61,
            "description": "Canada's breadbasket, producing wheat and rapeseed"
        }
    }
    
    weather_data = {}
    end = datetime.now()
    start = end - timedelta(days=365 * years)
    
    for region, coords in tqdm(key_regions.items(), desc="Loading weather data"):
        try:
            print(f"\nLoading historical data for {region} from {start.date()} to {end.date()}")
            
            location = Point(coords["lat"], coords["lon"])
            data = Daily(location, start, end)
            fetched_data = data.fetch()
            
            if fetched_data is not None and not fetched_data.empty:
                print(f"Fetched records: {len(fetched_data)}")
                processed_data = process_weather_data(fetched_data)
                weather_data[region] = processed_data
                print("Data successfully processed")
            else:
                print("Failed to fetch data for this region")
                weather_data[region] = pd.DataFrame()
                
        except Exception as e:
            print(f"Error fetching data for region {region}: {str(e)}")
            weather_data[region] = pd.DataFrame()
    
    return weather_data

def process_weather_data(raw_data):
    """
    Process weather data focusing on critical parameters for crop yield
    """
    if not isinstance(raw_data.index, pd.DatetimeIndex):
        raw_data.index = pd.to_datetime(raw_data.index)
    
    processed_data = pd.DataFrame(index=raw_data.index)
    
    # In Daily data, different column names are used
    processed_data['temperature'] = raw_data['tavg']  # average daily temperature
    processed_data['temp_min'] = raw_data['tmin']     # minimum temperature
    processed_data['temp_max'] = raw_data['tmax']     # maximum temperature
    processed_data['precipitation'] = raw_data['prcp'] # precipitation
    processed_data['wind_speed'] = raw_data['wspd']   # wind speed
    
    # Calculate GDD (Growing Degree Days)
    processed_data['growing_degree_days'] = calculate_gdd(processed_data['temp_max'], base_temp=10)
    
    # Add indicators for extreme conditions
    processed_data['extreme_temp'] = (
        (processed_data['temperature'] > processed_data['temperature'].quantile(0.95)) |
        (processed_data['temperature'] < processed_data['temperature'].quantile(0.05))
    )
    
    return processed_data

def detect_critical_conditions(weather_data):
    """
    Detect critical weather conditions
    """
    critical_events = []
    
    thresholds = {
        'AU_WheatBelt': {
            'drought_temp': 35,
            'min_rain': 10,
            'max_wind': 30
        },
        'NZ_Canterbury': {
            'frost_temp': 0,
            'flood_rain': 50,
            'drought_days': 14
        },
        'CA_Prairies': {
            'frost_temp': -5,
            'snow_cover': 10,
            'growing_degree_min': 5
        }
    }
    
    for region, data in weather_data.items():
        if not isinstance(data, pd.DataFrame) or data.empty:
            continue
            
        thresholds_region = thresholds.get(region, {})
        
        # Check temperature extremes
        if 'drought_temp' in thresholds_region and 'temperature' in data.columns:
            drought_days = data[data['temperature'] > thresholds_region['drought_temp']]
            if not drought_days.empty:
                critical_events.append({
                    'region': region,
                    'event_type': 'drought_temperature',
                    'dates': drought_days.index.tolist(),
                    'value': drought_days['temperature'].mean()
                })
                
        # Check precipitation
        if 'min_rain' in thresholds_region and 'precipitation' in data.columns:
            daily_rain = data['precipitation'].resample('d').sum()
            dry_days = daily_rain[daily_rain < thresholds_region['min_rain']]
            if not dry_days.empty:
                critical_events.append({
                    'region': region,
                    'event_type': 'low_precipitation',
                    'dates': dry_days.index.tolist(),
                    'value': dry_days.mean()
                })
    
    return critical_events

def get_historical_forex_data(years=5):
    """
    Fetch historical forex data
    """
    if not mt5.initialize():
        print("MT5 initialization error")
        return None
    
    pairs = ["AUDUSD", "NZDUSD", "USDCAD"]
    timeframes = {
        "D1": mt5.TIMEFRAME_D1  # Only daily timeframe for historical data
    }
    
    # Calculate the number of bars to fetch
    bars_needed = years * 365
    
    forex_data = {}
    for pair in tqdm(pairs, desc="Loading forex data"):
        pair_data = {}
        for tf_name, tf in timeframes.items():
            try:
                rates = mt5.copy_rates_from_pos(pair, tf, 0, bars_needed)
                if rates is not None:
                    df = pd.DataFrame(rates)
                    df['time'] = pd.to_datetime(df['time'], unit='s')
                    df.set_index('time', inplace=True)
                    
                    # Add technical indicators for daily timeframe
                    df['volatility'] = df['high'] - df['low']
                    df['range_pct'] = (df['high'] - df['low']) / df['low'] * 100
                    df['price_momentum'] = df['close'].pct_change()
                    
                    # Add monthly change
                    df['monthly_change'] = df['close'].pct_change(20)
                    
                    pair_data[tf_name] = df
                else:
                    print(f"Failed to fetch data for pair {pair} timeframe {tf_name}")
                    pair_data[tf_name] = pd.DataFrame()
            except Exception as e:
                print(f"Error fetching data for pair {pair} timeframe {tf_name}: {str(e)}")
                pair_data[tf_name] = pd.DataFrame()
        
        forex_data[pair] = pair_data
    
    mt5.shutdown()
    return forex_data

def merge_weather_forex_data(weather_data, forex_data):
    """
    Merge weather and forex data
    """
    synchronized_data = {}
    
    region_pair_mapping = {
        'AU_WheatBelt': 'AUDUSD',
        'NZ_Canterbury': 'NZDUSD',
        'CA_Prairies': 'USDCAD'
    }
    
    for region, pair in region_pair_mapping.items():
        if region in weather_data and pair in forex_data:
            weather = weather_data[region]
            # Use D1 timeframe for historical data
            forex = forex_data[pair]['D1']
            
            if not weather.empty and not forex.empty:
                if not isinstance(weather.index, pd.DatetimeIndex):
                    weather.index = pd.to_datetime(weather.index)
                if not isinstance(forex.index, pd.DatetimeIndex):
                    forex.index = pd.to_datetime(forex.index)
                
                # Merge data
                merged = pd.merge_asof(
                    forex,
                    weather,
                    left_index=True,
                    right_index=True,
                    tolerance=pd.Timedelta('1d')
                )
                
                merged = calculate_derived_features(merged)
                merged = clean_merged_data(merged)
                
                synchronized_data[region] = merged
            else:
                print(f"Empty data for region {region} or pair {pair}")
                synchronized_data[region] = pd.DataFrame()
    
    return synchronized_data

def calculate_derived_features(data):
    """
    Calculate derived features
    """
    if not data.empty:
        # Calculate volatility on daily data
        data['price_volatility'] = data['volatility'].rolling(20).std()
        data['temp_change'] = data['temperature'].diff()
        data['precip_intensity'] = data['precipitation'].rolling(7).sum()  # weekly sum of precipitation
        
        # Seasonal features
        data['growing_season'] = (
            (data.index.month >= 4) & 
            (data.index.month <= 9)
        )
        
        # Add monthly changes
        data['monthly_price_change'] = data['close'].pct_change(20)  # approximately a month of trading days
        data['monthly_temp_change'] = data['temperature'].pct_change(30)  # calendar month
    
    return data

def clean_merged_data(data):
    """
    Clean merged data
    """
    if data.empty:
        return data
        
    weather_cols = ['temperature', 'precipitation', 'wind_speed']
    
    # Fill gaps
    for col in weather_cols:
        if col in data.columns:
            data[col] = data[col].ffill(limit=3)
    
    # Remove outliers
    for col in weather_cols:
        if col in data.columns:
            q_low = data[col].quantile(0.01)
            q_high = data[col].quantile(0.99)
            data = data[
                (data[col] > q_low) & 
                (data[col] < q_high)
            ]
    
    return data

def calculate_gdd(temp_data, base_temp=10):
    """
    Calculate Growing Degree Days
    """
    return np.maximum(temp_data - base_temp, 0)

def analyze_and_visualize_correlations(merged_data):
    """
    Analyze and visualize correlations between weather conditions and currency rates
    """
    # Use default matplotlib style instead of seaborn
    plt.style.use('default')
    
    # Set common style for plots
    plt.rcParams['figure.figsize'] = [15, 10]
    plt.rcParams['axes.grid'] = True
    plt.rcParams['axes.labelsize'] = 12
    plt.rcParams['xtick.labelsize'] = 10
    plt.rcParams['ytick.labelsize'] = 10
    
    for region, data in merged_data.items():
        if data.empty:
            continue
            
        print(f"\nCorrelation analysis for region {region}")
        
        # Select columns for analysis
        weather_cols = ['temperature', 'precipitation', 'wind_speed', 'growing_degree_days']
        price_cols = ['close', 'volatility', 'range_pct', 'price_momentum', 'monthly_change']
        
        # Create correlation matrix
        correlation_matrix = pd.DataFrame()
        for w_col in weather_cols:
            if w_col not in data.columns:
                continue
            for p_col in price_cols:
                if p_col not in data.columns:
                    continue
                # Calculate correlations with different lags
                correlations = []
                lags = [0, 5, 10, 20, 30]  # days
                for lag in lags:
                    corr = data[w_col].corr(data[p_col].shift(-lag))
                    correlations.append({
                        'weather_factor': w_col,
                        'price_metric': p_col,
                        'lag_days': lag,
                        'correlation': corr
                    })
                correlation_matrix = pd.concat([
                    correlation_matrix,
                    pd.DataFrame(correlations)
                ])
        
        # Create heatmap of correlations
        plt.figure()
        pivot_table = correlation_matrix.pivot_table(
            index='weather_factor',
            columns='price_metric',
            values='correlation',
            aggfunc='mean'
        )
        
        # Create heatmap using matplotlib
        im = plt.imshow(pivot_table.values, cmap='RdYlBu', aspect='auto')
        plt.colorbar(im)
        
        # Set axes
        plt.xticks(range(len(pivot_table.columns)), pivot_table.columns, rotation=45)
        plt.yticks(range(len(pivot_table.index)), pivot_table.index)
        
        # Add values in cells
        for i in range(len(pivot_table.index)):
            for j in range(len(pivot_table.columns)):
                text = plt.text(j, i, f'{pivot_table.values[i, j]:.2f}',
                              ha='center', va='center')
        
        plt.title(f'Correlations of weather factors and price for {region}')
        plt.tight_layout()
        
        # Save plot
        plt.savefig(f'{region}_correlations_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png')
        plt.close()
        
        # Create lag correlations plot
        plt.figure()
        for w_col in weather_cols:
            if w_col not in data.columns:
                continue
            lag_correlations = correlation_matrix[
                correlation_matrix['weather_factor'] == w_col
            ]
            plt.plot(
                lag_correlations['lag_days'],
                lag_correlations['correlation'],
                marker='o',
                label=w_col
            )
        
        plt.title(f'Lag correlations for {region}')
        plt.xlabel('Lag (days)')
        plt.ylabel('Correlation')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        
        # Save lag plot
        plt.savefig(f'{region}_lag_correlations_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png')
        plt.close()
        
        # Print top-5 strongest correlations
        print("\nTop-5 strongest correlations:")
        top_correlations = correlation_matrix.nlargest(5, 'correlation')
        print(top_correlations[['weather_factor', 'price_metric', 'lag_days', 'correlation']])
        
        # Create seasonal plot
        plt.figure()
        data['month'] = data.index.month
        monthly_correlations = []
        
        for month in range(1, 13):
            month_data = data[data['month'] == month]
            month_corr = {}
            for w_col in weather_cols:
                if w_col not in month_data.columns:
                    continue
                month_corr[w_col] = month_data[w_col].corr(month_data['close'])
            monthly_correlations.append(month_corr)
        
        monthly_df = pd.DataFrame(monthly_correlations, index=range(1, 13))
        for column in monthly_df.columns:
            plt.plot(range(1, 13), monthly_df[column], marker='o', label=column)
            
        plt.title(f'Seasonal correlations for {region}')
        plt.xlabel('Month')
        plt.ylabel('Correlation with price')
        plt.grid(True)
        plt.legend(title='Weather Factor')
        plt.tight_layout()
        
        # Save seasonal plot
        plt.savefig(f'{region}_seasonal_correlations_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png')
        plt.close()
        
        # Save analysis results to a text file
        with open(f'{region}_analysis_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', 'w', encoding='utf-8') as f:
            f.write(f"Correlation analysis for region {region}\n")
            f.write("=" * 50 + "\n\n")
            
            f.write("Average correlations:\n")
            f.write(pivot_table.to_string())
            f.write("\n\n")
            
            f.write("Top-5 strongest correlations:\n")
            f.write(top_correlations[['weather_factor', 'price_metric', 'lag_days', 'correlation']].to_string())
            f.write("\n\n")
            
            f.write("Seasonal correlations:\n")
            f.write(monthly_df.to_string())

def main():
    print("Fetching historical weather data for 5 years...")
    weather_data = fetch_historical_weather(years=5)
    
    print("Fetching historical forex data...")
    forex_data = get_historical_forex_data(years=5)
    
    if forex_data is None:
        print("Error fetching forex data")
        return None, None
    
    print("Merging data...")
    merged_data = merge_weather_forex_data(weather_data, forex_data)
    
    print("Analyzing correlations and creating visualizations...")
    analyze_and_visualize_correlations(merged_data)
    
    print("Analyzing critical conditions...")
    critical_events = detect_critical_conditions(weather_data)
    
    return merged_data, critical_events

if __name__ == "__main__":
    try:
        # Create a folder for results
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        results_dir = f"results_{timestamp}"
        os.makedirs(results_dir, exist_ok=True)
        
        # Change working directory to the results folder
        original_dir = os.getcwd()
        os.chdir(results_dir)
        
        merged_data, critical_events = main()
        
        if merged_data is not None:
            print("\nData successfully fetched and processed")
            
            # Save merged data
            for region, data in merged_data.items():
                if not data.empty:
                    filename = f"{region}_merged_data.csv"
                    data.to_csv(filename)
                    print(f"Data for {region} saved to file {filename}")
            
            # Save critical events
            if critical_events:
                critical_events_df = pd.DataFrame(critical_events)
                filename = "critical_events.csv"
                critical_events_df.to_csv(filename)
                print(f"Critical events saved to file {filename}")
                
                # Create a summary of critical events
                summary_filename = "critical_events_summary.txt"
                with open(summary_filename, 'w', encoding='utf-8') as f:
                    f.write("Summary of critical events\n")
                    f.write("=" * 50 + "\n\n")
                    for region in merged_data.keys():
                        region_events = critical_events_df[critical_events_df['region'] == region]
                        f.write(f"\nRegion: {region}\n")
                        f.write(f"Total events: {len(region_events)}\n")
                        if not region_events.empty:
                            f.write("Event types:\n")
                            for event_type in region_events['event_type'].unique():
                                count = len(region_events[region_events['event_type'] == event_type])
                                f.write(f"- {event_type}: {count}\n")
            
            print(f"\nAll results saved to folder {results_dir}")
        
        # Return to the original directory
        os.chdir(original_dir)
            
    except Exception as e:
        print(f"An error occurred while running the program: {str(e)}")
        print("Details:")
        import traceback
        traceback.print_exc()
        
        # In case of error, also return to the original directory
        if 'original_dir' in locals():
            os.chdir(original_dir)
        
