import MetaTrader5 as mt
import pandas as pd
import numpy as np
import os
import tiktoken


DATA_DIR = os.path.dirname(__file__)
data_file = os.path.join(DATA_DIR, "llm_data.csv")

def get_data():
    mt_data_len=2500
    sr_len=60

    if not mt.initialize():
        print("mt initialize failed!")
    else:
        sbs=mt.symbols_get(group='*micro*')
        if sbs is  not  None:
            # for i in [mt.TIMEFRAME_M5,mt.TIMEFRAME_M15,mt.TIMEFRAME_H1,mt.TIMEFRAME_D1]:
            for i in [mt.TIMEFRAME_M5,]: 
                xy=None
                # xy_list=[]
                ct=0
                for j in sbs: 
                    if ct>0:
                        break 
                    print(j.name)
                    d_=mt.copy_rates_from_pos(j.name,i,0,mt_data_len)
                    df_d=pd.DataFrame(d_)
                    cl_d=df_d['close']
                    k=0
                    while k+1:
                        if mt_data_len-k>=sr_len:
                            cl_ds=cl_d[k:k+sr_len].tolist()
                            if xy is None:
                                xy=pd.DataFrame([cl_ds])
                                # xy_list=[cl_ds]
                            else:
                                xy.loc[len(xy)]=cl_ds
                                # xy_list.append(cl_ds)
                            k+=1                       
                        else:
                            break
                    ct+=1
            mt.shutdown()
            
    # print(len(xy),"   ",len(xy_list))
    xy.to_csv(data_file)
    # xy.to_json(f'llm_data.json')
    return xy

def data_to_file(path, tks):

    header = np.zeros(256, dtype=np.int32)
    header[0] = 20240520 
    header[1] = 1 
    header[2] = len(tks) 
    toks_np = np.array(tks, dtype=np.uint16)
    with open(path, "wb") as f:
        f.write(header.tobytes())
        f.write(toks_np.tobytes())

if __name__=="__main__":

    data=get_data()
    # data=pd.read_csv(data_file)
    # data=data.iloc[1:,1:]

    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode_ordinary(s)
    eot = enc._special_tokens['<|endoftext|>']

    train_tokens=[]
    val_tokens=[]
    val_cut=len(data)//10
    for i,r in data.iterrows():
        ser=r.tolist()
        ser=''.join(str(elem) for elem in ser)
        # ser = ser.strip() 
        tokens = encode(ser)
        if i< val_cut:
            val_tokens.append(eot)
            val_tokens.extend(tokens)
            enc_f = os.path.join(DATA_DIR, "val_data.bin")
            data_to_file(enc_f, val_tokens)
        else:
            train_tokens.append(eot)
            train_tokens.extend(tokens)
            enc_f = os.path.join(DATA_DIR, "train_data.bin")
            data_to_file(enc_f, train_tokens)
    print(f"tain:{len(train_tokens)}",f"val:{len(val_tokens)}")

