import pandas as pd
import wbdata
import MetaTrader5 as mt5
from catboost import CatBoostRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import warnings

# Disable warnings
warnings.filterwarnings("ignore", category=UserWarning, module="wbdata")

# Download World Bank data
indicators = {
    "NY.GDP.MKTP.KD.ZG": "GDP growth",
    "FP.CPI.TOTL.ZG": "Inflation",
    "FR.INR.RINR": "Real interest rate",
    "NE.EXP.GNFS.ZS": "Exports",  # % of GDP
    "NE.IMP.GNFS.ZS": "Imports",  # % of GDP
    "BN.CAB.XOKA.GD.ZS": "Current account balance",  # % of GDP
    "GC.DOD.TOTL.GD.ZS": "Government debt",  # % of GDP
    "SL.UEM.TOTL.ZS": "Unemployment rate",  # % of working-age population
    "NY.GNP.PCAP.CD": "GNI per capita",  # current USD
    "NY.GDP.PCAP.KD.ZG": "GDP per capita growth",  # Constant 2010 USD
    "NE.RSB.GNFS.ZS": "Reserves in months of imports",
    "NY.GDP.DEFL.KD.ZG": "GDP deflator",  # Constant 2010 USD
    "NY.GDP.PCAP.KD": "GDP per capita (constant 2015 US$)",
    "NY.GDP.PCAP.PP.CD": "GDP per capita, PPP (current international $)",
    "NY.GDP.PCAP.PP.KD": "GDP per capita, PPP (constant 2017 international $)",
    "NY.GDP.PCAP.CN": "GDP per capita (current LCU)",
    "NY.GDP.PCAP.KN": "GDP per capita (constant LCU)",
    "NY.GDP.PCAP.CD": "GDP per capita (current US$)",
    "NY.GDP.PCAP.KD": "GDP per capita (constant 2010 US$)",
    "NY.GDP.PCAP.KD.ZG": "GDP per capita growth (annual %)",
    "NY.GDP.PCAP.KN.ZG": "GDP per capita growth (constant LCU)",
}

# Get data for each indicator separately
data_frames = []
for indicator in indicators.keys():
    try:
        data_frame = wbdata.get_dataframe(
            {indicator: indicators[indicator]}, country="all"
        )
        data_frames.append(data_frame)
    except Exception as e:
        print(f"Error fetching data for indicator '{indicator}': {e}")

# Combine data into a single DataFrame
data = pd.concat(data_frames, axis=1)

# Display info on available indicators and their data
print("Available indicators and their data:")
print(data.columns)
print(data.head())

# Save data to CSV
data.to_csv("economic_data.csv", index=True)

# Display statistics
print("Economic Data Statistics:")
print(data.describe())

# Download data from MetaTrader 5
if not mt5.initialize():
    print("initialize() failed")
    mt5.shutdown()

# Get all currency pairs
symbols = mt5.symbols_get()
symbol_names = [symbol.name for symbol in symbols]

# Load historical data for each currency pair
historical_data = {}
for symbol in symbol_names:
    rates = mt5.copy_rates_from_pos(symbol, mt5.TIMEFRAME_D1, 0, 1000)
    df = pd.DataFrame(rates)
    df["time"] = pd.to_datetime(df["time"], unit="s")
    df.set_index("time", inplace=True)
    historical_data[symbol] = df


# Prepare data for forecasting
def prepare_data(symbol_data, economic_data):
    data = symbol_data.copy()
    data["close_diff"] = data["close"].diff()
    data["close_corr"] = data["close"].rolling(window=30).corr(data["close"].shift(1))

    for indicator in indicators.keys():
        if indicator in economic_data.columns:
            data[indicator] = economic_data[indicator]
        else:
            print(f"Warning: Data for indicator '{indicator}' is not available.")

    data.dropna(inplace=True)
    return data


# Prepare data for all currency pairs
prepared_data = {}
for symbol, df in historical_data.items():
    prepared_data[symbol] = prepare_data(df, data)


# Forecasting using CatBoost
def forecast(symbol_data):
    X = symbol_data.drop(columns=["close"])
    y = symbol_data["close"]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, shuffle=False
    )

    model = CatBoostRegressor(
        iterations=1000, learning_rate=0.1, depth=8, loss_function="RMSE", verbose=100
    )
    model.fit(X_train, y_train, verbose=False)

    predictions = model.predict(X_test)
    mse = mean_squared_error(y_test, predictions)
    print(f"Mean Squared Error for {symbol}: {mse}")

    # Forecasting one month ahead
    future_data = symbol_data.tail(30).copy()
    if len(predictions) >= 30:
        future_data["close"] = predictions[-30:]
    else:
        future_data["close"] = predictions

    future_predictions = model.predict(future_data.drop(columns=["close"]))

    return future_predictions


# Forecasting for all currency pairs
forecasts = {}
for symbol, df in prepared_data.items():
    try:
        forecasts[symbol] = forecast(df)
    except Exception as e:
        print(f"Error forecasting for {symbol}: {e}")

# MetaTrader5 shutdown
mt5.shutdown()

# Display forecasts
for symbol, forecast in forecasts.items():
    print(f"Forecast for {symbol}: {forecast}")

# Save forecasts in CSV
forecasts_df = pd.DataFrame(forecasts)
forecasts_df.to_csv("forecasts.csv", index=True)
