"""
Enhanced implementation using separate bid/ask prices for long/short models.

This approach aligns training data with actual execution prices, improving
model performance in production trading.
"""

import pandas as pd
import numpy as np
from typing import Dict, Tuple
from loguru import logger


class BidAskLongShortPipeline:
    """
    Enhanced pipeline using realistic bid/ask prices for each side.
    
    This approach provides better production performance by training models
    on the actual prices they'll encounter during execution.
    """
    
    def __init__(
        self,
        strategy,
        data_config: dict,
        feature_config: dict,
        target_config: dict,
        label_config: dict,
        model_params: dict,
        base_dir: str = "Models/BidAsk_LongShort",
    ):
        """
        Initialize bid/ask-aware long/short pipeline.
        
        Parameters
        ----------
        strategy : BaseStrategy
            Signal generating strategy
        data_config : dict
            Bar construction configuration
        feature_config : dict
            Feature engineering configuration
        target_config : dict
            Volatility target configuration
        label_config : dict
            Triple-barrier labeling configuration
        model_params : dict
            Model training configuration
        base_dir : str
            Base directory for model artifacts
        """
        from .model_development import ModelDevelopmentPipeline
        
        self.strategy = strategy
        self.data_config = data_config
        self.feature_config = feature_config
        self.target_config = target_config
        self.label_config = label_config
        self.model_params = model_params
        
        # Will be populated during run
        self.tick_data = None
        self.long_bars = None
        self.short_bars = None
        self.spread_series = None
        
        # Create configurations for each side
        long_config = data_config.copy()
        long_config['price'] = 'ask'  # Long positions use ask
        long_config['model_name'] = f"{data_config.get('symbol', 'MODEL')}_LONG_ASK"
        
        short_config = data_config.copy()
        short_config['price'] = 'bid'  # Short positions use bid
        short_config['model_name'] = f"{data_config.get('symbol', 'MODEL')}_SHORT_BID"
        
        self.long_pipeline = ModelDevelopmentPipeline(
            strategy=strategy,
            data_config=long_config,
            feature_config=feature_config,
            target_config=target_config,
            label_config=label_config,
            model_params=model_params.copy(),
            base_dir=f"{base_dir}/Long_Ask"
        )
        
        self.short_pipeline = ModelDevelopmentPipeline(
            strategy=strategy,
            data_config=short_config,
            feature_config=feature_config,
            target_config=target_config,
            label_config=label_config,
            model_params=model_params.copy(),
            base_dir=f"{base_dir}/Short_Bid"
        )
        
    def run(
        self,
        generate_reports: bool = True,
        save: bool = True,
        export_onnx: bool = False,
        verbose: bool = True
    ) -> Dict:
        """
        Run bid/ask-aware long/short model development pipeline.
        
        Returns
        -------
        dict
            Dictionary containing models, metrics, and analysis
        """
        if verbose:
            print("\n" + "=" * 80)
            print("BID/ASK-AWARE LONG/SHORT MODEL DEVELOPMENT PIPELINE")
            print("=" * 80)
            print("\nUsing realistic execution prices:")
            print("  • Long model: ASK prices (your entry cost)")
            print("  • Short model: BID prices (your entry cost)")
            
        try:       
            # Step 1: Create separate bars for long (ask) and short (bid)
            if verbose:
                print("\n[Step 1/6] Creating side-specific bars...")
                print("  • Long bars: using ASK prices")
                print("  • Short bars: using BID prices")
            self._create_side_specific_bars()
            
            # Step 2: Engineer features for each side
            if verbose:
                print("\n[Step 2/6] Engineering features for each side...")
            self._engineer_side_specific_features()
            
            # Step 3: Generate events for each side
            if verbose:
                print("\n[Step 3/6] Generating side-specific events...")
            self._generate_side_specific_events()
            
            # Step 4: Train long model
            if verbose:
                print("\n[Step 4/6] Training LONG model (ASK-based)...")
            long_results = self._train_side_model(
                self.long_pipeline,
                "LONG (ASK)",
                generate_reports,
                save,
                export_onnx,
                verbose
            )
            
            # Step 5: Train short model
            if verbose:
                print("\n[Step 5/6] Training SHORT model (BID-based)...")
            short_results = self._train_side_model(
                self.short_pipeline,
                "SHORT (BID)",
                generate_reports,
                save,
                export_onnx,
                verbose
            )
            
            # Step 6: Generate spread-aware analysis
            if verbose:
                print("\n[Step 6/6] Generating spread-aware analysis...")
            combined_metrics = self._generate_spread_analysis(long_results, short_results)
            
            results = {
                'long_model': long_results[0],
                'short_model': short_results[0],
                'long_features': long_results[1],
                'short_features': short_results[1],
                'long_metrics': long_results[2],
                'short_metrics': short_results[2],
                'combined_metrics': combined_metrics,
                'spread_stats': self._calculate_spread_statistics(),
                'long_config': long_results[3],
                'short_config': short_results[3],
            }
            
            if verbose:
                print("\n" + "=" * 80)
                print("✓ Bid/Ask-Aware Pipeline Completed Successfully")
                print("=" * 80)
                self._print_summary(results)
            
            return results
            
        except Exception as e:
            logger.error(f"Bid/Ask pipeline failed: {e}")
            raise
        
    def _create_side_specific_bars(self):
        """Create separate bars using ask (long) and bid (short) prices."""
        from .model_development import load_and_prepare_training_data
        
        bar_type = self.data_config['bar_type']
        self.data_config['price'] = 'bid_ask'
        
        self.bar_data = load_and_prepare_training_data(**self.data_config)
        bar_size = self.bar_data["tick_volume"].iloc[0]
        
        if self.data_config == "tick":
            self.short_config["tick_bar_size"] = bar_size
            self.short_pipeline.file_manager.save_config(self.short_config)
            self.long_config["tick_bar_size"] = bar_size
            self.long_pipeline.file_manager.save_config(self.long_config)
            
        # Long bars: use ask prices
        self.long_bars = self.bar_data.filter(regex='ask')
        self.long_bars.columns = [x.split('_')[1] for x in self.long_bars.columns]
        
        # Short bars: use bid prices
        self.short_bars =  self.bar_data.filter(regex='bid')
        self.short_bars.columns = [x.split('_')[1] for x in self.short_bars.columns]
        
        # Calculate spread at bar frequency
        self.bar_spread = self.long_bars['close'] - self.short_bars['close']
        self.short_bars['spread'] = self.bar_spread
        self.long_bars['spread'] = self.bar_spread
        
        logger.info(f"Created {len(self.long_bars)} bars for each side")
        logger.info(f"Average bar-level spread: {self.bar_spread.mean():.5f}")
        
    def _engineer_side_specific_features(self):
        """Engineer features separately for each side."""
        from .model_development import create_feature_engineering_pipeline
        
        # Long features (from ask prices)
        long_features = create_feature_engineering_pipeline(
            self.long_bars,
            self.feature_config,
            self.data_config
        )
        
        # Short features (from bid prices)
        short_features = create_feature_engineering_pipeline(
            self.short_bars,
            self.feature_config,
            self.data_config
        )
        
        self.long_pipeline.bar_data = self.long_bars
        self.short_pipeline.bar_data = self.short_bars
        self.long_pipeline.features = long_features
        self.short_pipeline.features = short_features
        
        self.long_pipeline.completed_steps['data_loading'] = True
        self.short_pipeline.completed_steps['data_loading'] = True
        self.long_pipeline.completed_steps['feature_engineering'] = True
        self.short_pipeline.completed_steps['feature_engineering'] = True
        
    def _generate_side_specific_events(self):
        """Generate events separately for each side using appropriate prices."""
        from .model_development import generate_events_triple_barrier
        
        # Generate long events using ask prices
        long_events = generate_events_triple_barrier(
            self.long_bars,
            self.strategy,
            **self.label_config
        )
        # Filter to only long signals
        long_events = long_events[long_events['side'] == 1]
        
        # Generate short events using bid prices
        short_events = generate_events_triple_barrier(
            self.short_bars,
            self.strategy,
            **self.label_config
        )
        # Filter to only short signals
        short_events = short_events[short_events['side'] == -1]
        
        self.long_pipeline.events = long_events
        self.short_pipeline.events = short_events
        
        self.long_pipeline.completed_steps['label_generation'] = True
        self.short_pipeline.completed_steps['label_generation'] = True
        
        logger.info(f"Generated {len(long_events)} long events (ask-based)")
        logger.info(f"Generated {len(short_events)} short events (bid-based)")
        
    def _train_side_model(
        self,
        pipeline,
        side_name: str,
        generate_reports: bool,
        save: bool,
        export_onnx: bool,
        verbose: bool
    ):
        """Train model for specific side."""
        
        if verbose:
            print(f"  Computing {side_name} sample weights...")
        pipeline.compute_sample_weights()
        
        if verbose:
            print(f"  Adding {side_name} meta-features...")
        pipeline.add_meta_features()
        pipeline.preprocess_features()
        
        if verbose:
            print(f"  Training {side_name} model...")
        pipeline.train_model()
        
        if verbose:
            print(f"  Analyzing {side_name} features...")
        pipeline.analyze_features()
        
        pipeline._compile_metrics()
        
        if generate_reports:
            pipeline._generate_analysis_reports()
        
        if save:
            pipeline._save_all_artifacts()
        
        return (
            pipeline.best_model,
            pipeline._get_feature_names(),
            pipeline.metrics,
            pipeline.config
        )
    
    def _calculate_spread_statistics(self) -> Dict:
        """Calculate comprehensive spread statistics."""
        return {
            'tick_spread_mean': float(self.spread_series.mean()),
            'tick_spread_std': float(self.spread_series.std()),
            'tick_spread_median': float(self.spread_series.median()),
            'tick_spread_95th': float(self.spread_series.quantile(0.95)),
            'bar_spread_mean': float(self.bar_spread.mean()),
            'bar_spread_std': float(self.bar_spread.std()),
            'spread_bps': float((self.spread_series.mean() / self.tick_data['bid'].mean()) * 10000),
        }
    
    def _generate_spread_analysis(self, long_results, short_results) -> Dict:
        """Generate analysis including spread impact."""
        long_metrics = long_results[2]
        short_metrics = short_results[2]
        
        return {
            'long_events': long_metrics['events_count'],
            'short_events': short_metrics['events_count'],
            'long_cv_score': long_metrics['cv_results'].get('best_score', 0),
            'short_cv_score': short_metrics['cv_results'].get('best_score', 0),
            'long_features': long_metrics['feature_count'],
            'short_features': short_metrics['feature_count'],
            'spread_stats': self._calculate_spread_statistics(),
            'long_label_dist': long_metrics['label_distribution'],
            'short_label_dist': short_metrics['label_distribution'],
        }
    
    def _print_summary(self, results: Dict):
        """Print comprehensive summary including spread analysis."""
        print("\nMODEL COMPARISON SUMMARY")
        print("-" * 80)
        
        spread_stats = results['spread_stats']
        print(f"\nSpread Statistics:")
        print(f"  Average Spread: {spread_stats['tick_spread_mean']:.5f} ({spread_stats['spread_bps']:.2f} bps)")
        print(f"  Spread StdDev: {spread_stats['tick_spread_std']:.5f}")
        print(f"  95th Percentile: {spread_stats['tick_spread_95th']:.5f}")
        
        print(f"\nLONG Model (ASK-based):")
        print(f"  Events: {results['combined_metrics']['long_events']:,}")
        print(f"  CV Score: {results['combined_metrics']['long_cv_score']:.4f}")
        print(f"  Features: {results['combined_metrics']['long_features']}")
        
        print(f"\nSHORT Model (BID-based):")
        print(f"  Events: {results['combined_metrics']['short_events']:,}")
        print(f"  CV Score: {results['combined_metrics']['short_cv_score']:.4f}")
        print(f"  Features: {results['combined_metrics']['short_features']}")


# Convenience function
def train_bidask_longshort_models(
    strategy,
    data_config: dict,
    feature_config: dict,
    target_config: dict,
    label_config: dict,
    model_params: dict,
    base_dir: str = "Models/BidAsk_LongShort",
    **kwargs
) -> Dict:
    """
    Train separate long/short models using realistic bid/ask prices.
    
    This is the RECOMMENDED approach for production trading systems.
    """
    pipeline = BidAskLongShortPipeline(
        strategy=strategy,
        data_config=data_config,
        feature_config=feature_config,
        target_config=target_config,
        label_config=label_config,
        model_params=model_params,
        base_dir=base_dir
    )
    
    return pipeline.run(**kwargs)
