from bots.botlibs.tester_lib import test_model
from datetime import datetime
import pandas as pd
from numba import jit
import numpy as np
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize
import time


def get_prices() -> pd.DataFrame:
    p = pd.read_csv("files/" + hyper_params["symbol"] + ".csv", sep="\s+")
    pFixed = pd.DataFrame(columns=["time", "close"])
    pFixed["time"] = p["<DATE>"] + " " + p["<TIME>"]
    pFixed["time"] = pd.to_datetime(pFixed["time"], format="mixed")
    pFixed["close"] = p["<CLOSE>"]
    pFixed.set_index("time", inplace=True)
    pFixed.index = pd.to_datetime(pFixed.index, unit="s")
    return pFixed.dropna()


def get_features(data: pd.DataFrame) -> pd.DataFrame:
    pFixed = data.copy()
    pFixedC = data.copy()
    count = 0
    for i in hyper_params["periods"]:
        pFixed[str(count)] = pFixedC - pFixedC.rolling(i).mean()
        count += 1
    return pFixed.dropna()


@jit(nopython=True)
def get_labels_numba(close_prices, min_val, max_val, markup):
    labels = np.empty(len(close_prices) - max_val, dtype=np.float64)
    for i in range(len(close_prices) - max_val):
        rand = np.random.randint(min_val, max_val + 1)
        curr_pr = close_prices[i]
        future_pr = close_prices[i + rand]

        if (future_pr + markup) < curr_pr:
            labels[i] = 1.0
        elif (future_pr - markup) > curr_pr:
            labels[i] = 0.0
        else:
            labels[i] = 2.0

    return labels


def get_labels_fast(dataset, min_val=1, max_val=15):
    close_prices = dataset["close"].values
    markup = hyper_params["markup"]

    labels = get_labels_numba(close_prices, min_val, max_val, markup)

    dataset = dataset.iloc[: len(labels)].copy()
    dataset["labels"] = labels
    dataset = dataset.dropna()
    dataset = dataset.drop(dataset[dataset.labels == 2.0].index)

    return dataset


hyper_params = {
    "symbol": "EURGBP_H1",
    "markup": 0.00010,
    "stop_loss": 0.0200,
    "take_profit": 0.0200,
    "backward": datetime(2010, 1, 1),
    "forward": datetime(2023, 1, 1),
    "periods": [i for i in range(50, 300, 50)],
}


# catboost learning
dataset = get_labels_fast(get_features(get_prices()))
dataset["meta_labels"] = 1.0
data = dataset[
    (dataset.index < hyper_params["forward"])
    & (dataset.index > hyper_params["backward"])
].copy()

X = data[data.columns[1:-2]]
y = data["labels"]

train_X, test_X, train_y, test_y = train_test_split(
    X, y, train_size=0.7, test_size=0.3, shuffle=True
)

model = CatBoostClassifier(
    iterations=500,
    thread_count=8,
    custom_loss=["Accuracy"],
    eval_metric="Accuracy",
    verbose=True,
    use_best_model=True,
    task_type="CPU",
)

model.fit(
    train_X, train_y, eval_set=(test_X, test_y), early_stopping_rounds=25, plot=False
)

# test catboost model
test_model(
    dataset,
    [model],
    hyper_params["stop_loss"],
    hyper_params["take_profit"],
    hyper_params["forward"],
    hyper_params["backward"],
    hyper_params["markup"],
    True,
)


# stop loss / take profit grid search
def optimize_params_GRID_SEARCH(pr, model, hyper_params, test_model_func):
    best_r2 = -np.inf
    best_stop_loss = None
    best_take_profit = None

    # Ranges for stop_loss and take_profit
    stop_loss_range = np.arange(0.00100, 0.02001, 0.00100)
    take_profit_range = np.arange(0.00100, 0.02001, 0.00100)

    total_iterations = len(stop_loss_range) * len(take_profit_range)
    start_time = time.time()

    for stop_loss in stop_loss_range:
        for take_profit in take_profit_range:
            # Create a copy of hyper_params
            current_hyper_params = hyper_params.copy()
            current_hyper_params["stop_loss"] = stop_loss
            current_hyper_params["take_profit"] = take_profit

            r2 = test_model_func(
                pr,
                [model],
                current_hyper_params["stop_loss"],
                current_hyper_params["take_profit"],
                current_hyper_params["forward"],
                current_hyper_params["backward"],
                current_hyper_params["markup"],
                False,
            )

            if r2 > best_r2:
                best_r2 = r2
                best_stop_loss = stop_loss
                best_take_profit = take_profit

    end_time = time.time()
    total_time = end_time - start_time
    average_time_per_iteration = total_time / total_iterations

    print(f"Total iterations: {total_iterations}")
    print(f"Average time per iteration: {average_time_per_iteration:.6f} seconds")
    print(f"Total time: {total_time:.6f} seconds")

    return best_stop_loss, best_take_profit, best_r2


def optimize_params_L_BFGS_B(pr, model, hyper_params, test_model_func):
    def objective(x):
        current_hyper_params = hyper_params.copy()
        current_hyper_params["stop_loss"] = x[0]
        current_hyper_params["take_profit"] = x[1]

        r2 = test_model_func(
            pr,
            [model],
            current_hyper_params["stop_loss"],
            current_hyper_params["take_profit"],
            current_hyper_params["forward"],
            current_hyper_params["backward"],
            current_hyper_params["markup"],
            False,
        )
        return -r2

    bounds = ((0.001, 0.02), (0.001, 0.02))

    # Try several random starting points
    n_attempts = 50
    best_result = None
    best_fun = float("inf")

    start_time = time.time()
    for _ in range(n_attempts):
        # Random starting point
        x0 = np.random.uniform(0.001, 0.02, 2)

        result = minimize(
            objective,
            x0,
            method="L-BFGS-B",
            bounds=bounds,
            options={
                "ftol": 1e-5,
                "disp": False,
                "maxiter": 100,
            },  # Increase accuracy and number of iterations
        )

        if result.fun < best_fun:
            best_fun = result.fun
            best_result = result
    # Get the end time and calculate the total time
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total time: {total_time:.6f} seconds")

    return best_result.x[0], best_result.x[1], -best_result.fun


# using
best_stop_loss, best_take_profit, best_r2 = optimize_params_GRID_SEARCH(
    dataset, model, hyper_params, test_model
)
best_stop_loss, best_take_profit, best_r2 = optimize_params_L_BFGS_B(
    dataset, model, hyper_params, test_model
)


print(
    f"Best parameters: stop_loss={best_stop_loss}, take_profit={best_take_profit}, R^2={best_r2}"
)


# test with optimal sl/tp
test_model(
    dataset,
    [model],
    best_stop_loss,
    best_take_profit,
    hyper_params["forward"],
    hyper_params["backward"],
    hyper_params["markup"],
    True,
)
