"""
Centralized caching system for AFML package.
Now with robust cache keys, MLflow integration, backtest caching, and monitoring.
"""

import json
import os
import threading
from collections import defaultdict
from pathlib import Path
from types import FunctionType
from typing import Callable, Dict, Optional, Union

from appdirs import user_cache_dir
from joblib import Memory
from loguru import logger

# =============================================================================
# 1) CACHE DIRECTORY SETUP
# =============================================================================


def _setup_cache_directories() -> Dict[str, Path]:
    """Setup centralized cache directories."""
    # Base cache directory from environment or default
    cache_env = os.getenv("AFML_CACHE")
    base_dir = Path(cache_env) if cache_env else Path(user_cache_dir("afml"))

    dirs = {
        "base": base_dir,
        "joblib": base_dir / "joblib_cache",
        "numba": base_dir / "numba_cache",
        "backtest": base_dir / "backtest_cache",  # Added backtest cache directory
    }

    # Create directories
    for cache_dir in dirs.values():
        cache_dir.mkdir(parents=True, exist_ok=True)

    return dirs


CACHE_DIRS = _setup_cache_directories()

# =============================================================================
# 2) NUMBA CONFIGURATION
# =============================================================================


def _configure_numba():
    """Configure Numba to use centralized cache."""
    numba_dir = str(CACHE_DIRS["numba"])
    os.environ["NUMBA_CACHE_DIR"] = numba_dir

    # Performance optimizations
    os.environ.setdefault("NUMBA_DISABLE_JIT", "0")
    os.environ.setdefault("NUMBA_WARNINGS", "0")

    logger.debug("Numba cache configured: {}", numba_dir)


# =============================================================================
# 3) SIMPLE CACHE STATISTICS
# =============================================================================


class CacheStats:
    """Lightweight cache statistics tracking."""

    def __init__(self):
        self._lock = threading.Lock()
        self._stats = defaultdict(lambda: {"hits": 0, "misses": 0})
        self._stats_file = CACHE_DIRS["base"] / "cache_stats.json"
        self._load_stats()

    def _load_stats(self):
        """Load stats from disk."""
        if self._stats_file.exists():
            try:
                with open(self._stats_file, "r") as f:
                    data = json.load(f)
                    self._stats.update(data)
            except Exception:
                pass  # Start fresh if corrupted

    def _save_stats(self):
        """Save stats to disk."""
        try:
            with open(self._stats_file, "w") as f:
                json.dump(dict(self._stats), f)
        except Exception:
            pass  # Fail silently

    def record_hit(self, func_name: str):
        """Record cache hit."""
        with self._lock:
            self._stats[func_name]["hits"] += 1
            # Save every 25 hits to reduce I/O
            if self._stats[func_name]["hits"] % 25 == 0:
                self._save_stats()

    def record_miss(self, func_name: str):
        """Record cache miss."""
        with self._lock:
            self._stats[func_name]["misses"] += 1
            # Save every 25 misses
            if self._stats[func_name]["misses"] % 25 == 0:
                self._save_stats()

    def get_hit_rate(self, func_name: str = None) -> float:
        """Get hit rate for function or overall."""
        with self._lock:
            if func_name:
                stats = self._stats[func_name]
                total = stats["hits"] + stats["misses"]
                return stats["hits"] / total if total > 0 else 0.0
            else:
                total_hits = sum(s["hits"] for s in self._stats.values())
                total_calls = sum(s["hits"] + s["misses"] for s in self._stats.values())
                return total_hits / total_calls if total_calls > 0 else 0.0

    def get_stats(self) -> Dict[str, Dict[str, int]]:
        """Get all statistics."""
        with self._lock:
            return dict(self._stats)

    def clear(self):
        """Clear all statistics."""
        with self._lock:
            self._stats.clear()
            if self._stats_file.exists():
                self._stats_file.unlink()


# Global stats instance
cache_stats = CacheStats()

# =============================================================================
# 4) JOBLIB MEMORY INSTANCE
# =============================================================================

memory = Memory(location=str(CACHE_DIRS["joblib"]), verbose=0)


# =============================================================================
# 5) UTILITY FUNCTIONS
# =============================================================================


def get_cache_hit_rate(func_name: str = None) -> float:
    """Get cache hit rate."""
    return cache_stats.get_hit_rate(func_name)


def get_cache_stats() -> Dict[str, Dict[str, int]]:
    """Get cache statistics."""
    return cache_stats.get_stats()


def clear_cache_stats():
    """Clear cache statistics."""
    cache_stats.clear()


def clear_afml_cache(warn: bool = True):
    """Clear all AFML caches."""
    if warn:
        logger.warning("Clearing AFML cache...")

    memory.clear(warn=warn)
    clear_cache_stats()


def get_cache_summary() -> Dict[str, Union[float, int]]:
    """Get simple cache performance summary."""
    stats = cache_stats.get_stats()
    total_hits = sum(s["hits"] for s in stats.values())
    total_calls = sum(s["hits"] + s["misses"] for s in stats.values())

    return {
        "hit_rate": total_hits / total_calls if total_calls > 0 else 0.0,
        "total_calls": total_calls,
        "functions_tracked": len(stats),
    }


# =============================================================================
# 6) CACHE ANALYSIS CONTEXT MANAGER
# =============================================================================


class CacheAnalyzer:
    """Simple context manager for analyzing cache performance."""

    def __init__(self, name: str = "analysis"):
        self.name = name
        self.start_stats = None

    def __enter__(self):
        self.start_stats = cache_stats.get_stats().copy()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            end_stats = cache_stats.get_stats()
            report = self._generate_report(end_stats)
            if report:
                logger.info("Cache analysis '{}': {}", self.name, report)

    def _generate_report(self, end_stats) -> Optional[str]:
        """Generate simple performance report."""
        if not self.start_stats:
            return None

        total_new_hits = 0
        total_new_calls = 0

        for func_name, end_data in end_stats.items():
            start_data = self.start_stats.get(func_name, {"hits": 0, "misses": 0})
            new_hits = end_data["hits"] - start_data["hits"]
            new_misses = end_data["misses"] - start_data["misses"]
            new_calls = new_hits + new_misses

            total_new_hits += new_hits
            total_new_calls += new_calls

        if total_new_calls > 0:
            hit_rate = total_new_hits / total_new_calls
            return f"{total_new_calls} calls, {hit_rate:.1%} hit rate"

        return "no cache activity"


# =============================================================================
# 7) INITIALIZATION FUNCTION
# =============================================================================


def initialize_cache_system():
    """Initialize the AFML cache system."""
    # Configure Numba first (before any @njit functions are defined)
    _configure_numba()

    # Log cache setup
    logger.info("AFML cache system initialized:")
    logger.info("  Joblib cache: {}", CACHE_DIRS["joblib"])
    logger.info("  Numba cache: {}", CACHE_DIRS["numba"])

    # Load existing stats
    stats = cache_stats.get_stats()
    if stats:
        hit_rate = cache_stats.get_hit_rate()
        logger.info(
            "  Loaded stats: {} functions, {:.1%} hit rate", len(stats), hit_rate
        )


# =============================================================================
# 8) NOW SAFE TO IMPORT OTHER MODULES
# =============================================================================

# Import robust cache key generation - NOW SAFE (memory and cache_stats exist)
from .data_access_tracker import (  # noqa: E402
    DataAccessTracker,  # noqa: E402
    clear_data_access_log,
    get_data_tracker,
    log_data_access,
    print_contamination_report,
)

# Import selective cleaner functions after base components are defined
from .selective_cleaner import (  # noqa: E402
    analyze_cache_versions,  # noqa: E402
    cache_maintenance,
    clean_orphaned_caches,
    cleanup_by_age,
    cleanup_by_size,
    clear_orphaned_features_caches,
    clear_orphaned_labeling_caches,
    clear_orphaned_ml_caches,
    find_orphaned_caches,
    get_version_tracker,
    print_version_analysis,
)

# Add to imports
from .unified_cache_system import (
    cacheable,  # noqa: E402
    create_cacheable_param_grid,
    cv_cacheable,
    data_tracking_cacheable,
    print_cache_report,
    reconstruct_param_grid,
    robust_cacheable,
    time_aware_cacheable,
)

# MLflow integration (optional)
try:
    from .mlflow_integration import (
        MLFLOW_AVAILABLE,
        MLflowCacheIntegration,
        get_mlflow_cache,
        mlflow_cached,
        setup_mlflow_cache,
    )

    MLFLOW_INTEGRATION_AVAILABLE = True
except ImportError:
    MLFLOW_INTEGRATION_AVAILABLE = False
    logger.debug("MLflow integration not available (install mlflow)")

# Backtest caching
from .backtest_cache import (
    BacktestCache,
    BacktestMetadata,  # noqa: E402
    BacktestResult,
    cached_backtest,
    get_backtest_cache,
)
# Cache monitoring
from .cache_monitoring import (
    CacheHealthReport,
    CacheMonitor,  # noqa: E402
    FunctionCacheStats,
    analyze_cache_patterns,
    debug_function_cache,
    diagnose_cache_issues,
    get_cache_efficiency_report,
    get_cache_monitor,
    print_cache_health,
)

# =============================================================================
# 9) ENHANCED CONVENIENCE FUNCTIONS
# =============================================================================


def get_comprehensive_cache_status() -> dict:
    """
    Get comprehensive cache status including all subsystems.

    Returns:
        Dict with status of all cache components
    """
    status = {
        "core": get_cache_summary(),
        "health": None,
        "backtest": None,
        "mlflow": {"available": MLFLOW_INTEGRATION_AVAILABLE},
    }

    # Get health report
    try:
        monitor = get_cache_monitor()
        report = monitor.generate_health_report()
        status["health"] = {
            "total_functions": report.total_functions,
            "hit_rate": report.overall_hit_rate,
            "total_calls": report.total_calls,
            "cache_size_mb": report.total_cache_size_mb,
        }
    except Exception as e:
        logger.debug(f"Health report failed: {e}")

    # Get backtest cache stats
    try:
        backtest_cache = get_backtest_cache()
        status["backtest"] = backtest_cache.get_cache_stats()
    except Exception as e:
        logger.debug(f"Backtest cache stats failed: {e}")

    return status


def optimize_cache_system(
    clear_changed: bool = True,
    max_size_mb: int = 1000,
    max_age_days: int = 30,
    print_report: bool = True,
) -> dict:
    """
    Comprehensive cache optimization and maintenance.

    Args:
        clear_changed: Clear caches for changed functions
        max_size_mb: Maximum total cache size in MB
        max_age_days: Remove caches older than this
        print_report: Print detailed report

    Returns:
        Dict with optimization results
    """
    logger.info("Running comprehensive cache optimization...")

    results = {
        "maintenance": None,
        "health_report": None,
        "backtest_cleanup": None,
    }

    # Run core cache maintenance
    try:
        results["maintenance"] = cache_maintenance(
            auto_clear_changed=clear_changed,
            max_cache_size_mb=max_size_mb,
            max_age_days=max_age_days,
        )
    except Exception as e:
        logger.warning(f"Cache maintenance failed: {e}")

    # Get health report
    try:
        monitor = get_cache_monitor()
        results["health_report"] = monitor.generate_health_report()

        if print_report:
            monitor.print_health_report(detailed=False)
    except Exception as e:
        logger.warning(f"Health report failed: {e}")

    # Clean old backtest caches
    try:
        backtest_cache = get_backtest_cache()
        cleared = backtest_cache.clear_old_runs(days=max_age_days)
        results["backtest_cleanup"] = {"runs_cleared": cleared}
        logger.info(f"Cleared {cleared} old backtest runs")
    except Exception as e:
        logger.warning(f"Backtest cleanup failed: {e}")

    return results


def setup_production_cache(
    enable_mlflow: bool = True,
    mlflow_experiment: str = "production",
    mlflow_uri: str = None,
    max_cache_size_mb: int = 2000,
) -> dict:
    """
    Setup cache system for production use.

    Args:
        enable_mlflow: Enable MLflow integration
        mlflow_experiment: MLflow experiment name
        mlflow_uri: MLflow tracking URI
        max_cache_size_mb: Maximum cache size

    Returns:
        Dict with initialized components
    """
    logger.info("Initializing production cache system...")

    components = {
        "core_cache": None,
        "mlflow_cache": None,
        "backtest_cache": None,
        "monitor": None,
    }

    # Initialize core cache
    initialize_cache_system()
    components["core_cache"] = True

    # Setup MLflow if available and requested
    if enable_mlflow and MLFLOW_INTEGRATION_AVAILABLE:
        try:
            components["mlflow_cache"] = setup_mlflow_cache(
                experiment_name=mlflow_experiment,
                tracking_uri=mlflow_uri,
            )
            logger.info(f"MLflow tracking enabled: {mlflow_experiment}")
        except Exception as e:
            logger.warning(f"MLflow setup failed: {e}")

    # Initialize backtest cache
    try:
        components["backtest_cache"] = get_backtest_cache()
    except Exception as e:
        logger.warning(f"Backtest cache setup failed: {e}")

    # Initialize monitor
    try:
        components["monitor"] = get_cache_monitor()
    except Exception as e:
        logger.warning(f"Cache monitor setup failed: {e}")

    # Run initial maintenance
    try:
        optimize_cache_system(max_size_mb=max_cache_size_mb, print_report=False)
    except Exception as e:
        logger.warning(f"Initial optimization failed: {e}")

    logger.info("✅ Production cache system ready")
    return components


# =============================================================================
# 10) ADDITIONAL UTILITY FUNCTIONS
# =============================================================================


def get_cache_size_info() -> Dict[str, Union[int, float]]:
    """
    Get detailed information about cache sizes.

    Returns:
        Dict with cache size information in bytes and MB
    """
    size_info = {}

    for cache_name, cache_dir in CACHE_DIRS.items():
        if cache_dir.exists():
            total_size = 0
            file_count = 0

            for file_path in cache_dir.rglob("*"):
                if file_path.is_file():
                    total_size += file_path.stat().st_size
                    file_count += 1

            size_info[cache_name] = {
                "size_bytes": total_size,
                "size_mb": round(total_size / (1024 * 1024), 2),
                "file_count": file_count,
            }

    return size_info


def clear_cache_by_pattern(pattern: str, cache_type: str = "joblib"):
    """
    Clear cache entries matching a pattern.

    Args:
        pattern: String pattern to match in cache filenames
        cache_type: Type of cache to clear ('joblib', 'numba', 'backtest')
    """
    if cache_type not in CACHE_DIRS:
        raise ValueError(
            f"Invalid cache type: {cache_type}. Available: {list(CACHE_DIRS.keys())}"
        )

    cache_dir = CACHE_DIRS[cache_type]
    removed_count = 0

    for cache_file in cache_dir.rglob("*"):
        if cache_file.is_file() and pattern in cache_file.name:
            try:
                cache_file.unlink()
                removed_count += 1
                logger.debug(f"Removed cache file: {cache_file.name}")
            except Exception as e:
                logger.warning(f"Failed to remove {cache_file}: {e}")

    logger.info(
        f"Removed {removed_count} cache files matching pattern '{pattern}' from {cache_type} cache"
    )


def apply_decorator_to_methods(decorator: Callable, *, include_private: bool = False):
    """
    Class decorator factory that applies `decorator` to each function attribute
    on the class (by default public methods only). Preserves staticmethod/classmethod.
    """

    def class_decorator(cls):
        for name, attr in list(cls.__dict__.items()):
            if not include_private and name.startswith("_"):
                continue

            # staticmethod
            if isinstance(attr, staticmethod):
                fn = attr.__func__
                wrapped = decorator(fn)
                setattr(cls, name, staticmethod(wrapped))
                continue

            # classmethod
            if isinstance(attr, classmethod):
                fn = attr.__func__
                wrapped = decorator(fn)
                setattr(cls, name, classmethod(wrapped))
                continue

            # plain function (instance method)
            if isinstance(attr, FunctionType):
                wrapped = decorator(attr)
                setattr(cls, name, wrapped)

        return cls

    return class_decorator


# =============================================================================
# 11) EXPORTS
# =============================================================================

__all__ = [
    # Core caching
    "memory",
    "cacheable",  # NEW: Universal decorator
    "initialize_cache_system",
    "cache_stats",
    "get_cache_hit_rate",
    "get_cache_stats",
    "clear_cache_stats",
    "get_cache_summary",
    "CacheAnalyzer",
    "clear_afml_cache",
    "CACHE_DIRS",
    # Selective cache management (NOW FOCUSED ON CLEANUP)
    "cache_maintenance",
    "find_orphaned_caches",
    "clean_orphaned_caches",
    "cleanup_by_size",
    "cleanup_by_age",
    "get_version_tracker",
    "clear_orphaned_ml_caches",
    "clear_orphaned_labeling_caches",
    "clear_orphaned_features_caches",
    "analyze_cache_versions",
    "print_version_analysis",
    # NOTE: Removed exports:
    # - smart_cacheable (replaced by auto_versioning parameter)
    # - clear_changed_* functions (replaced by clean_orphaned_* functions)
    # - selective_cache_clear (replaced by clean_orphaned_caches)
    # Robust cache keys
    "CacheKeyGenerator",
    "clear_data_access_log",
    "DataAccessTracker",
    "get_data_tracker",
    "log_data_access",
    "print_contamination_report",
    "robust_cacheable",  # Alias for cacheable()
    "time_aware_cacheable",  # Alias for cacheable(time_aware=True)
    "data_tracking_cacheable",
    # MLflow integration
    "MLflowCacheIntegration",
    "setup_mlflow_cache",
    "get_mlflow_cache",
    "mlflow_cached",
    "MLFLOW_AVAILABLE",
    "MLFLOW_INTEGRATION_AVAILABLE",
    # Backtest caching
    "BacktestCache",
    "BacktestMetadata",
    "BacktestResult",
    "get_backtest_cache",
    "cached_backtest",
    # Cache monitoring
    "CacheMonitor",
    "FunctionCacheStats",
    "CacheHealthReport",
    "get_cache_monitor",
    "print_cache_report",
    "print_cache_health",
    "get_cache_efficiency_report",
    "analyze_cache_patterns",
    "debug_function_cache",
    "diagnose_cache_issues",
    # Enhanced convenience functions
    "get_comprehensive_cache_status",
    "optimize_cache_system",
    "setup_production_cache",
    # Cache cross-validation
    "cv_cacheable",  # Alias for cacheable()
    # Additional utility functions
    "get_cache_size_info",
    "clear_cache_by_pattern",
    "apply_decorator_to_methods",
    # Hyper-parameter fit helpers
    "reconstruct_param_grid",
    "create_cacheable_param_grid",
]


# =============================================================================
# STARTUP MESSAGE UPDATE
# =============================================================================

# Add to end of file to show new features are available
logger.debug("Enhanced cache features available:")
logger.debug("  - Unified cacheable() decorator with auto_versioning")
logger.debug("  - Robust cache keys for NumPy/Pandas")
logger.debug("  - MLflow integration: {}", "✓" if MLFLOW_INTEGRATION_AVAILABLE else "✗")
logger.debug("  - Backtest caching: ✓")
logger.debug("  - Cache monitoring: ✓")
logger.debug("  - Orphaned cache cleanup: ✓")
logger.debug("  - Orphaned cache cleanup: ✓")
logger.debug("  - Cache size info and selective clearing: ✓")
