import requests
import pandas as pd
import numpy as np
import json
import time
import logging
from typing import Dict, List, Optional
from datetime import datetime, timedelta
from pathlib import Path
from scipy.stats import pearsonr
import warnings

# Visualization
import matplotlib

matplotlib.use("Agg")  # For saving without GUI
import matplotlib.pyplot as plt
import seaborn as sns

# MetaTrader 5
import MetaTrader5 as mt5

warnings.filterwarnings("ignore")
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Configure charts - EXACT width of 700px
plt.rcParams["figure.figsize"] = (7, 4.2)  # 700px in case of 100 DPI
plt.rcParams["figure.dpi"] = 100  # Exactly 100 DPI
plt.rcParams["savefig.dpi"] = 100  # When saving 100 DPI as well
plt.rcParams["savefig.bbox"] = "tight"
plt.style.use("seaborn-v0_8")


class MetaTrader5Connector:
    """Connect to MetaTrader 5 to receive currency pairs"""

    def __init__(self):
        self.connected = False
        self.currency_symbols = []

    def connect(self) -> bool:
        """Connect to MT5"""
        try:
            if not mt5.initialize():
                logger.error(f"MT5 initialization failed: {mt5.last_error()}")
                return False

            self.connected = True
            logger.info("MT5 connected successfully")
            return True

        except Exception as e:
            logger.error(f"MT5 connection error: {e}")
            return False

    def get_currency_pairs(self) -> List[str]:
        """Get the list of currency pairs from MT5"""
        if not self.connected:
            logger.warning("MT5 not connected, using default pairs")
            return [
                "EURUSD",
                "GBPUSD",
                "USDJPY",
                "AUDUSD",
                "USDCAD",
                "USDCHF",
                "NZDUSD",
            ]

        try:
            # Get all symbols
            symbols = mt5.symbols_get()
            if symbols is None:
                logger.error("Failed to get symbols from MT5")
                return ["EURUSD", "GBPUSD"]

            # Sort only currency pairs (Forex)
            currency_pairs = []
            for symbol in symbols:
                if (
                    symbol.path.startswith("Forex")
                    or "forex" in symbol.path.lower()
                    or len(symbol.name) == 6
                    and symbol.name.isalpha()
                ):

                    # Make sure these are basic currency pairs
                    major_currencies = [
                        "USD",
                        "EUR",
                        "GBP",
                        "JPY",
                        "AUD",
                        "CAD",
                        "CHF",
                        "NZD",
                    ]
                    if (
                        symbol.name[:3] in major_currencies
                        and symbol.name[3:6] in major_currencies
                        and symbol.name[:3] != symbol.name[3:6]
                    ):
                        currency_pairs.append(symbol.name)

            # Remove duplicates and sort
            currency_pairs = list(set(currency_pairs))
            currency_pairs.sort()

            logger.info(
                f"Found {len(currency_pairs)} currency pairs from MT5: {currency_pairs[:10]}"
            )

            # Reutrn top 10 for analysis
            return currency_pairs[:10] if currency_pairs else ["EURUSD", "GBPUSD"]

        except Exception as e:
            logger.error(f"Error getting currency pairs from MT5: {e}")
            return ["EURUSD", "GBPUSD"]

    def get_currency_data(
        self, symbol: str, timeframe=mt5.TIMEFRAME_D1, count: int = 100
    ) -> pd.DataFrame:
        """Get currency pair data from MT5"""
        if not self.connected:
            logger.warning(f"MT5 not connected, cannot get data for {symbol}")
            return pd.DataFrame()

        try:
            # Get data
            rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, count)
            if rates is None:
                logger.warning(f"No data for {symbol}")
                return pd.DataFrame()

            # Convert to DataFrame
            df = pd.DataFrame(rates)
            df["time"] = pd.to_datetime(df["time"], unit="s")
            df["symbol"] = symbol

            logger.info(f"Got {len(df)} records for {symbol}")
            return df

        except Exception as e:
            logger.error(f"Error getting data for {symbol}: {e}")
            return pd.DataFrame()

    def disconnect(self):
        """Disconnect from MT5"""
        if self.connected:
            mt5.shutdown()
            self.connected = False
            logger.info("MT5 disconnected")


class IMFDataMiner:
    """IMF data mining with indicator testing"""

    def __init__(self):
        self.base_url = "http://dataservices.imf.org/REST/SDMX_JSON.svc"
        self.session = requests.Session()
        self.currency_country_map = {
            "USD": "US",
            "EUR": "U2",
            "GBP": "GB",
            "JPY": "JP",
            "AUD": "AU",
            "CAD": "CA",
            "CHF": "CH",
            "NZD": "NZ",
        }
        self.session.headers.update(
            {"User-Agent": "IMF-Data-Miner/1.0", "Accept": "application/json"}
        )

        # Chart folder
        self.viz_dir = Path("imf_visualizations_700px")
        self.viz_dir.mkdir(exist_ok=True)

    def test_indicator(self, indicator: str, country: str = "US") -> bool:
        """Test if a single indicator works"""
        url = f"{self.base_url}/CompactData/IFS/A.{country}.{indicator}"
        params = {"startPeriod": "2020", "endPeriod": "2023"}

        try:
            time.sleep(0.5)  # Rate limiting
            response = self.session.get(url, params=params, timeout=30)

            if response.status_code == 200:
                data = response.json()
                # Check if we got actual data
                if (
                    "CompactData" in data
                    and "DataSet" in data["CompactData"]
                    and "Series" in data["CompactData"]["DataSet"]
                ):
                    logger.info(f"✓ {indicator} works")
                    return True

            logger.warning(f"✗ {indicator} - no data")
            return False

        except Exception as e:
            logger.warning(f"✗ {indicator} - error: {e}")
            return False

    def find_working_indicators(self) -> List[str]:
        """Find indicators that actually work"""
        # Extended list of potential indicators to test
        test_indicators = [
            "NGDP_RPCH",
            "NGDP_XDC",
            "NGDP_USD",
            "PCPIPCH",
            "PCPI_PC_PP_PT",
            "PCPI_IX",
            "LUR",
            "LUR_PT",
            "LE_PT",
            "GGXWDG_NGDP",
            "GGR_NGDP",
            "GGX_NGDP",
            "GGSB_NPGDP",
            "BCA_BP6_USD",
            "BCA_NGDPD",
            "TXG_FOB_USD",
            "TMG_CIF_USD",
            "BXGS_BP6_USD",
            "TXG_RPCH_PA",
            "TMG_RPCH_PA",
            "FPOLM_PA",
            "FIGB_PA",
            "FILR_PA",
            "FIDR_PA",
            "FM1_XDC",
            "FM2_XDC",
            "FM3_XDC",
            "RAXG_USD",
            "RAFG_USD",
            "RAGG_XOZ",
            "ENEER_IX",
            "EREER_IX",
            "ENDA_XDC_USD_RATE",
            "LP",
            "LE",
            "LF",
            "NID_NGDP",
            "NCG_NGDP",
            "NCP_NGDP",
            "NGSD_NGDP",
            "NGDP_D",
            "NGDP_R_K_PT",
        ]

        working_indicators = []
        failed_indicators = []

        logger.info(f"Testing {len(test_indicators)} indicators...")

        for indicator in test_indicators:
            if self.test_indicator(indicator):
                working_indicators.append(indicator)
            else:
                failed_indicators.append(indicator)

        logger.info(
            f"Found {len(working_indicators)} working indicators: {working_indicators}"
        )

        # GRAPH 1: Indicators test results
        self._plot_indicator_results(working_indicators, failed_indicators)

        return working_indicators

    def _plot_indicator_results(self, working: List[str], failed: List[str]):
        """Test results graph"""
        fig, (ax1, ax2) = plt.subplots(1, 2)

        # Pie chart
        sizes = [len(working), len(failed)]
        labels = [f"Working ({len(working)})", f"Failed ({len(failed)})"]
        colors = ["#2ecc71", "#e74c3c"]

        ax1.pie(sizes, labels=labels, colors=colors, autopct="%1.1f%%", startangle=90)
        ax1.set_title("Indicator Testing Results")

        # Bar chart by categories
        categories = {
            "GDP": len([i for i in working if "GDP" in i]),
            "CPI": len([i for i in working if "CPI" in i or "PCPI" in i]),
            "Labor": len([i for i in working if "LUR" in i or "LE" in i or "LP" in i]),
            "Gov": len([i for i in working if "GG" in i]),
            "External": len(
                [i for i in working if "BCA" in i or "TX" in i or "TM" in i]
            ),
            "Monetary": len(
                [i for i in working if "FP" in i or "FI" in i or "FM" in i]
            ),
            "FX": len([i for i in working if "RA" in i or "EN" in i or "ER" in i]),
        }

        ax2.bar(categories.keys(), categories.values(), color="skyblue")
        ax2.set_title("Working Indicators by Category")
        ax2.set_ylabel("Count")
        plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)

        plt.tight_layout()
        plt.savefig(
            self.viz_dir / "indicator_test_results.png", dpi=100, bbox_inches="tight"
        )
        plt.close()
        logger.info(f"Saved indicator test results chart")

    def get_available_datasets(self) -> pd.DataFrame:
        """Get list of available datasets"""
        url = f"{self.base_url}/Dataflow"
        try:
            response = self.session.get(url, timeout=30)
            response.raise_for_status()
            data = response.json()

            datasets = []
            for flow in data["Structure"]["Dataflows"]["Dataflow"]:
                datasets.append({"id": flow["@id"], "name": flow["Name"]["#text"]})

            df = pd.DataFrame(datasets)

            # GRAPH 2: Reviewing datasets
            if not df.empty:
                self._plot_datasets_overview(df)

            return df
        except Exception as e:
            logger.error(f"Error fetching datasets: {e}")
            return pd.DataFrame()

    def _plot_datasets_overview(self, datasets_df: pd.DataFrame):
        """Dataset review graph"""
        fig, (ax1, ax2) = plt.subplots(2, 1)

        # Analyze key words in names
        names = datasets_df["name"].str.lower()
        keywords = {
            "Financial": names.str.contains("financial|monetary|bank").sum(),
            "Trade": names.str.contains("trade|export|import").sum(),
            "Economic": names.str.contains("economic|gdp|indicator").sum(),
            "Government": names.str.contains("government|fiscal|debt").sum(),
            "Price": names.str.contains("price|inflation|cpi").sum(),
            "Exchange": names.str.contains("exchange|currency|rate").sum(),
        }

        # Bar chart
        ax1.bar(keywords.keys(), keywords.values(), color="lightcoral")
        ax1.set_title(f"IMF Datasets by Category (Total: {len(datasets_df)})")
        ax1.set_ylabel("Number of Datasets")
        plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)

        # Pie chart
        ax2.pie(keywords.values(), labels=keywords.keys(), autopct="%1.1f%%")
        ax2.set_title("Dataset Category Distribution")

        plt.tight_layout()
        plt.savefig(
            self.viz_dir / "datasets_overview.png", dpi=100, bbox_inches="tight"
        )
        plt.close()
        logger.info(f"Saved datasets overview chart")

    def get_countries_list(self) -> pd.DataFrame:
        """Get list of countries"""
        url = f"{self.base_url}/CodeList/CL_AREA_IFS"
        try:
            response = self.session.get(url, timeout=30)
            response.raise_for_status()
            data = response.json()

            countries = []
            if "Structure" in data and "CodeLists" in data["Structure"]:
                codelist = data["Structure"]["CodeLists"]["CodeList"]
                if "Code" in codelist:
                    codes = codelist["Code"]
                    if not isinstance(codes, list):
                        codes = [codes]

                    for country in codes:
                        countries.append(
                            {
                                "code": country["@value"],
                                "name": (
                                    country["Description"]["#text"]
                                    if "Description" in country
                                    else country["@value"]
                                ),
                                "has_currency": country["@value"]
                                in self.currency_country_map.values(),
                            }
                        )

            df = pd.DataFrame(countries)

            # GRAPH 3: Analyze countries
            if not df.empty:
                self._plot_countries_analysis(df)

            return df
        except Exception as e:
            logger.error(f"Error fetching countries: {e}")
            return pd.DataFrame()

    def _plot_countries_analysis(self, countries_df: pd.DataFrame):
        """Country analysis graph"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

        # 1. Currency vs regular countries
        currency_count = countries_df["has_currency"].sum()
        regular_count = len(countries_df) - currency_count

        ax1.pie(
            [currency_count, regular_count],
            labels=[
                f"Major Currencies ({currency_count})",
                f"Other Countries ({regular_count})",
            ],
            colors=["gold", "lightblue"],
            autopct="%1.1f%%",
        )
        ax1.set_title("Currency Coverage")

        # 2. Currency (simulation) top
        currencies = ["EUR", "GBP", "JPY", "AUD", "CAD", "CHF", "NZD"]
        volumes = np.random.uniform(10, 100, len(currencies))

        ax2.bar(currencies, volumes, color="lightgreen")
        ax2.set_title("Major Currencies (Simulated Data)")
        ax2.set_ylabel("Activity Score")

        # 3. Regional distribution
        regions = ["North America", "Europe", "Asia-Pacific"]
        region_counts = [2, 3, 3]  # USD+CAD, EUR+GBP+CHF, JPY+AUD+NZD

        ax3.bar(regions, region_counts, color=["red", "green", "blue"])
        ax3.set_title("Currencies by Region")
        ax3.set_ylabel("Number of Currencies")

        # 4. Time series (simulating activity)
        dates = pd.date_range("2020-01-01", "2024-01-01", freq="Q")
        activity = np.cumsum(np.random.normal(1, 0.5, len(dates)))

        ax4.plot(dates, activity, marker="o", color="purple")
        ax4.set_title("IMF Data Activity Over Time")
        ax4.set_ylabel("Activity Index")
        plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45)

        plt.tight_layout()
        plt.savefig(
            self.viz_dir / "countries_analysis.png", dpi=100, bbox_inches="tight"
        )
        plt.close()
        logger.info(f"Saved countries analysis chart")

    def fetch_specific_data(
        self,
        countries: List[str],
        indicators: List[str],
        start_year: int,
        end_year: int,
    ) -> pd.DataFrame:
        """Fetch data for specific indicators that we know work"""
        all_data = []

        # Test each indicator individually first
        working_indicators = []
        for indicator in indicators:
            if self.test_indicator(indicator, countries[0]):
                working_indicators.append(indicator)

        if not working_indicators:
            logger.warning("No working indicators found")
            return pd.DataFrame()

        logger.info(
            f"Using {len(working_indicators)} working indicators: {working_indicators}"
        )

        # Fetch data for working indicators
        for indicator in working_indicators:
            for country_pair in [countries]:  # Process countries in pairs if needed
                time.sleep(1)  # Rate limiting

                countries_string = "+".join(country_pair)
                url = (
                    f"{self.base_url}/CompactData/IFS/A.{countries_string}.{indicator}"
                )
                params = {"startPeriod": str(start_year), "endPeriod": str(end_year)}

                try:
                    logger.info(f"Fetching {indicator} for {countries_string}")
                    response = self.session.get(url, params=params, timeout=60)
                    response.raise_for_status()

                    raw_data = response.json()
                    batch_df = self._parse_response_data(raw_data)

                    if not batch_df.empty:
                        all_data.append(batch_df)
                        logger.info(f"Success: {len(batch_df)} records for {indicator}")
                    else:
                        logger.warning(f"No data for {indicator}")

                except Exception as e:
                    logger.warning(f"Failed {indicator}: {e}")
                    continue

        if all_data:
            df = pd.concat(all_data, ignore_index=True)
            logger.info(f"Total fetched: {len(df)} records")

            # GRAPH 4: Data retrieval results
            self._plot_data_fetch_results(df, working_indicators, countries)

            return df
        else:
            return pd.DataFrame()

    def _plot_data_fetch_results(
        self, data_df: pd.DataFrame, indicators: List[str], countries: List[str]
    ):
        """Data retrieval results graph"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

        # 1. Indicator data
        if "INDICATOR" in data_df.columns:
            indicator_counts = data_df["INDICATOR"].value_counts()
            ax1.bar(
                range(len(indicator_counts)), indicator_counts.values, color="orange"
            )
            ax1.set_title("Data Points by Indicator")
            ax1.set_ylabel("Count")
            ax1.set_xticks(range(len(indicator_counts)))
            ax1.set_xticklabels(indicator_counts.index, rotation=45, ha="right")

        # 2. Data by country
        if "REF_AREA" in data_df.columns:
            country_counts = data_df["REF_AREA"].value_counts()
            ax2.pie(
                country_counts.values, labels=country_counts.index, autopct="%1.1f%%"
            )
            ax2.set_title("Data by Country")

        # 3. Time distribution
        if "time_period" in data_df.columns:
            data_df["year"] = pd.to_datetime(
                data_df["time_period"], errors="coerce"
            ).dt.year
            year_counts = data_df["year"].value_counts().sort_index()
            ax3.plot(year_counts.index, year_counts.values, marker="o", color="red")
            ax3.set_title("Data Points Over Time")
            ax3.set_ylabel("Count")
            ax3.grid(True, alpha=0.3)

        # 4. Data quality
        total_possible = len(indicators) * len(countries) * 5  # 5 years
        actual = len(data_df)
        missing = max(0, total_possible - actual)

        ax4.pie(
            [actual, missing],
            labels=[f"Available ({actual})", f"Missing ({missing})"],
            colors=["green", "red"],
            autopct="%1.1f%%",
        )
        ax4.set_title("Data Completeness")

        plt.tight_layout()
        plt.savefig(
            self.viz_dir / "data_fetch_results.png", dpi=100, bbox_inches="tight"
        )
        plt.close()
        logger.info(f"Saved data fetch results chart")

    def _parse_response_data(self, data: Dict) -> pd.DataFrame:
        """Parse IMF API response"""
        records = []
        try:
            compact_data = data["CompactData"]
            dataset = compact_data["DataSet"]

            if "Series" not in dataset:
                return pd.DataFrame()

            series_list = dataset["Series"]
            if not isinstance(series_list, list):
                series_list = [series_list]

            for series in series_list:
                series_attrs = {
                    k.replace("@", ""): v
                    for k, v in series.items()
                    if k.startswith("@")
                }
                obs_list = series.get("Obs", [])
                if not isinstance(obs_list, list):
                    obs_list = [obs_list]

                for obs in obs_list:
                    if isinstance(obs, dict):
                        record = series_attrs.copy()
                        record.update(
                            {
                                "time_period": obs.get("@TIME_PERIOD", ""),
                                "value": obs.get("@OBS_VALUE", ""),
                                "status": obs.get("@OBS_STATUS", ""),
                            }
                        )
                        records.append(record)

            df = pd.DataFrame(records)
            if "value" in df.columns:
                df["value"] = pd.to_numeric(df["value"], errors="coerce")
            if "time_period" in df.columns:
                df["time_period"] = pd.to_datetime(df["time_period"], errors="coerce")

            return df
        except Exception as e:
            logger.error(f"Error parsing response: {e}")
            return pd.DataFrame()


class CurrencyCorrelationAnalyzer:
    """Correlation analyzer with working indicators"""

    def __init__(self, data_miner: IMFDataMiner):
        self.data_miner = data_miner
        self.viz_dir = data_miner.viz_dir

    def analyze_correlations(self, currency_pairs: List[str]) -> Dict:
        """Analyze correlations using discovered working indicators"""

        # First, find what indicators actually work
        logger.info("Discovering working indicators...")
        working_indicators = self.data_miner.find_working_indicators()

        if not working_indicators:
            logger.error("No working indicators found!")
            return {"pair_correlations": {}, "summary_stats": {}}

        results = {
            "pair_correlations": {},
            "summary_stats": {},
            "working_indicators": working_indicators,
        }
        end_year = datetime.now().year
        start_year = end_year - 5

        for pair in currency_pairs:
            base_currency = pair[:3]
            quote_currency = pair[3:]

            base_country = self.data_miner.currency_country_map.get(base_currency)
            quote_country = self.data_miner.currency_country_map.get(quote_currency)

            if not base_country or not quote_country:
                logger.warning(f"Unknown currency mapping for {pair}")
                continue

            logger.info(f"Analyzing {pair} ({base_country} vs {quote_country})")

            # Fetch data
            countries = [base_country, quote_country]
            economic_data = self.data_miner.fetch_specific_data(
                countries, working_indicators, start_year, end_year
            )

            if economic_data.empty:
                logger.warning(f"No data for {pair}")
                continue

            # Analyze correlations
            pair_analysis = self._analyze_pair(
                economic_data, base_country, quote_country, pair
            )
            results["pair_correlations"][pair] = pair_analysis

        # Calculate summary
        results["summary_stats"] = self._calculate_summary(results["pair_correlations"])

        # GRAPH 5: Final correlation dashboard
        self._plot_correlation_dashboard(results)

        return results

    def _analyze_pair(
        self, data: pd.DataFrame, base_country: str, quote_country: str, pair: str
    ) -> Dict:
        """Analyze single currency pair"""

        results = {"correlations": {}, "data_quality": {}}

        try:
            if "INDICATOR" not in data.columns:
                logger.warning(f"No INDICATOR column in data for {pair}")
                return results

            logger.info(f"Available indicators in data: {data['INDICATOR'].unique()}")
            logger.info(f"Available countries in data: {data['REF_AREA'].unique()}")

            # Pivot data
            pivoted = data.pivot_table(
                index="time_period",
                columns=["REF_AREA", "INDICATOR"],
                values="value",
                aggfunc="first",
            )

            logger.info(f"Pivoted data shape: {pivoted.shape}")
            logger.info(f"Available columns: {pivoted.columns.tolist()}")

            if pivoted.empty:
                logger.warning(f"Empty pivoted data for {pair}")
                return results

            # Generate synthetic currency returns
            currency_returns = self._generate_currency_returns(
                pivoted, base_country, quote_country
            )

            if currency_returns.empty:
                logger.warning(f"No currency returns generated for {pair}")
                return results

            # GRAPH 6: Currency pair analysis
            self._plot_pair_analysis(
                pivoted, currency_returns, pair, base_country, quote_country
            )

            # Calculate correlations for available indicators
            unique_indicators = data["INDICATOR"].unique()
            correlation_count = 0

            for indicator in unique_indicators:
                base_col = (base_country, indicator)
                quote_col = (quote_country, indicator)

                if base_col in pivoted.columns and quote_col in pivoted.columns:
                    base_series = pivoted[base_col].dropna()
                    quote_series = pivoted[quote_col].dropna()

                    if len(base_series) > 0 and len(quote_series) > 0:
                        differential = base_series - quote_series
                        differential = differential.dropna()

                        if len(differential) > 2:
                            corr_result = self._calculate_correlation(
                                differential, currency_returns
                            )

                            if corr_result["sample_size"] > 2:
                                results["correlations"][indicator] = corr_result
                                correlation_count += 1
                                logger.info(
                                    f"Calculated correlation for {indicator}: {corr_result['correlation']}"
                                )

            logger.info(f"Calculated {correlation_count} correlations for {pair}")

            results["data_quality"] = {
                "data_points": len(pivoted),
                "indicators_used": correlation_count,
                "time_span": (
                    f"{pivoted.index.min().year}-{pivoted.index.max().year}"
                    if len(pivoted) > 0
                    else "No data"
                ),
            }

        except Exception as e:
            logger.error(f"Error analyzing pair {pair}: {e}")

        return results

    def _plot_pair_analysis(
        self,
        pivoted: pd.DataFrame,
        returns: pd.Series,
        pair: str,
        base_country: str,
        quote_country: str,
    ):
        """Currency pair analysis graph"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
        fig.suptitle(f"{pair} Analysis", fontsize=16)

        # 1. Currency returns
        if not returns.empty:
            ax1.plot(returns.index, returns.values, color="blue", alpha=0.7)
            ax1.fill_between(returns.index, returns.values, alpha=0.3, color="blue")
            ax1.set_title(f"{pair} Synthetic Returns")
            ax1.set_ylabel("Returns")
            ax1.grid(True, alpha=0.3)

        # 2. Returns distribution
        if not returns.empty:
            ax2.hist(
                returns.values, bins=15, alpha=0.7, color="green", edgecolor="black"
            )
            ax2.axvline(
                returns.mean(),
                color="red",
                linestyle="--",
                label=f"Mean: {returns.mean():.3f}",
            )
            ax2.set_title("Returns Distribution")
            ax2.legend()

        # 3. Economic indicators comparison
        base_cols = [col for col in pivoted.columns if col[0] == base_country]
        quote_cols = [col for col in pivoted.columns if col[0] == quote_country]

        if base_cols and quote_cols:
            # Take the first available indicator
            base_col = base_cols[0]
            corresponding_quote = None
            for quote_col in quote_cols:
                if quote_col[1] == base_col[1]:
                    corresponding_quote = quote_col
                    break

            if corresponding_quote:
                base_data = pivoted[base_col].dropna()
                quote_data = pivoted[corresponding_quote].dropna()

                ax3.plot(
                    base_data.index,
                    base_data.values,
                    label=f"{base_country}",
                    marker="o",
                )
                ax3.plot(
                    quote_data.index,
                    quote_data.values,
                    label=f"{quote_country}",
                    marker="s",
                )
                ax3.set_title(f"Indicator: {base_col[1]}")
                ax3.legend()
                ax3.grid(True, alpha=0.3)

        # 4. Volatility
        if not returns.empty and len(returns) > 5:
            rolling_vol = returns.rolling(window=min(5, len(returns) // 2)).std()
            ax4.plot(rolling_vol.index, rolling_vol.values, color="purple", marker="o")
            ax4.set_title("Rolling Volatility")
            ax4.set_ylabel("Volatility")
            ax4.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(self.viz_dir / f"{pair}_analysis.png", dpi=100, bbox_inches="tight")
        plt.close()
        logger.info(f"Saved {pair} analysis chart (700px width)")

    def _plot_correlation_dashboard(self, results: Dict):
        """Final correlations dashboard"""
        # Gather all correlations
        all_correlations = []
        correlation_data = []

        for pair, analysis in results.get("pair_correlations", {}).items():
            for indicator, corr_data in analysis.get("correlations", {}).items():
                if isinstance(corr_data, dict) and "correlation" in corr_data:
                    all_correlations.append(abs(corr_data["correlation"]))
                    correlation_data.append(
                        {
                            "pair": pair,
                            "indicator": indicator,
                            "correlation": corr_data["correlation"],
                            "significant": corr_data.get("significant", False),
                        }
                    )

        if not correlation_data:
            logger.warning("No correlation data to plot")
            return

        corr_df = pd.DataFrame(correlation_data)

        # Dashboard 2x2
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
        fig.suptitle("IMF Currency Correlation Dashboard", fontsize=16)

        # 1. Correlation heatmap
        if len(corr_df) > 0:
            pivot_corr = corr_df.pivot_table(
                index="indicator", columns="pair", values="correlation", fill_value=0
            )

            im = ax1.imshow(
                pivot_corr.values, cmap="RdBu_r", aspect="auto", vmin=-1, vmax=1
            )
            ax1.set_xticks(range(len(pivot_corr.columns)))
            ax1.set_xticklabels(pivot_corr.columns)
            ax1.set_yticks(range(len(pivot_corr.index)))
            ax1.set_yticklabels(pivot_corr.index)
            ax1.set_title("Correlation Matrix")

            # Add colorbar
            cbar = plt.colorbar(im, ax=ax1)
            cbar.set_label("Correlation")

        # 2. Distribute correlations
        ax2.hist(
            all_correlations, bins=15, alpha=0.7, color="skyblue", edgecolor="black"
        )
        ax2.axvline(
            np.mean(all_correlations),
            color="red",
            linestyle="--",
            label=f"Mean: {np.mean(all_correlations):.3f}",
        )
        ax2.set_title("Distribution of |Correlations|")
        ax2.set_xlabel("|Correlation|")
        ax2.set_ylabel("Frequency")
        ax2.legend()

        # 3. Top correlations
        top_corr = corr_df.nlargest(5, "correlation")
        if len(top_corr) > 0:
            y_pos = range(len(top_corr))
            ax3.barh(y_pos, top_corr["correlation"].abs(), color="lightgreen")
            ax3.set_yticks(y_pos)
            ax3.set_yticklabels(
                [
                    f"{row['pair']}-{row['indicator'][:8]}"
                    for _, row in top_corr.iterrows()
                ]
            )
            ax3.set_title("Top Correlations")
            ax3.set_xlabel("|Correlation|")

        # 4. Significant vs insignificant
        significant_count = corr_df["significant"].sum()
        total_count = len(corr_df)
        not_significant = total_count - significant_count

        ax4.pie(
            [significant_count, not_significant],
            labels=[
                f"Significant ({significant_count})",
                f"Not Significant ({not_significant})",
            ],
            colors=["green", "red"],
            autopct="%1.1f%%",
        )
        ax4.set_title("Statistical Significance")

        plt.tight_layout()
        plt.savefig(
            self.viz_dir / "correlation_dashboard.png", dpi=100, bbox_inches="tight"
        )
        plt.close()
        logger.info(f"Saved correlation dashboard (700px width)")

    def _generate_currency_returns(
        self, data: pd.DataFrame, base_country: str, quote_country: str
    ) -> pd.Series:
        """Generate synthetic currency returns based on available data"""
        try:
            returns = pd.Series(index=data.index, dtype=float)

            # Find any available economic indicators for both countries
            base_indicators = [col for col in data.columns if col[0] == base_country]
            quote_indicators = [col for col in data.columns if col[0] == quote_country]

            logger.info(
                f"Base country {base_country} indicators: {[col[1] for col in base_indicators]}"
            )
            logger.info(
                f"Quote country {quote_country} indicators: {[col[1] for col in quote_indicators]}"
            )

            for idx in data.index:
                synthetic_return = 0
                factors_used = 0

                # Use any available indicators to generate synthetic returns
                for base_col in base_indicators:
                    indicator = base_col[1]
                    quote_col = (quote_country, indicator)

                    if quote_col in data.columns:
                        base_val = (
                            data.loc[idx, base_col] if idx in data.index else None
                        )
                        quote_val = (
                            data.loc[idx, quote_col] if idx in data.index else None
                        )

                        if pd.notna(base_val) and pd.notna(quote_val):
                            # Simple differential impact
                            diff = base_val - quote_val

                            # Weight different types of indicators differently
                            if "GDP" in indicator or "NGDP" in indicator:
                                synthetic_return += diff * 0.3
                            elif "CPI" in indicator or "PCPI" in indicator:
                                synthetic_return += (
                                    -diff * 0.2
                                )  # Inverted for PPP effect
                            elif "LUR" in indicator:  # Unemployment
                                synthetic_return += (
                                    -diff * 0.15
                                )  # Higher unemployment = weaker currency
                            elif "BCA" in indicator:  # Current account
                                synthetic_return += diff * 0.2
                            else:
                                synthetic_return += diff * 0.1  # Generic weight

                            factors_used += 1

                # Add some noise
                if factors_used > 0:
                    synthetic_return += np.random.normal(0, 0.5)
                else:
                    synthetic_return = np.random.normal(0, 1)

                returns[idx] = synthetic_return

            logger.info(
                f"Generated {len(returns.dropna())} currency return observations"
            )
            return returns.dropna()

        except Exception as e:
            logger.error(f"Error generating returns: {e}")
            return pd.Series()

    def _calculate_correlation(self, x: pd.Series, y: pd.Series) -> Dict:
        """Calculate correlation with significance"""
        try:
            aligned_x, aligned_y = x.align(y, join="inner")
            aligned_x = aligned_x.dropna()
            aligned_y = aligned_y.dropna()

            if len(aligned_x) < 3 or len(aligned_y) < 3:
                return {"correlation": 0, "p_value": 1, "sample_size": 0}

            # Align again after dropping NAs
            common_index = aligned_x.index.intersection(aligned_y.index)
            if len(common_index) < 3:
                return {"correlation": 0, "p_value": 1, "sample_size": 0}

            x_clean = aligned_x[common_index]
            y_clean = aligned_y[common_index]

            correlation, p_value = pearsonr(x_clean, y_clean)

            return {
                "correlation": round(correlation, 4),
                "p_value": round(p_value, 4),
                "sample_size": len(common_index),
                "significant": p_value < 0.05,
            }
        except Exception as e:
            logger.error(f"Error calculating correlation: {e}")
            return {"correlation": 0, "p_value": 1, "sample_size": 0}

    def _calculate_summary(self, pair_correlations: Dict) -> Dict:
        """Calculate summary statistics"""
        all_correlations = []
        significant_count = 0
        total_count = 0

        for pair, analysis in pair_correlations.items():
            for corr_data in analysis.get("correlations", {}).values():
                if isinstance(corr_data, dict) and "correlation" in corr_data:
                    all_correlations.append(abs(corr_data["correlation"]))
                    total_count += 1
                    if corr_data.get("significant", False):
                        significant_count += 1

        if not all_correlations:
            return {}

        return {
            "average_correlation": round(np.mean(all_correlations), 4),
            "max_correlation": round(max(all_correlations), 4),
            "significant_percentage": (
                round((significant_count / total_count) * 100, 2)
                if total_count > 0
                else 0
            ),
            "total_correlations": total_count,
        }


def generate_report(
    datasets: pd.DataFrame, countries: pd.DataFrame, correlation_results: Dict
) -> str:
    """Generate detailed analysis report"""

    report = ["=== IMF DATA MINING ANALYSIS WITH VISUALIZATIONS (700px) ===\n"]

    # Datasets
    report.append(f"1. DATASETS: {len(datasets)} available")

    # Countries
    report.append(f"2. COUNTRIES: {len(countries)} total")
    if not countries.empty:
        fx_countries = countries[countries["has_currency"] == True]
        report.append(f"   Major currencies: {len(fx_countries)}")

    # Working indicators
    if "working_indicators" in correlation_results:
        working = correlation_results["working_indicators"]
        report.append(f"3. WORKING INDICATORS: {len(working)}")
        for indicator in working:
            report.append(f"   - {indicator}")

    # Correlations
    if correlation_results and "summary_stats" in correlation_results:
        summary = correlation_results["summary_stats"]
        report.append(f"4. CORRELATIONS:")
        report.append(
            f"   Average correlation: {summary.get('average_correlation', 'N/A')}"
        )
        report.append(f"   Max correlation: {summary.get('max_correlation', 'N/A')}")
        report.append(
            f"   Significant: {summary.get('significant_percentage', 'N/A')}%"
        )
        report.append(
            f"   Total correlations: {summary.get('total_correlations', 'N/A')}"
        )

        # Pair details
        if "pair_correlations" in correlation_results:
            report.append(f"5. ANALYZED PAIRS:")
            for pair, analysis in correlation_results["pair_correlations"].items():
                quality = analysis.get("data_quality", {})
                data_points = quality.get("data_points", 0)
                indicators_used = quality.get("indicators_used", 0)
                time_span = quality.get("time_span", "Unknown")
                report.append(
                    f"   - {pair}: {data_points} data points, {indicators_used} indicators, {time_span}"
                )

                # Show top correlations for this pair
                correlations = analysis.get("correlations", {})
                if correlations:
                    sorted_corrs = sorted(
                        correlations.items(),
                        key=lambda x: abs(x[1]["correlation"]),
                        reverse=True,
                    )
                    report.append(f"     Top correlations:")
                    for indicator, corr_data in sorted_corrs[:3]:
                        corr = corr_data["correlation"]
                        sig = "(*)" if corr_data.get("significant", False) else ""
                        report.append(f"       {indicator}: {corr:.3f}{sig}")

    # Visualizations created
    report.append(f"\n6. VISUALIZATIONS CREATED (700px width):")
    report.append(f"   - indicator_test_results.png")
    report.append(f"   - datasets_overview.png")
    report.append(f"   - countries_analysis.png")
    report.append(f"   - data_fetch_results.png")
    report.append(f"   - correlation_dashboard.png")

    if "pair_correlations" in correlation_results:
        for pair in correlation_results["pair_correlations"].keys():
            report.append(f"   - {pair}_analysis.png")

    report.append(f"\n   All charts saved in: imf_visualizations_700px/")
    report.append("\n=== END ===")
    return "\n".join(report)


def save_results(
    datasets: pd.DataFrame, countries: pd.DataFrame, correlation_results: Dict
):
    """Save analysis results"""

    output_dir = Path("imf_analysis_output")
    output_dir.mkdir(exist_ok=True)

    # Save CSVs
    if not datasets.empty:
        datasets.to_csv(output_dir / "datasets.csv", index=False)
    if not countries.empty:
        countries.to_csv(output_dir / "countries.csv", index=False)

    # Save JSON
    if correlation_results:
        with open(output_dir / "correlations.json", "w") as f:
            json.dump(correlation_results, f, indent=2, default=str)

    # Save report
    report = generate_report(datasets, countries, correlation_results)
    with open(output_dir / "report.txt", "w") as f:
        f.write(report)

    logger.info(f"Results saved to: {output_dir.absolute()}")
    logger.info(f"Visualizations saved to: imf_visualizations_700px/")


def main():
    """Main execution function"""
    print(
        "=== IMF Data Mining & Currency Analysis WITH MT5 & 700px VISUALIZATIONS ===\n"
    )

    try:
        # Initialize MT5 connection
        print("🔌 Connecting to MetaTrader 5...")
        mt5_connector = MetaTrader5Connector()

        # Get currency pairs from MT5
        if mt5_connector.connect():
            currency_pairs = mt5_connector.get_currency_pairs()
            print(
                f"📈 Got {len(currency_pairs)} currency pairs from MT5: {currency_pairs}"
            )

            # MT5 CHART: Currency pair data
            mt5_data_all = []
            for pair in currency_pairs[:5]:  # Take the first 5 ones for visualization
                pair_data = mt5_connector.get_currency_data(pair, count=30)
                if not pair_data.empty:
                    mt5_data_all.append(pair_data)

            if mt5_data_all:
                # Create MT5 currency price chart
                fig, (ax1, ax2) = plt.subplots(2, 1)

                # Price chart
                for data in mt5_data_all[:3]:  # Top 3 pairs
                    symbol = data["symbol"].iloc[0]
                    ax1.plot(
                        data["time"],
                        data["close"],
                        label=symbol,
                        marker="o",
                        markersize=3,
                    )

                ax1.set_title("MT5 Currency Pairs - Closing Prices")
                ax1.set_ylabel("Price")
                ax1.legend()
                ax1.grid(True, alpha=0.3)

                # Volume chart
                for data in mt5_data_all[:3]:
                    symbol = data["symbol"].iloc[0]
                    ax2.plot(data["time"], data["tick_volume"], label=symbol, alpha=0.7)

                ax2.set_title("MT5 Currency Pairs - Tick Volumes")
                ax2.set_ylabel("Volume")
                ax2.legend()
                ax2.grid(True, alpha=0.3)

                plt.tight_layout()
                viz_dir = Path("imf_visualizations_700px")
                viz_dir.mkdir(exist_ok=True)
                plt.savefig(
                    viz_dir / "mt5_currency_data.png", dpi=100, bbox_inches="tight"
                )
                plt.close()
                logger.info("Saved MT5 currency data chart (700px width)")

            mt5_connector.disconnect()
        else:
            print("⚠️ MT5 connection failed, using default currency pairs")
            currency_pairs = ["EURUSD", "GBPUSD", "USDJPY", "AUDUSD"]

        # Initialize IMF
        data_miner = IMFDataMiner()
        analyzer = CurrencyCorrelationAnalyzer(data_miner)

        # Get basic info
        print("1. Fetching datasets...")
        datasets = data_miner.get_available_datasets()
        print(f"Found {len(datasets)} datasets")

        print("\n2. Fetching countries...")
        countries = data_miner.get_countries_list()
        print(f"Found {len(countries)} countries")

        # Analyze correlations using MT5 pairs
        print(f"\n3. Analyzing correlations for MT5 pairs: {currency_pairs[:3]}...")
        correlation_results = analyzer.analyze_correlations(
            currency_pairs[:3]
        )  # Take the first 3 ones

        # Generate report
        print("\n4. Generating report...")
        report = generate_report(datasets, countries, correlation_results)
        print("\n" + report)

        # Save results
        save_results(datasets, countries, correlation_results)

        print("\n=== Analysis Complete! ===")
        print("📊 Check imf_visualizations_700px/ folder for all 700px width charts!")
        print("🔍 MT5 currency pairs were used for correlation analysis!")

    except Exception as e:
        logger.error(f"Error in main: {e}")
        print(f"Error: {e}")


if __name__ == "__main__":
    main()
