import MetaTrader5 as mt5
import pandas as pd
import numpy as np
from meteostat import Point, Daily, Hourly
from datetime import datetime, timedelta
import warnings
from sklearn.model_selection import TimeSeriesSplit
from catboost import CatBoostRegressor, CatBoostClassifier
from sklearn.metrics import mean_squared_error, accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# Filter warnings
warnings.filterwarnings('ignore', category=FutureWarning)

def fetch_agriculture_weather(years=5):
    """
    Fetch weather data for key agricultural regions
    """
    key_regions = {
        "AU_WheatBelt": {
            "lat": -31.95, 
            "lon": 116.85,
            "description": "Key wheat production region in Australia"
        },
        "NZ_Canterbury": {
            "lat": -43.53,
            "lon": 172.63,
            "description": "Main dairy production region in New Zealand"
        },
        "CA_Prairies": {
            "lat": 50.45, 
            "lon": -104.61,
            "description": "Canada's breadbasket, producing wheat and rapeseed"
        }
    }
    
    weather_data = {}
    end = datetime.now()
    start = end - timedelta(days=365 * years)
    
    for region, coords in tqdm(key_regions.items(), desc="Loading weather data"):
        try:
            print(f"\nLoading historical data for {region} from {start.date()} to {end.date()}")
            
            location = Point(coords["lat"], coords["lon"])
            data = Daily(location, start, end)
            fetched_data = data.fetch()
            
            if fetched_data is not None and not fetched_data.empty:
                print(f"Fetched records: {len(fetched_data)}")
                processed_data = process_weather_data(fetched_data)
                weather_data[region] = processed_data
                print("Data successfully processed")
            else:
                print(f"Failed to fetch data for {region}")
                weather_data[region] = pd.DataFrame()
                
        except Exception as e:
            print(f"Error fetching data for {region}: {str(e)}")
            weather_data[region] = pd.DataFrame()
    
    return weather_data

def process_weather_data(raw_data):
    """
    Process weather data
    """
    if not isinstance(raw_data.index, pd.DatetimeIndex):
        raw_data.index = pd.to_datetime(raw_data.index)
    
    processed_data = pd.DataFrame(index=raw_data.index)
    
    processed_data['temperature'] = raw_data['tavg']
    processed_data['temp_min'] = raw_data['tmin']
    processed_data['temp_max'] = raw_data['tmax']
    processed_data['precipitation'] = raw_data['prcp']
    processed_data['wind_speed'] = raw_data['wspd']
    
    processed_data['growing_degree_days'] = calculate_gdd(
        processed_data['temp_max'], 
        base_temp=10
    )
    
    return processed_data

def get_historical_forex_data(years=5):
    """
    Fetch historical forex data
    """
    if not mt5.initialize():
        print("MT5 initialization error")
        return None
    
    pairs = ["AUDUSD", "NZDUSD", "USDCAD"]
    timeframes = {
        "D1": mt5.TIMEFRAME_D1
    }
    
    bars_needed = years * 365
    forex_data = {}
    
    for pair in tqdm(pairs, desc="Loading forex data"):
        pair_data = {}
        for tf_name, tf in timeframes.items():
            try:
                rates = mt5.copy_rates_from_pos(pair, tf, 0, bars_needed)
                if rates is not None:
                    df = pd.DataFrame(rates)
                    df['time'] = pd.to_datetime(df['time'], unit='s')
                    df.set_index('time', inplace=True)
                    
                    df['volatility'] = df['high'] - df['low']
                    df['range_pct'] = (df['high'] - df['low']) / df['low'] * 100
                    df['price_momentum'] = df['close'].pct_change()
                    df['monthly_change'] = df['close'].pct_change(20)
                    
                    pair_data[tf_name] = df
                else:
                    print(f"Error fetching data for {pair} {tf_name}")
                    pair_data[tf_name] = pd.DataFrame()
            except Exception as e:
                print(f"Error fetching data for {pair} {tf_name}: {str(e)}")
                pair_data[tf_name] = pd.DataFrame()
        
        forex_data[pair] = pair_data
    
    mt5.shutdown()
    return forex_data

def merge_weather_forex_data(weather_data, forex_data):
    """
    Merge weather and forex data
    """
    synchronized_data = {}
    
    region_pair_mapping = {
        'AU_WheatBelt': 'AUDUSD',
        'NZ_Canterbury': 'NZDUSD',
        'CA_Prairies': 'USDCAD'
    }
    
    for region, pair in region_pair_mapping.items():
        if region in weather_data and pair in forex_data:
            weather = weather_data[region]
            forex = forex_data[pair]['D1']
            
            if not weather.empty and not forex.empty:
                if not isinstance(weather.index, pd.DatetimeIndex):
                    weather.index = pd.to_datetime(weather.index)
                if not isinstance(forex.index, pd.DatetimeIndex):
                    forex.index = pd.to_datetime(forex.index)
                
                # Merge data
                merged = pd.merge_asof(
                    forex,
                    weather,
                    left_index=True,
                    right_index=True,
                    tolerance=pd.Timedelta('1d')
                )
                
                synchronized_data[region] = merged
            else:
                print(f"Empty data for region {region} or pair {pair}")
                synchronized_data[region] = pd.DataFrame()
    
    return synchronized_data

def calculate_gdd(temp_data, base_temp=10):
    """
    Calculate Growing Degree Days
    """
    return np.maximum(temp_data - base_temp, 0)

def prepare_ml_features(data):
    """
    Prepare features for the model
    """
    features = pd.DataFrame(index=data.index)
    
    weather_cols = ['temperature', 'precipitation', 'wind_speed', 'growing_degree_days']
    for col in weather_cols:
        if col not in data.columns:
            continue
        # Current values
        features[col] = data[col]
        
        # Moving averages
        features[f"{col}_ma_24"] = data[col].rolling(24).mean()
        features[f"{col}_ma_72"] = data[col].rolling(72).mean()
        
        # Changes
        features[f"{col}_change"] = data[col].pct_change()
        features[f"{col}_change_24"] = data[col].pct_change(24)
        
        # Volatility
        features[f"{col}_volatility"] = data[col].rolling(24).std()
    
    # Price features
    price_cols = ['volatility', 'range_pct', 'monthly_change']
    for col in price_cols:
        if col not in data.columns:
            continue
        features[f"{col}_ma_24"] = data[col].rolling(24).mean()
    
    # Temporal features
    features['month'] = data.index.month
    features['day_of_week'] = data.index.dayofweek
    features['growing_season'] = ((data.index.month >= 4) & 
                                (data.index.month <= 9)).astype(int)
    
    # Remove missing values
    features = features.dropna()
    
    return features

def create_prediction_targets(data, forecast_horizon=24):
    """
    Create target variables
    """
    targets = pd.DataFrame(index=data.index)
    
    # Price change
    targets['price_change'] = data['close'].pct_change(forecast_horizon).shift(-forecast_horizon)
    
    # Direction of movement
    targets['direction'] = (targets['price_change'] > 0).astype(int)
    
    # Volatility
    if 'volatility' in data.columns:
        targets['volatility'] = data['volatility'].rolling(
            forecast_horizon
        ).mean().shift(-forecast_horizon)
    
    return targets.dropna()

def train_ml_models(merged_data, region):
    """
    Train machine learning models
    """
    data = merged_data[region]
    if data.empty:
        return None
    
    print(f"\nPreparing data for region {region}")
    features = prepare_ml_features(data)
    targets = create_prediction_targets(data)
    
    print("Features:", features.columns.tolist())
    print("Target variables:", targets.columns.tolist())
    
    # Remove missing values
    valid_idx = features.index.intersection(targets.index)
    features = features.loc[valid_idx]
    targets = targets.loc[valid_idx]
    
    # Split into training and testing sets
    tscv = TimeSeriesSplit(n_splits=5)
    
    # Define categorical features
    cat_features = ['month', 'day_of_week', 'growing_season']
    
    models = {
        'direction': CatBoostClassifier(
            iterations=1000,
            learning_rate=0.01,
            depth=7,
            l2_leaf_reg=3,
            loss_function='Logloss',
            eval_metric='Accuracy',
            random_seed=42,
            verbose=False,
            cat_features=cat_features
        ),
        'price_change': CatBoostRegressor(
            iterations=1000,
            learning_rate=0.01,
            depth=7,
            l2_leaf_reg=3,
            loss_function='RMSE',
            random_seed=42,
            verbose=False,
            cat_features=cat_features
        )
    }
    
    if 'volatility' in targets.columns:
        models['volatility'] = CatBoostRegressor(
            iterations=1000,
            learning_rate=0.01,
            depth=7,
            l2_leaf_reg=3,
            loss_function='RMSE',
            random_seed=42,
            verbose=False,
            cat_features=cat_features
        )
    
    results = {}
    for target_name, model in models.items():
        print(f"\nTraining model for {target_name}")
        target = targets[target_name]
        
        fold_metrics = []
        predictions = []
        test_indices = []
        feature_importance = pd.DataFrame()
        
        for fold_idx, (train_idx, test_idx) in enumerate(tscv.split(features)):
            print(f"Training fold {fold_idx + 1}/5")
            X_train = features.iloc[train_idx]
            y_train = target.iloc[train_idx]
            X_test = features.iloc[test_idx]
            y_test = target.iloc[test_idx]
            
            # Train model
            model.fit(
                X_train, 
                y_train,
                eval_set=(X_test, y_test),
                early_stopping_rounds=50,
                verbose=False
            )
            
            # Predictions
            pred = model.predict(X_test)
            predictions.extend(pred)
            test_indices.extend(test_idx)
            
            # Calculate metrics
            if target_name == 'direction':
                metric = accuracy_score(y_test, pred)
            else:
                metric = mean_squared_error(y_test, pred, squared=False)
            
            fold_metrics.append(metric)
            print(f"Fold {fold_idx + 1} metric: {metric:.4f}")
            
            # Save feature importance
            fold_importance = pd.DataFrame({
                'feature': features.columns,
                f'importance_{fold_idx}': model.feature_importances_
            })
            if feature_importance.empty:
                feature_importance = fold_importance
            else:
                feature_importance = feature_importance.merge(
                    fold_importance, on='feature'
                )
        
        # Calculate mean feature importance
        importance_cols = [col for col in feature_importance.columns if 'importance' in col]
        feature_importance['mean_importance'] = feature_importance[importance_cols].mean(axis=1)
        feature_importance = feature_importance.sort_values('mean_importance', ascending=False)
        
        results[target_name] = {
            'model': model,
            'metrics': fold_metrics,
            'mean_metric': np.mean(fold_metrics),
            'predictions': pd.Series(predictions, index=features.index[test_indices]),
            'feature_importance': feature_importance
        }
        
        print(f"Mean metric for {target_name}: {results[target_name]['mean_metric']:.4f}")
    
    return results

def analyze_model_performance(model_results, region):
    """
    Analyze model performance
    """
    # Create a folder for region plots
    region_dir = f"{region}_analysis"
    os.makedirs(region_dir, exist_ok=True)
    
    for target_name, results in model_results.items():
        # Plot metrics per fold
        plt.figure(figsize=(10, 6))
        plt.plot(results['metrics'], marker='o')
        plt.title(f'{target_name} - Metrics per Fold')
        plt.xlabel('Fold')
        plt.ylabel('Accuracy' if target_name == 'direction' else 'RMSE')
        plt.grid(True)
        plt.savefig(f'{region_dir}/{target_name}_metrics.png')
        plt.close()
        
        # Plot feature importance
        plt.figure(figsize=(12, 6))
        importance_data = results['feature_importance'].head(15)
        plt.barh(importance_data['feature'], importance_data['mean_importance'])
        plt.title(f'{target_name} - Feature Importance')
        plt.xlabel('Importance')
        plt.tight_layout()
        plt.savefig(f'{region_dir}/{target_name}_feature_importance.png')
        plt.close()
        
        # Plot predictions (for regression models)
        if target_name != 'direction':
            plt.figure(figsize=(15, 6))
            predictions = results['predictions'].rolling(20).mean()
            plt.plot(predictions, label='Predicted')
            plt.title(f'{target_name} - Rolling Mean of Predictions (20 days)')
            plt.xlabel('Date')
            plt.ylabel(target_name)
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(f'{region_dir}/{target_name}_predictions.png')
            plt.close()
        
        # Save text report
        with open(f'{region_dir}/{target_name}_report.txt', 'w') as f:
            f.write(f"Model report for {target_name} in region {region}\n")
            f.write("=" * 50 + "\n\n")
            f.write(f"Mean metric: {results['mean_metric']:.4f}\n")
            f.write(f"Metrics per fold: {', '.join([f'{m:.4f}' for m in results['metrics']])}\n\n")
            f.write("Top 15 important features:\n")
            for _, row in importance_data.iterrows():
                f.write(f"{row['feature']}: {row['mean_importance']:.4f}\n")

def main():
    print("Fetching historical weather data...")
    weather_data = fetch_agriculture_weather(years=5)
    
    print("\nFetching historical forex data...")
    forex_data = get_historical_forex_data(years=5)
    
    if forex_data is None:
        print("Error fetching forex data")
        return
    
    print("\nMerging data...")
    merged_data = merge_weather_forex_data(weather_data, forex_data)
    
    print("\nTraining models and analyzing results...")
    results_dir = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(results_dir, exist_ok=True)
    os.chdir(results_dir)
    
    for region in merged_data.keys():
        print(f"\nAnalyzing region {region}")
        model_results = train_ml_models(merged_data, region)
        if model_results:
            analyze_model_performance(model_results, region)
    
    os.chdir('..')
    print(f"\nResults saved to folder {results_dir}")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        import traceback
        traceback.print_exc()
