"""
Implementation of Sequentially Bootstrapped Bagging Classifier using sklearn's library as base class
"""

import numbers
from abc import ABCMeta, abstractmethod
from warnings import warn

import numpy as np
from joblib import Parallel, delayed
from sklearn.base import ClassifierMixin, RegressorMixin
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
from sklearn.ensemble._bagging import BaseBagging
from sklearn.ensemble._base import _partition_estimators
from sklearn.metrics import accuracy_score, r2_score
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils import (
    check_array,
    check_consistent_length,
    check_random_state,
    check_X_y,
)
from sklearn.utils.random import sample_without_replacement
from sklearn.utils.validation import has_fit_parameter

from .bootstrapping import get_active_indices, seq_bootstrap
from .misc import indices_to_mask

MAX_INT = np.iinfo(np.int32).max


# pylint: disable=too-many-ancestors
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments
# pylint: disable=too-many-statements
# pylint: disable=invalid-name
# pylint: disable=protected-access
# pylint: disable=len-as-condition
# pylint: disable=attribute-defined-outside-init
# pylint: disable=bad-super-call
# pylint: disable=no-else-raise


def _generate_random_features(random_state, bootstrap, n_population, n_samples):
    """Draw randomly sampled indices."""
    # Draw random indices for features
    if bootstrap:
        indices = random_state.randint(0, n_population, n_samples)
    else:
        indices = sample_without_replacement(n_population, n_samples, random_state=random_state)
    return indices


def _generate_bagging_indices(
    random_state, bootstrap_features, n_features, max_features, max_samples, active_indices
):
    """Randomly draw feature and sample indices."""
    # Get valid random state - this returns a RandomState object
    random_state_obj = check_random_state(random_state)

    # Draw samples using sequential bootstrap
    if isinstance(max_samples, numbers.Integral):
        sample_indices = seq_bootstrap(
            active_indices, sample_length=max_samples, random_seed=random_state_obj
        )
    elif isinstance(max_samples, numbers.Real):
        n_samples = int(round(max_samples * len(active_indices)))
        sample_indices = seq_bootstrap(
            active_indices, sample_length=n_samples, random_seed=random_state_obj
        )
    else:
        sample_indices = seq_bootstrap(
            active_indices, sample_length=None, random_seed=random_state_obj
        )

    # Draw feature indices using the same random state
    if isinstance(max_features, numbers.Integral):
        n_feat = max_features
    elif isinstance(max_features, numbers.Real):
        n_feat = int(round(max_features * n_features))
    else:
        raise ValueError("max_features must be int or float")

    feature_indices = _generate_random_features(
        random_state_obj, bootstrap_features, n_features, n_feat
    )

    return sample_indices, feature_indices


def _parallel_build_estimators(
    n_estimators, ensemble, X, y, active_indices, sample_weight, seeds, total_n_estimators, verbose
):
    """Private function used to build a batch of estimators within a job."""
    # Retrieve settings
    n_samples, n_features = X.shape
    max_samples = ensemble._max_samples
    max_features = ensemble.max_features
    bootstrap_features = ensemble.bootstrap_features
    support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")

    # Build estimators
    estimators = []
    estimators_samples = []
    estimators_features = []

    for i in range(n_estimators):
        if verbose > 1:
            print(
                "Building estimator %d of %d for this parallel run (total %d)..."
                % (i + 1, n_estimators, total_n_estimators)
            )

        random_state = seeds[i]
        estimator = ensemble._make_estimator(append=False, random_state=random_state)

        # Draw samples and features
        sample_indices, feature_indices = _generate_bagging_indices(
            random_state, bootstrap_features, n_features, max_features, max_samples, active_indices
        )

        # Draw samples, using sample weights if supported
        if support_sample_weight and sample_weight is not None:
            curr_sample_weight = sample_weight[sample_indices]
        else:
            curr_sample_weight = None

        estimators_features.append(feature_indices)
        estimators_samples.append(sample_indices)

        X_ = X[sample_indices][:, feature_indices]
        y_ = y[sample_indices]

        estimator.fit(X_, y_, sample_weight=curr_sample_weight)
        estimators.append(estimator)

    return estimators, estimators_features, estimators_samples


class SequentiallyBootstrappedBaseBagging(BaseBagging, metaclass=ABCMeta):
    """
    SequentiallyBootstrappedBaseBagging

    Base class implementing an ensemble built with a sequential bootstrap sampler for
    row selection and optional per-estimator feature subsampling. This base class
    provides the common plumbing for training sets, feature-selection, and
    out-of-bag (OOB) bookkeeping used by SequentiallyBootstrappedBaggingClassifier and
    related concrete subclasses.

    The implementation assumes concrete subclasses implement label- or task-specific
    aggregation and prediction logic (for example majority-vote vs probability
    averaging). The class focuses on sampling, estimator lifecycle (fit/clone),
    OOB bookkeeping, and ergonomics for financial-style evaluation (purging/embargo).

    Parameters
    ----------
    base_estimator : estimator object
        A scikit-learn-compatible, unfitted estimator implementing fit and predict.
        For probability-based aggregation or scoring, base_estimator should implement
        predict_proba.
    n_estimators : int, optional (default=10)
        Number of base estimators to build in the ensemble.
    max_samples : int or float, optional (default=1.0)
        If int, the exact number of training rows drawn per estimator; if float in
        (0, 1], the fraction of the training set drawn per estimator.
    replacement : bool, optional (default=True)
        If True, sample training rows with replacement. If False, sample without
        replacement.
    bootstrap_features : bool, optional (default=False)
        If True, features are sampled with replacement for each base estimator
        (feature bagging). When enabled, each estimator is trained on a randomly
        drawn subset of columns in addition to the sampled rows. This increases
        estimator diversity and can reduce ensemble variance but may weaken
        individual learners when the feature set is small or highly informative.
        Use `max_features` to control the number (or fraction) of features drawn
        per estimator.
    max_features : int or float or None, optional (default=None)
        If None and bootstrap_features is True, the full feature set is sampled
        (with replacement). If int, draw exactly `max_features` columns per
        estimator. If float in (0, 1], draw that fraction of the total feature
        count per estimator. Ignored when bootstrap_features is False.
    oob_score : bool, optional (default=False)
        Whether to compute and store the estimator's out-of-bag score after fit.
        Matches scikit-learn semantics: a boolean flag only. If True the estimator
        will compute the built-in OOB score and expose `oob_score_`. Use external
        helpers to compute custom OOB metrics without changing this flag.
    random_state : int or np.random.Generator or None, optional (default=None)
        RNG seed or generator used for reproducible sampling.
    n_jobs : int or None, optional (default=None)
        Parallel jobs for fitting/predicting (None=1, -1=all cores). Concrete
        subclasses may implement parallel dispatch.
    verbose : int, optional (default=0)
        Verbosity level.

    Attributes
    ----------
    estimators_ : list
        Fitted base estimator instances (length == n_estimators) after fit.
    samples_ : list of ndarray
        List of integer index arrays used as training indices for each fitted
        estimator. Used for reconstructing OOB masks or external OOB analysis.
    features_ : list of ndarray or None
        If bootstrap_features is True, a list of integer arrays of feature indices
        drawn for each estimator; otherwise None.
    oob_score_ : float or None
        The built-in OOB score computed by the estimator when oob_score=True; None
        when OOB scoring is disabled or unavailable.
    samples_info_sets : pd.Series or None
        Optional per-sample metadata used by sequential samplers (e.g., timestamps or
        label overlap extents). Concrete subclasses may require this to build
        purged/embargoed splits.

    Notes
    -----
    - Sequential bootstrap samplers are designed for settings where observations
      overlap in label exposure (for example, overlapping trade labels in
      financial datasets). Use purging and embargo strategies when evaluating to
      avoid temporal leakage.
    - Feature bootstrapping (bootstrap_features=True) is orthogonal to sequential
      bootstrapping: row sampling respects temporal/label structure while feature
      sampling increases estimator diversity. Always validate feature-subsampling
      choices with purged/embargoed cross-validation and domain-appropriate
      backtesting.
    - The class intentionally preserves `oob_score` as a boolean flag to remain
      API-compatible with scikit-learn. To compute custom OOB metrics (F1, AUC,
      weighted metrics) use an external helper that aggregates per-estimator OOB
      predictions (or the `samples_` list) and applies the desired scorer.

    Examples
    --------
    >>> from sklearn.tree import DecisionTreeClassifier
    >>> base = DecisionTreeClassifier(max_depth=6)
    >>> ens = SequentiallyBootstrappedBaseBagging(
    ...     base_estimator=base,
    ...     n_estimators=50,
    ...     max_samples=0.5,
    ...     replacement=True,
    ...     bootstrap_features=True,
    ...     max_features=0.3,
    ...     oob_score=False,
    ...     aggregation="probability",
    ...     random_state=42,
    ... )
    >>> ens.fit(X_train, y_train)
    >>> proba = ens.predict_proba(X_test)  # if implemented in subclass
    """

    @abstractmethod
    def __init__(
        self,
        samples_info_sets,
        price_bars_index,
        estimator=None,
        n_estimators=10,
        max_samples=1.0,
        max_features=1.0,
        bootstrap_features=False,
        oob_score=False,
        warm_start=False,
        n_jobs=None,
        random_state=None,
        verbose=0,
    ):
        super().__init__(
            estimator=estimator,
            n_estimators=n_estimators,
            bootstrap=True,  # Always use bootstrap for sequential bootstrap
            max_samples=max_samples,
            max_features=max_features,
            bootstrap_features=bootstrap_features,
            oob_score=oob_score,
            warm_start=warm_start,
            n_jobs=n_jobs,
            random_state=random_state,
            verbose=verbose,
        )

        self.samples_info_sets = samples_info_sets
        self.price_bars_index = price_bars_index
        self.active_indices_ = None
        self._estimators_samples = []  # Initialize private variable
        self._seeds = None  # Initialize the attribute

    @property
    def estimators_samples_(self):
        """Return the stored sample indices for each estimator."""
        return self._estimators_samples

    def fit(self, X, y, sample_weight=None):
        """
        Build a Sequentially Bootstrapped Bagging ensemble of estimators from the training
        set (X, y).

        Parameters
        ----------
        X : (array-like, sparse matrix) of shape = [n_samples, n_features]
            The training input samples. Sparse matrices are accepted only if
            they are supported by the base estimator.
        y : (array-like), shape = [n_samples]
            The target values (class labels in classification, real numbers in
            regression).
        sample_weight : (array-like), shape = [n_samples] or None
            Sample weights. If None, then samples are equally weighted.
            Note that this is supported only if the base estimator supports
            sample weighting.
        Returns
        -------
        self : (object)
        """
        return self._fit(X, y, self.max_samples, sample_weight=sample_weight)

    def _fit(self, X, y, max_samples=None, sample_weight=None):
        """
        Build a Sequentially Bootstrapped Bagging ensemble of estimators from the training
        set (X, y).

        Parameters
        ----------
        X : (array-like, sparse matrix) of shape = [n_samples, n_features]
            The training input samples. Sparse matrices are accepted only if
            they are supported by the base estimator.
        y : (array-like), shape = [n_samples]
            The target values (class labels in classification, real numbers in
            regression).
        max_samples : (int or float), optional (default=None)
        sample_weight : (array-like), shape = [n_samples] or None
            Sample weights. If None, then samples are equally weighted.
            Note that this is supported only if the base estimator supports
            sample weighting.
        Returns
        -------
        self : (object)
        """
        # Set classes_ and n_classes_ for classifier compatibility
        if hasattr(self, "classes_") and not hasattr(self, "n_classes_"):
            # This is a classifier, set the required attributes
            self.classes_ = np.unique(y)
            self.n_classes_ = len(self.classes_)
        # Validate parameters
        random_state = check_random_state(self.random_state)

        # Generate random seeds for each estimator, just like BaseBagging does
        random_state = check_random_state(self.random_state)
        self._seeds = random_state.randint(np.iinfo(np.int32).max, size=self.n_estimators)

        # Convert data and validate
        X, y = check_X_y(X, y, ["csr", "csc"])
        n_samples = X.shape[0]

        # Check sample weight
        if sample_weight is not None:
            sample_weight = check_array(sample_weight, ensure_2d=False)
            check_consistent_length(y, sample_weight)

        # Remap output for continuous or binary classification
        self._validate_estimator()

        # Validate max_samples
        if max_samples is None:
            max_samples = self.max_samples

        if not isinstance(max_samples, (numbers.Integral, numbers.Real)):
            raise ValueError("max_samples must be int or float, got %s" % type(max_samples))

        if isinstance(max_samples, numbers.Integral):
            max_samples = min(max_samples, n_samples)
        else:  # float
            if not (0.0 < max_samples <= 1.0):
                raise ValueError("max_samples must be in (0, 1], got %r" % max_samples)
            max_samples = int(round(max_samples * n_samples))

        # Store max_samples
        self._max_samples = max_samples

        # Compute indicator matrix for sequential bootstrap
        if self.active_indices_ is None:
            self.active_indices_ = get_active_indices(self.samples_info_sets, self.price_bars_index)

        # Check if indicator matrix matches data shape
        if len(self.active_indices_) != n_samples:
            raise ValueError(
                f"Indicator matrix shape {len(self.active_indices_)} "
                f"does not match number of samples {n_samples}"
            )

        # Warm start handling
        if not self.warm_start or not hasattr(self, "estimators_"):
            # Free allocated memory, if any
            self.estimators_ = []
            self.estimators_features_ = []
            self._estimators_samples = []  # Use private variable instead of property

        n_more_estimators = self.n_estimators - len(self.estimators_)

        if n_more_estimators < 0:
            raise ValueError(
                "n_estimators=%d must be larger or equal to "
                "len(estimators_)=%d when warm_start==True"
                % (self.n_estimators, len(self.estimators_))
            )

        elif n_more_estimators == 0:
            warn("Warm-start fitting without increasing n_estimators does not " "fit new trees.")
            return self

        # Parallel or sequential construction
        n_jobs, n_estimators, starts = _partition_estimators(n_more_estimators, self.n_jobs)
        total_n_estimators = sum(n_estimators)

        # Generate random seeds for each estimator
        seeds = random_state.randint(MAX_INT, size=n_more_estimators)

        # Store seeds for sklearn compatibility (used by estimators_samples_ property)
        if not hasattr(self, "_seeds"):
            self._seeds = seeds
        else:
            self._seeds = np.concatenate([self._seeds, seeds])

        # Build estimators in parallel
        all_results = Parallel(n_jobs=n_jobs, verbose=self.verbose)(
            delayed(_parallel_build_estimators)(
                n_estimators[i],
                self,
                X,
                y,
                self.active_indices_,
                sample_weight,
                seeds[starts[i] : starts[i + 1]],
                total_n_estimators,
                verbose=self.verbose,
            )
            for i in range(n_jobs)
        )

        # Unpack results
        for result in all_results:
            self.estimators_ += result[0]
            self.estimators_features_ += result[1]
            self._estimators_samples += result[2]  # Use private variable

        # Compute OOB score if requested
        if self.oob_score:
            self._set_oob_score(X, y)

        return self


class SequentiallyBootstrappedBaggingClassifier(
    SequentiallyBootstrappedBaseBagging, BaggingClassifier, ClassifierMixin
):
    """
    A Sequentially Bootstrapped Bagging classifier is an ensemble meta-estimator that fits base
    classifiers each on random subsets of the original dataset generated using
    Sequential Bootstrapping sampling procedure and then aggregate their individual predictions
    to form a final prediction. Such a meta-estimator can typically be used as
    a way to reduce the variance of a black-box estimator (e.g., a decision
    tree), by introducing randomization into its construction procedure and
    then making an ensemble out of it.

    :param samples_info_sets: (pd.Series), The information range on which each record is constructed from
        *samples_info_sets.index*: Time when the information extraction started.
        *samples_info_sets.value*: Time when the information extraction ended.
    :param price_bars_index: (pd.DataFrame)
        Price bars index used in samples_info_sets generation
    :param estimator: (object or None), optional (default=None)
        The base estimator to fit on random subsets of the dataset.
        If None, then the base estimator is a decision tree.
    :param n_estimators: (int), optional (default=10)
        The number of base estimators in the ensemble.
    :param max_samples: (int or float), optional (default=1.0)
        The number of samples to draw from X to train each base estimator.
        If int, then draw `max_samples` samples. If float, then draw `max_samples * X.shape[0]` samples.
    :param max_features: (int or float), optional (default=1.0)
        The number of features to draw from X to train each base estimator.
        If int, then draw `max_features` features. If float, then draw `max_features * X.shape[1]` features.
    :param bootstrap_features: (bool), optional (default=False)
        Whether features are drawn with replacement.
    :param oob_score: (bool), optional (default=False)
        Whether to use out-of-bag samples to estimate
        the generalization error.
    :param warm_start: (bool), optional (default=False)
        When set to True, reuse the solution of the previous call to fit
        and add more estimators to the ensemble, otherwise, just fit
        a whole new ensemble.
    :param n_jobs: (int or None), optional (default=None)
        The number of jobs to run in parallel for both `fit` and `predict`.
        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
        ``-1`` means using all processors.
    :param random_state: (int), RandomState instance or None, optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.
    :param verbose: (int), optional (default=0)
        Controls the verbosity when fitting and predicting.

    :ivar estimator_: (estimator)
        The base estimator from which the ensemble is grown.
    :ivar estimators_: (list of estimators)
        The collection of fitted base estimators.
    :ivar estimators_samples_: (list of arrays)
        The subset of drawn samples (i.e., the in-bag samples) for each base
        estimator. Each subset is defined by an array of the indices selected.
    :ivar estimators_features_: (list of arrays)
        The subset of drawn features for each base estimator.
    :ivar classes_: (array) of shape = [n_classes]
        The classes labels.
    :ivar n_classes_: (int or list)
        The number of classes.
    :ivar oob_score_: (float)
        Score of the training dataset obtained using an out-of-bag estimate.
    """

    def __init__(
        self,
        samples_info_sets,
        price_bars_index,
        estimator=None,
        n_estimators=10,
        max_samples=1.0,
        max_features=1.0,
        bootstrap_features=False,
        oob_score=False,
        warm_start=False,
        n_jobs=None,
        random_state=None,
        verbose=0,
    ):
        super().__init__(
            samples_info_sets=samples_info_sets,
            price_bars_index=price_bars_index,
            estimator=estimator,
            n_estimators=n_estimators,
            max_samples=max_samples,
            max_features=max_features,
            bootstrap_features=bootstrap_features,
            oob_score=oob_score,
            warm_start=warm_start,
            n_jobs=n_jobs,
            random_state=random_state,
            verbose=verbose,
        )

    def _validate_estimator(self):
        """Check the estimator and set the estimator_ attribute."""
        super()._validate_estimator(default=DecisionTreeClassifier())

    def _fit(self, X, y, max_samples=None, sample_weight=None):
        """
        Override _fit to set classes_ and n_classes_ for classifier compatibility.
        """
        # Set classes_ and n_classes_ before calling parent _fit
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)

        # Call parent _fit method
        return super()._fit(X, y, max_samples, sample_weight)

    def _set_oob_score(self, X, y):
        """Compute out-of-bag score"""

        # Safeguard: Ensure n_classes_ is set
        if not hasattr(self, "n_classes_"):
            self.classes_ = np.unique(y)
            self.n_classes_ = len(self.classes_)

        n_samples = y.shape[0]
        n_classes = self.n_classes_

        predictions = np.zeros((n_samples, n_classes))

        for estimator, samples, features in zip(
            self.estimators_, self._estimators_samples, self.estimators_features_
        ):
            # Create mask for OOB samples
            mask = ~indices_to_mask(samples, n_samples)

            if np.any(mask):
                # Get predictions for OOB samples
                X_oob = X[mask][:, features]
                predictions[mask] += estimator.predict_proba(X_oob)

        # Average predictions
        denominator = np.sum(predictions != 0, axis=1)
        denominator[denominator == 0] = 1  # avoid division by zero
        predictions /= denominator[:, np.newaxis]

        # Compute OOB score
        oob_decision_function = predictions
        oob_prediction = np.argmax(predictions, axis=1)

        if n_classes == 2:
            oob_prediction = oob_prediction.astype(np.int64)

        self.oob_decision_function_ = oob_decision_function
        self.oob_prediction_ = oob_prediction
        self.oob_score_ = accuracy_score(y, oob_prediction)


class SequentiallyBootstrappedBaggingRegressor(
    SequentiallyBootstrappedBaseBagging, BaggingRegressor, RegressorMixin
):
    """
    A Sequentially Bootstrapped Bagging regressor is an ensemble meta-estimator that fits base
    regressors each on random subsets of the original dataset using Sequential Bootstrapping and then
    aggregate their individual predictions (either by voting or by averaging)
    to form a final prediction. Such a meta-estimator can typically be used as
    a way to reduce the variance of a black-box estimator (e.g., a decision
    tree), by introducing randomization into its construction procedure and
    then making an ensemble out of it.

    :param samples_info_sets: (pd.Series), The information range on which each record is constructed from
        *samples_info_sets.index*: Time when the information extraction started.
        *samples_info_sets.value*: Time when the information extraction ended.

    :param price_bars_index: (pd.DatetimeIndex)
        Index of price bars used in samples_info_sets generation
    :param estimator: (object or None), optional (default=None)
        The base estimator to fit on random subsets of the dataset. If None, then the base estimator is a decision tree.
    :param n_estimators: (int), optional (default=10)
        The number of base estimators in the ensemble.
    :param max_samples: (int or float), optional (default=1.0)
        The number of samples to draw from X to train each base estimator.
        If int, then draw `max_samples` samples. If float, then draw `max_samples * X.shape[0]` samples.
    :param max_features: (int or float), optional (default=1.0)
        The number of features to draw from X to train each base estimator.
        If int, then draw `max_features` features. If float, then draw `max_features * X.shape[1]` features.
    :param bootstrap_features: (bool), optional (default=False)
        Whether features are drawn with replacement.
    :param oob_score: (bool)
        Whether to use out-of-bag samples to estimate
        the generalization error.
    :param warm_start: (bool), optional (default=False)
        When set to True, reuse the solution of the previous call to fit
        and add more estimators to the ensemble, otherwise, just fit
        a whole new ensemble.
    :param n_jobs: (int or None), optional (default=None)
        The number of jobs to run in parallel for both `fit` and `predict`.
        ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
        ``-1`` means using all processors.
    :param random_state: (int, RandomState instance or None), optional (default=None)
        If int, random_state is the seed used by the random number generator;
        If RandomState instance, random_state is the random number generator;
        If None, the random number generator is the RandomState instance used
        by `np.random`.
    :param verbose: (int), optional (default=0)
        Controls the verbosity when fitting and predicting.

    :ivar estimators_: (list) of estimators
        The collection of fitted sub-estimators.
    :ivar estimators_samples_: (list) of arrays
        The subset of drawn samples (i.e., the in-bag samples) for each base
        estimator. Each subset is defined by an array of the indices selected.
    :ivar estimators_features_: (list) of arrays
        The subset of drawn features for each base estimator.
    :ivar oob_score_: (float)
        Score of the training dataset obtained using an out-of-bag estimate.
    :ivar oob_prediction_: (array) of shape = [n_samples]
        Prediction computed with out-of-bag estimate on the training
        set. If n_estimators is small it might be possible that a data point
        was never left out during the bootstrap. In this case,
        `oob_prediction_` might contain NaN.
    """

    def __init__(
        self,
        samples_info_sets,
        price_bars_index,
        estimator=None,
        n_estimators=10,
        max_samples=1.0,
        max_features=1.0,
        bootstrap_features=False,
        oob_score=False,
        warm_start=False,
        n_jobs=None,
        random_state=None,
        verbose=0,
    ):
        super().__init__(
            samples_info_sets=samples_info_sets,
            price_bars_index=price_bars_index,
            estimator=estimator,
            n_estimators=n_estimators,
            max_samples=max_samples,
            max_features=max_features,
            bootstrap_features=bootstrap_features,
            oob_score=oob_score,
            warm_start=warm_start,
            n_jobs=n_jobs,
            random_state=random_state,
            verbose=verbose,
        )

    def _validate_estimator(self):
        """Check the estimator and set the estimator_ attribute."""
        super()._validate_estimator(default=DecisionTreeRegressor())

    def _set_oob_score(self, X, y):
        """Compute out-of-bag score"""
        n_samples = y.shape[0]
        predictions = np.zeros(n_samples)
        n_predictions = np.zeros(n_samples)

        for estimator, samples, features in zip(
            self.estimators_, self._estimators_samples, self.estimators_features_
        ):
            # Create mask for OOB samples
            mask = ~indices_to_mask(samples, n_samples)

            if np.any(mask):
                # Get predictions for OOB samples
                X_oob = X[mask][:, features]
                predictions[mask] += estimator.predict(X_oob)
                n_predictions[mask] += 1

        # Avoid division by zero
        mask = n_predictions > 0
        if np.any(mask):
            predictions[mask] /= n_predictions[mask]

        self.oob_prediction_ = predictions
        self.oob_score_ = r2_score(y[mask], predictions[mask])


def compute_custom_oob_metrics(clf, X, y, sample_weight=None):
    """
    Compute custom OOB metrics (F1, AUC, precision/recall) for a fitted ensemble.

    Args:
        clf: Fitted SequentiallyBootstrappedBaggingClassifier
        X: Feature matrix used in training
        y: True labels
        sample_weight: Optional sample weights

    Returns:
        dict: Custom OOB metric values
    """
    from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score

    n_samples = y.shape[0]
    n_classes = clf.n_classes_

    # Accumulate OOB predictions
    oob_proba = np.zeros((n_samples, n_classes))
    oob_count = np.zeros(n_samples)

    for estimator, samples, features in zip(
        clf.estimators_, clf.estimators_samples_, clf.estimators_features_
    ):
        mask = ~indices_to_mask(samples, n_samples)
        if np.any(mask):
            X_oob = X[mask][:, features]
            oob_proba[mask] += estimator.predict_proba(X_oob)
            oob_count[mask] += 1

    # Average and get predictions
    oob_mask = oob_count > 0
    oob_proba[oob_mask] /= oob_count[oob_mask, np.newaxis]
    oob_pred = np.argmax(oob_proba, axis=1)

    # Compute metrics on samples with OOB predictions
    y_oob = y[oob_mask]
    pred_oob = oob_pred[oob_mask]
    proba_oob = oob_proba[oob_mask]

    metrics = {
        "f1": f1_score(y_oob, pred_oob, average="weighted"),
        "precision": precision_score(y_oob, pred_oob, average="weighted"),
        "recall": recall_score(y_oob, pred_oob, average="weighted"),
        "coverage": oob_mask.sum() / n_samples,  # Fraction with OOB predictions
    }

    # Add AUC for binary classification
    if n_classes == 2:
        metrics["auc"] = roc_auc_score(y_oob, proba_oob[:, 1])

    return metrics
