import MetaTrader5 as mt5
import pandas as pd
import numpy as np
from meteostat import Point, Daily, Hourly
from datetime import datetime, timedelta
import warnings
# Filter pandas warnings
warnings.filterwarnings('ignore', category=FutureWarning)

def fetch_agriculture_weather():
    """
    Fetch weather data for key agricultural regions
    with search for the nearest weather stations
    """
    key_regions = {
        "AU_WheatBelt": {
            "lat": -31.95, 
            "lon": 116.85,
            "description": "Key wheat production region in Australia"
        },
        "NZ_Canterbury": {
            # Updated coordinates for Christchurch - the largest city in the Canterbury region
            "lat": -43.53,
            "lon": 172.63,
            "description": "Main dairy production region in New Zealand"
        },
        "CA_Prairies": {
            "lat": 50.45, 
            "lon": -104.61,
            "description": "Canada's breadbasket, producing wheat and rapeseed"
        }
    }
    
    # Alternative coordinates for regions in case of missing data
    backup_coordinates = {
        "NZ_Canterbury": [
            {"lat": -43.53, "lon": 172.63},  # Christchurch
            {"lat": -44.01, "lon": 171.25},  # Timaru
            {"lat": -43.90, "lon": 171.75}   # Ashburton
        ]
    }
    
    weather_data = {}
    end = datetime.now()
    start = end - timedelta(days=30)
    
    for region, coords in key_regions.items():
        try:
            print(f"\nFetching data for {region}:")
            
            # Main attempt to fetch data
            location = Point(coords["lat"], coords["lon"])
            print(f"Trying main coordinates: lat={coords['lat']}, lon={coords['lon']}")
            data = Hourly(location, start, end)
            fetched_data = data.fetch()
            
            # If data is empty and there are backup coordinates, try them
            if (fetched_data is None or fetched_data.empty) and region in backup_coordinates:
                print(f"Data not fetched, trying alternative coordinates for {region}")
                
                for backup_coords in backup_coordinates[region]:
                    print(f"Trying coordinates: lat={backup_coords['lat']}, lon={backup_coords['lon']}")
                    location = Point(backup_coords["lat"], backup_coords["lon"])
                    data = Hourly(location, start, end)
                    fetched_data = data.fetch()
                    
                    if fetched_data is not None and not fetched_data.empty:
                        print(f"Successfully fetched data from alternative coordinates!")
                        break
            
            # Check fetched data
            if fetched_data is not None and not fetched_data.empty:
                print(f"Fetched records: {len(fetched_data)}")
                print("Available columns:", fetched_data.columns.tolist())
                processed_data = process_weather_data(fetched_data)
                weather_data[region] = processed_data
                print("Data successfully processed")
            else:
                print("Failed to fetch data for this region")
                weather_data[region] = pd.DataFrame()
                
        except Exception as e:
            print(f"Error fetching data for region {region}:")
            print(f"Error type: {type(e).__name__}")
            print(f"Error description: {str(e)}")
            weather_data[region] = pd.DataFrame()
    
    return weather_data

def process_weather_data(raw_data):
    """
    Process weather data focusing on critical parameters for crop yield
    """
    if not isinstance(raw_data.index, pd.DatetimeIndex):
        raw_data.index = pd.to_datetime(raw_data.index)
    
    processed_data = pd.DataFrame(index=raw_data.index)
    
    # Basic metrics
    processed_data['temperature'] = raw_data['temp']
    processed_data['precipitation'] = raw_data['prcp']
    processed_data['wind_speed'] = raw_data['wspd']
    
    # Calculate GDD (Growing Degree Days)
    processed_data['growing_degree_days'] = calculate_gdd(raw_data['temp'])
    
    # Add indicators for extreme conditions
    processed_data['extreme_temp'] = (
        (raw_data['temp'] > raw_data['temp'].quantile(0.95)) |
        (raw_data['temp'] < raw_data['temp'].quantile(0.05))
    )
    
    return processed_data

def detect_critical_conditions(weather_data):
    """
    Detect critical weather conditions
    """
    critical_events = []
    
    thresholds = {
        'AU_WheatBelt': {
            'drought_temp': 35,
            'min_rain': 10,
            'max_wind': 30
        },
        'NZ_Canterbury': {
            'frost_temp': 0,
            'flood_rain': 50,
            'drought_days': 14
        },
        'CA_Prairies': {
            'frost_temp': -5,
            'snow_cover': 10,
            'growing_degree_min': 5
        }
    }
    
    for region, data in weather_data.items():
        if not isinstance(data, pd.DataFrame) or data.empty:
            continue
            
        thresholds_region = thresholds.get(region, {})
        
        # Check temperature extremes
        if 'drought_temp' in thresholds_region and 'temperature' in data.columns:
            drought_days = data[data['temperature'] > thresholds_region['drought_temp']]
            if not drought_days.empty:
                critical_events.append({
                    'region': region,
                    'event_type': 'drought_temperature',
                    'dates': drought_days.index.tolist(),
                    'value': drought_days['temperature'].mean()
                })
                
        # Check precipitation
        if 'min_rain' in thresholds_region and 'precipitation' in data.columns:
            daily_rain = data['precipitation'].resample('d').sum()
            dry_days = daily_rain[daily_rain < thresholds_region['min_rain']]
            if not dry_days.empty:
                critical_events.append({
                    'region': region,
                    'event_type': 'low_precipitation',
                    'dates': dry_days.index.tolist(),
                    'value': dry_days.mean()
                })
    
    return critical_events

def merge_weather_forex_data(weather_data, forex_data):
    """
    Merge weather and forex data
    """
    synchronized_data = {}
    
    region_pair_mapping = {
        'AU_WheatBelt': 'AUDUSD',
        'NZ_Canterbury': 'NZDUSD',
        'CA_Prairies': 'USDCAD'
    }
    
    for region, pair in region_pair_mapping.items():
        if region in weather_data and pair in forex_data:
            weather = weather_data[region]
            forex = forex_data[pair]['H1']
            
            if not weather.empty and not forex.empty:
                if not isinstance(weather.index, pd.DatetimeIndex):
                    weather.index = pd.to_datetime(weather.index)
                if not isinstance(forex.index, pd.DatetimeIndex):
                    forex.index = pd.to_datetime(forex.index)
                
                # Use 'h' instead of 'H'
                weather = weather.resample('h').mean()
                
                # Merge data
                merged = pd.merge_asof(
                    forex,
                    weather,
                    left_index=True,
                    right_index=True,
                    tolerance=pd.Timedelta('1h')
                )
                
                merged = calculate_derived_features(merged)
                merged = clean_merged_data(merged)
                
                synchronized_data[region] = merged
            else:
                print(f"Empty data for region {region} or pair {pair}")
                synchronized_data[region] = pd.DataFrame()
    
    return synchronized_data

def calculate_derived_features(data):
    """
    Calculate derived features
    """
    if not data.empty:
        data['price_volatility'] = data['volatility'].rolling(24).std()
        data['temp_change'] = data['temperature'].diff()
        data['precip_intensity'] = data['precipitation'].rolling(24).sum()
        
        data['growing_season'] = (
            (data.index.month >= 4) & 
            (data.index.month <= 9)
        )
    
    return data

def clean_merged_data(data):
    """
    Clean merged data
    """
    if data.empty:
        return data
        
    weather_cols = ['temperature', 'precipitation', 'wind_speed']
    
    # Use ffill() instead of fillna(method='ffill')
    for col in weather_cols:
        if col in data.columns:
            data[col] = data[col].ffill(limit=3)
    
    # Remove outliers
    for col in weather_cols:
        if col in data.columns:
            q_low = data[col].quantile(0.01)
            q_high = data[col].quantile(0.99)
            data = data[
                (data[col] > q_low) & 
                (data[col] < q_high)
            ]
    
    return data

def get_agricultural_forex_pairs():
    """
    Fetch forex data for agricultural currency pairs via MetaTrader5
    """
    if not mt5.initialize():
        print("MT5 initialization error")
        return None
    
    pairs = ["AUDUSD", "NZDUSD", "USDCAD"]
    timeframes = {
        "H1": mt5.TIMEFRAME_H1,
        "H4": mt5.TIMEFRAME_H4,
        "D1": mt5.TIMEFRAME_D1
    }
    
    forex_data = {}
    
    for pair in pairs:
        pair_data = {}
        for tf_name, tf in timeframes.items():
            try:
                rates = mt5.copy_rates_from_pos(pair, tf, 0, 1000)
                if rates is not None:
                    df = pd.DataFrame(rates)
                    df['time'] = pd.to_datetime(df['time'], unit='s')
                    df.set_index('time', inplace=True)
                    
                    df['volatility'] = df['high'] - df['low']
                    df['range_pct'] = (df['high'] - df['low']) / df['low'] * 100
                    df['price_momentum'] = df['close'].pct_change()
                    
                    pair_data[tf_name] = df
                else:
                    print(f"Failed to fetch data for pair {pair} timeframe {tf_name}")
                    pair_data[tf_name] = pd.DataFrame()
            except Exception as e:
                print(f"Error fetching data for pair {pair} timeframe {tf_name}: {str(e)}")
                pair_data[tf_name] = pd.DataFrame()
        
        forex_data[pair] = pair_data
    
    mt5.shutdown()
    return forex_data

def calculate_gdd(temp_data, base_temp=10):
    """
    Calculate Growing Degree Days
    """
    return np.maximum(temp_data - base_temp, 0)

def main():
    print("Fetching weather data...")
    weather_data = fetch_agriculture_weather()
    
    print("Fetching forex data...")
    forex_data = get_agricultural_forex_pairs()
    
    if forex_data is None:
        print("Error fetching forex data")
        return None, None
    
    print("Merging data...")
    merged_data = merge_weather_forex_data(weather_data, forex_data)
    
    print("Analyzing critical conditions...")
    critical_events = detect_critical_conditions(weather_data)
    
    return merged_data, critical_events

if __name__ == "__main__":
    try:
        merged_data, critical_events = main()
        if merged_data is not None:
            print("Data successfully fetched and processed")
            
            # Save results
            for region, data in merged_data.items():
                if not data.empty:
                    filename = f"{region}_merged_data_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
                    data.to_csv(filename)
                    print(f"Data for {region} saved to file {filename}")
            
            # Save critical events
            if critical_events:
                critical_events_df = pd.DataFrame(critical_events)
                filename = f"critical_events_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
                critical_events_df.to_csv(filename)
                print(f"Critical events saved to file {filename}")
                
    except Exception as e:
        print(f"An error occurred: {str(e)}")
