"""
afml/transaction_costs.py

Loads broker transaction cost data exported by TransactionCostCollector.mq5
and derives the min_ret threshold for triple-barrier label construction.

The threshold is the minimum return a trade must achieve to cover the
round-trip transaction cost at a given spread percentile, slippage
assumption, and holding-period-adjusted swap accrual.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import pandas as pd


@dataclass
class TransactionCostModel:
    """
    Broker-specific transaction cost model for a single symbol.

    All costs are expressed as fractional returns (e.g., 0.0001 = 1 pip
    on a 1.0000 priced instrument) so they can be directly compared to
    the return series used for triple-barrier labeling.

    Parameters
    ----------
    symbol : str
        Instrument identifier (e.g., "EURUSD").
    spread_pips : float
        Spread to use for cost calculation. Recommend p95 from the
        collected distribution — not the mean — because entries during
        high-spread periods are disproportionately costly.
    slippage_pips : float
        One-way slippage estimate. Derived from live/demo trade log.
        Default 0.5 pips is conservative for major forex pairs.
    commission_per_lot : float
        Round-trip commission in account currency per standard lot.
        Must be confirmed from a reference trade (see MQL5 script note).
    swap_long_per_night : float
        Swap in native MQL5 units for long positions.
    swap_short_per_night : float
        Swap in native MQL5 units for short positions (usually negative).
    swap_mode : str
        MQL5 SYMBOL_SWAP_MODE string from the CSV.
    swap_triple_day : int
        Weekday on which triple swap is charged (0=Sun … 6=Sat).
    pip_factor : float
        Points per pip. 10 for 5-digit brokers, 1 for 4-digit.
    point : float
        Broker point size (e.g., 0.00001 for EURUSD 5-digit).
    tick_value : float
        Account currency value of one tick per standard lot.
    tick_size : float
        Minimum price movement.
    contract_size : float
        Units per standard lot (e.g., 100,000 for forex).
    lot_size : float
        Lot size used for this strategy (e.g., 0.1 mini-lot).
    account_currency_rate : float
        Exchange rate from profit currency to account currency.
        Set to 1.0 if profit currency == account currency.
    spread_by_hour : dict[int, float]
        Hour-of-day mean spread in pips (broker time). Used for
        session-aware cost estimation.
    """

    symbol:               str
    spread_pips:          float
    slippage_pips:        float   = 0.5
    commission_per_lot:   float   = 0.0
    swap_long_per_night:  float   = 0.0
    swap_short_per_night: float   = 0.0
    swap_mode:            str     = "points"
    swap_triple_day:      int     = 3        # Wednesday default
    pip_factor:           float   = 10.0
    point:                float   = 0.00001
    tick_value:           float   = 10.0
    tick_size:            float   = 0.00001
    contract_size:        float   = 100_000.0
    lot_size:             float   = 0.01
    account_currency_rate: float  = 1.0
    spread_by_hour:       dict[int, float] = field(default_factory=dict)

    # ── Derived helpers ───────────────────────────────────────────────────────

    @property
    def pip_value(self) -> float:
        """Account currency value of one pip per lot_size."""
        pips_per_point = 1.0 / self.pip_factor
        return (self.tick_value / self.tick_size) * (
            self.pip_factor * self.point
        ) * self.lot_size * self.account_currency_rate

    def spread_cost_frac(self, entry_price: float) -> float:
        """
        Round-trip spread cost as a fraction of entry price.

        One-way cost = 0.5 * spread (you pay half the spread on entry
        and half on exit when using limit orders at mid-price).
        For market orders, the full spread applies on entry.
        """
        spread_price = self.spread_pips * self.pip_factor * self.point
        return spread_price / entry_price  # full spread, round trip

    def slippage_cost_frac(self, entry_price: float) -> float:
        """Round-trip slippage as a fraction of entry price (2× one-way)."""
        slippage_price = (
            self.slippage_pips * self.pip_factor * self.point * 2
        )
        return slippage_price / entry_price

    def commission_cost_frac(self, entry_price: float) -> float:
        """
        Round-trip commission as a fraction of entry price.

        commission_per_lot is typically per-side; multiply by 2 for
        round trip. Divide by notional value to get a fraction.
        """
        notional = entry_price * self.contract_size * self.lot_size
        if notional == 0:
            return 0.0
        return (self.commission_per_lot * 2 * self.account_currency_rate) / notional

    def swap_cost_frac(
        self,
        entry_price: float,
        holding_days: float,
        side: int = 1,
    ) -> float:
        """
        Swap accrual as a fraction of entry price for a given holding period.

        Handles triple-swap day: if holding_days >= 3 and the position
        spans the triple day, an additional 2 nights of swap are charged.

        Parameters
        ----------
        entry_price  : float  Entry price of the trade.
        holding_days : float  Expected holding period in calendar days.
        side         : int    +1 for long, -1 for short.
        """
        rate = (
            self.swap_long_per_night if side >= 0
            else self.swap_short_per_night
        )

        if self.swap_mode == "points":
            swap_per_night_price = rate * self.point
            nights = holding_days
            # Add 2 extra nights if holding period spans the triple day
            if holding_days >= 3:
                nights += 2
            total_swap_price = abs(swap_per_night_price) * nights
            return total_swap_price / entry_price if entry_price > 0 else 0.0

        elif self.swap_mode in ("currency", "currency_mrgn", "currency_dep"):
            # Swap in account currency per lot per night
            nights = holding_days + (2 if holding_days >= 3 else 0)
            total_swap = abs(rate) * self.lot_size * nights
            notional = entry_price * self.contract_size * self.lot_size
            return total_swap / notional if notional > 0 else 0.0

        elif self.swap_mode in ("interest_open", "interest_curr"):
            # Annual interest rate — convert to nightly fraction
            nights = holding_days + (2 if holding_days >= 3 else 0)
            nightly_rate = abs(rate) / 100.0 / 365.0
            return nightly_rate * nights

        return 0.0

    def round_trip_cost_frac(
        self,
        entry_price: float,
        holding_days: float = 0.0,
        side: int = 1,
    ) -> float:
        """
        Total round-trip cost as a fraction of entry price.

        Parameters
        ----------
        entry_price  : Reference price (e.g., close at label bar).
        holding_days : Expected holding period. Used for swap accrual.
                       Pass 0.0 for intraday strategies (no overnight).
        side         : +1 long, -1 short. Affects swap direction.
        """
        return (
            self.spread_cost_frac(entry_price)
            + self.slippage_cost_frac(entry_price)
            + self.commission_cost_frac(entry_price)
            + self.swap_cost_frac(entry_price, holding_days, side)
        )

    def min_ret_for_symbol(
        self,
        price_series: pd.Series,
        holding_days: float = 0.0,
        side: int = 1,
        cost_multiplier: float = 1.5,
    ) -> float:
        """
        Derive the min_ret threshold for triple-barrier labeling.

        The threshold is set at cost_multiplier × median round-trip
        cost across the price series. A cost_multiplier of 1.5 means
        the profit barrier must be at least 1.5× the transaction cost
        to receive a positive label — trades that barely cover costs
        are treated as non-events.

        Parameters
        ----------
        price_series    : pd.Series  Close prices (same index as labels).
        holding_days    : float      Expected average holding period.
        side            : int        +1 or -1. Use 1 if unknown.
        cost_multiplier : float      Safety margin above break-even.
                                     1.0 = break-even; 1.5 = recommended.

        Returns
        -------
        float
            min_ret value to pass to get_events() or equivalent.
        """
        costs = price_series.apply(
            lambda p: self.round_trip_cost_frac(p, holding_days, side)
        )
        return float(costs.median() * cost_multiplier)

    def session_adjusted_spread_pips(self, hour: int) -> float:
        """
        Return mean spread for a given hour-of-day (broker time).
        Falls back to the model's spread_pips if hour is not in data.
        """
        return self.spread_by_hour.get(hour, self.spread_pips)

    def summary(self, entry_price: float, holding_days: float = 1.0) -> dict:
        """Human-readable cost breakdown for a reference trade."""
        return {
            "spread_frac":      self.spread_cost_frac(entry_price),
            "slippage_frac":    self.slippage_cost_frac(entry_price),
            "commission_frac":  self.commission_cost_frac(entry_price),
            "swap_long_frac":   self.swap_cost_frac(entry_price, holding_days, 1),
            "swap_short_frac":  self.swap_cost_frac(entry_price, holding_days, -1),
            "total_long_frac":  self.round_trip_cost_frac(entry_price, holding_days, 1),
            "total_short_frac": self.round_trip_cost_frac(entry_price, holding_days, -1),
            "total_long_pips":  self.round_trip_cost_frac(entry_price, holding_days, 1)
                                * entry_price / (self.pip_factor * self.point),
            "total_short_pips": self.round_trip_cost_frac(entry_price, holding_days, -1)
                                * entry_price / (self.pip_factor * self.point),
        }


# ── CSV loader ────────────────────────────────────────────────────────────────

def load_cost_model(
    csv_path: Path,
    spread_percentile: str = "p95_pips",
    slippage_pips: float   = 0.5,
    commission_per_lot: float = 0.0,
    lot_size: float        = 0.01,
    account_currency_rate: float = 1.0,
) -> TransactionCostModel:
    """
    Build a TransactionCostModel from a CSV exported by
    TransactionCostCollector.mq5.

    Parameters
    ----------
    csv_path           : Path to the <symbol>_costs.csv file.
    spread_percentile  : Which spread statistic to use as the model
                         spread. Options: mean_pips, p50_pips, p95_pips,
                         p99_pips. Default p95_pips is recommended.
    slippage_pips      : One-way slippage (not in CSV — must be supplied
                         from live/demo trade log analysis).
    commission_per_lot : Round-trip commission per standard lot in
                         account currency. Must be confirmed from a
                         reference trade (see MQL5 script note in CSV).
    lot_size           : Strategy lot size.
    account_currency_rate : Profit-to-account-currency exchange rate.

    Returns
    -------
    TransactionCostModel
    """
    df = pd.read_csv(csv_path)

    def get(section: str, key: str) -> str:
        row = df[(df["section"] == section) & (df["key"] == key)]
        if row.empty:
            raise KeyError(f"Missing: section={section!r} key={key!r} in {csv_path}")
        return str(row["value"].iloc[0]).strip()

    symbol       = get("symbol_properties", "symbol")
    digits       = int(get("symbol_properties", "digits"))
    point        = float(get("symbol_properties", "point"))
    pip_factor   = float(get("symbol_properties", "pip_factor"))
    tick_size    = float(get("symbol_properties", "tick_size"))
    tick_value   = float(get("symbol_properties", "tick_value"))
    contract_sz  = float(get("symbol_properties", "contract_size"))

    swap_long    = float(get("swap", "swap_long"))
    swap_short   = float(get("swap", "swap_short"))
    swap_mode    = get("swap", "swap_mode")
    swap_3day    = int(get("swap", "swap_3day"))

    spread_pips  = float(get("spread_summary", spread_percentile))

    # Session spreads
    hour_rows = df[df["section"] == "spread_by_hour"]
    spread_by_hour: dict[int, float] = {}
    for _, row in hour_rows.iterrows():
        hour = int(str(row["key"]).replace("hour_", ""))
        spread_by_hour[hour] = float(row["value"])

    return TransactionCostModel(
        symbol               = symbol,
        spread_pips          = spread_pips,
        slippage_pips        = slippage_pips,
        commission_per_lot   = commission_per_lot,
        swap_long_per_night  = swap_long,
        swap_short_per_night = swap_short,
        swap_mode            = swap_mode,
        swap_triple_day      = swap_3day,
        pip_factor           = pip_factor,
        point                = point,
        tick_value           = tick_value,
        tick_size            = tick_size,
        contract_size        = contract_sz,
        lot_size             = lot_size,
        account_currency_rate= account_currency_rate,
        spread_by_hour       = spread_by_hour,
    )