import socket
import json
from time import sleep
import pandas as pd
import numpy as np
import warnings
from pytorch_forecasting import NBeats


import base64
import hashlib
import struct

warnings.filterwarnings("ignore")

max_encoder_length=96
max_prediction_length=20
info_file="results.json"

def load_model():
    with open(info_file) as f:
            m_p=json.load(fp=f)['last_best_model']
    model = NBeats.load_from_checkpoint(m_p)
    return model


def eva(msg,model):
        offset=1
        msg=np.fromstring(msg, dtype=float, sep= ',') 
        # print(msg)
        dt=pd.DataFrame(msg)
        dt=dt.iloc[-max_encoder_length-offset:-offset,:]
        last_=dt.iloc[-1] 
        for i in range(1,max_prediction_length+1):
            dt.loc[dt.index[-1]+1]=last_
        dt['close']=dt
        dt['series']=0
        dt['time_idx']=dt.index-dt.index[0]
        # print(dt)
        predictions = model.predict(dt, mode='raw',trainer_kwargs=dict(accelerator="cpu",logger=False),return_x=True)
        trend =predictions.output["trend"][0].detach().cpu()
        if (trend[-1]-trend.mean()) >= 0:
            return "buy" 
        else:
            return "sell"

class server_:
    def __init__(self, host = '127.0.0.1', port = 8989):
        self.sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.host = host
        self.port = port
        self.sk.bind((self.host, self.port))
        self.re = ''
        self.model=load_model()
        self.stop=None
        self.sk.listen(1)
        self.sk_, self.ad_ = self.sk.accept()
        print('server running：',self.sk_, self.ad_)  

    def msg(self):
        self.re = ''
        wsk=False
        while True:
            data = self.sk_.recv(2500)
            if not data:
                break
            if (data[1] & 0x80) >> 7:

                fin = (data[0] & 0x80) >> 7 # FIN bit
                opcode = data[0] & 0x0f # opcode
                masked = (data[1] & 0x80) >> 7 # mask bit
                mask = data[4:8] # masking key
                payload = data[8:] # payload data

                print('fin is：{},opcode is：{}，mask:{}'.format(fin,opcode,masked))
                message = ""
                for i in range(len(payload)):
                    message += chr(payload[i] ^ mask[i % 4])
                data=message
                wsk=True
            else:
                data=data.decode("utf-8")

            if '\r\n\r\n' in data: 
                key = data.split("\r\n")[4].split(": ")[1]
                print(key)
                GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

                ac = base64.b64encode(hashlib.sha1((key+GUID).encode('utf-8')).digest())

                response_tpl="HTTP/1.1 101 Switching Protocols\r\n" \
                            "Upgrade:websocket\r\n" \
                            "Connection: Upgrade\r\n" \
                            "Sec-WebSocket-Accept: %s\r\n" \
                            "WebSocket-Location: ws://%s/\r\n\r\n"                
                response_str = response_tpl % (ac.decode('utf-8'), "127.0.0.1:8989")
                self.sk_.send(bytes(response_str, encoding='utf-8')) 
                
                data=data.split('\r\n\r\n',1)[1]
            if "stop" in data:
                self.stop=True
                break
            if len(data)<200:
                 break
            self.re+=data
            bt=eva(self.re, self.model)
            bt=bytes(bt, "utf-8")

            if wsk:
                 tk=b'\x81'
                 lgt=len(bt)
                 tk+=struct.pack('B',lgt)
                 bt=tk+bt
            self.sk_.sendall(bt)
        return self.re
        
    def __del__(self):
        print("server closed!")
        self.sk.close()
        if self.sk_ is not None:
            self.sk_.close()
            self.ad_.close()
        


sv = server_()

while True:
     rem=sv.msg()
     if sv.stop:
          break
     sleep(0.5)


