from bots.botlibs.tester_lib import tester, tester_slow, test_model
from datetime import datetime
import pandas as pd
import random
import time
from numba import jit
import numpy as np


def get_prices() -> pd.DataFrame:
    p = pd.read_csv("files/" + hyper_params["symbol"] + ".csv", sep="\s+")
    pFixed = pd.DataFrame(columns=["time", "close"])
    pFixed["time"] = p["<DATE>"] + " " + p["<TIME>"]
    pFixed["time"] = pd.to_datetime(pFixed["time"], format="mixed")
    pFixed["close"] = p["<BID>"]
    pFixed.set_index("time", inplace=True)
    pFixed.index = pd.to_datetime(pFixed.index, unit="s")
    # Delete repeating strings by 'time' index
    pFixed = pFixed[~pFixed.index.duplicated(keep="first")]
    return pFixed.dropna()


def get_labels(dataset, min=1, max=15) -> pd.DataFrame:
    labels = []
    for i in range(dataset.shape[0] - max):
        rand = random.randint(min, max)
        curr_pr = dataset["close"].iloc[i]
        future_pr = dataset["close"].iloc[i + rand]

        if (future_pr + hyper_params["markup"]) < curr_pr:
            labels.append(1.0)
        elif (future_pr - hyper_params["markup"]) > curr_pr:
            labels.append(0.0)
        else:
            labels.append(2.0)

    dataset = dataset.iloc[: len(labels)].copy()
    dataset["labels"] = labels
    dataset = dataset.dropna()
    dataset = dataset.drop(dataset[dataset.labels == 2.0].index)
    return dataset


@jit(nopython=True)
def get_labels_numba(close_prices, min_val, max_val, markup):
    labels = np.empty(len(close_prices) - max_val, dtype=np.float64)
    for i in range(len(close_prices) - max_val):
        rand = np.random.randint(min_val, max_val + 1)
        curr_pr = close_prices[i]
        future_pr = close_prices[i + rand]

        if (future_pr + markup) < curr_pr:
            labels[i] = 1.0
        elif (future_pr - markup) > curr_pr:
            labels[i] = 0.0
        else:
            labels[i] = 2.0

    return labels


def get_labels_fast(dataset, min_val=1, max_val=15):
    close_prices = dataset["close"].values
    markup = hyper_params["markup"]

    labels = get_labels_numba(close_prices, min_val, max_val, markup)

    dataset = dataset.iloc[: len(labels)].copy()
    dataset["labels"] = labels
    dataset = dataset.dropna()
    dataset = dataset.drop(dataset[dataset.labels == 2.0].index)

    return dataset


hyper_params = {
    "symbol": "EURGBP_ticks",
    "markup": 0.00010,
    "stop_loss": 0.0100,
    "take_profit": 0.0100,
    "backward": datetime(2010, 1, 1),
    "forward": datetime(2024, 9, 9),
}

pr = get_prices()

# get labels test
start_time = time.time()
pr = get_labels_fast(pr)
pr["meta_labels"] = 1.0
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time:.4f} seconds")


# numba tester test
start_time = time.time()
tester(
    pr,
    hyper_params["stop_loss"],
    hyper_params["take_profit"],
    hyper_params["forward"],
    hyper_params["backward"],
    hyper_params["markup"],
    False,
)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time:.4f} seconds")

# native python tester test
start_time = time.time()
tester_slow(
    pr,
    hyper_params["stop_loss"],
    hyper_params["take_profit"],
    hyper_params["markup"],
    hyper_params["forward"],
    False,
)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time:.4f} seconds")
