import socket
from time import sleep
import pandas as pd
import numpy as np
import warnings
import base64
import hashlib
import struct
from torch2onnx import GPT2LMHeadModelWithAdapters,Adapter
from transformers import AutoTokenizer
import logging
import torch
from statistics import mean
# Set logging and warning
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")
# Set device
dvc='cuda' if torch.cuda.is_available() else 'cpu'

# Global 
model_id = "gpt2_Adapter-tuning"
encoder_length=20
prediction_length=10
info_file="results.json"
host="0.0.0.0"
port=10055
# Function to loda model
def load_model():
    try:
        # Load the model
        model = GPT2LMHeadModelWithAdapters.from_pretrained(model_id).to(dvc)
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained('gpt2')
        print("Model loaded!")
        return  model,tokenizer
    except Exception as e:
        # Log any errors that occur during loading
        logging.error(f"Error loading model and tokenizer: {e}")
        raise

def eva(msg,model,tokenizer):
        
        # Get the data
        msg=np.fromstring(msg, dtype=float, sep= ',').tolist()
        # Parse the data
        input_data=msg[-encoder_length:]
        # Create the prompt
        prompt = ' '.join(map(str, input_data))
        # Generate the predication
        token=tokenizer.encode(prompt, return_tensors='pt').to(dvc)
        attention_mask = torch.ones_like(token).to(dvc)
        model.eval() 
        generated = tokenizer.decode(
            model.generate(
                token, 
                attention_mask=attention_mask,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True, 
                max_length=200)[0], 
            skip_special_tokens=True)
        generated_prices=generated.split('\n')[0]

        # Remove non-numeric formats
        def try_float(s):
            try:
                return float(s)
            except ValueError:
                return None
        generated_prices=generated_prices.split()
        generated_prices=list(map(try_float,generated_prices))
        generated_prices = [f for f in generated_prices if f is not None]

        generated_prices=generated_prices[0:prediction_length]
        
        # Calculate and send the results
        last_price=input_data[-1]
        prediction_mean=mean(generated_prices)
        if (last_price-prediction_mean) >= 0:
            # print('Send sell.')
            return "sell" 
        else:
            # print("Send buy.")
            return "buy"

class server_:
    def __init__(self, host = host, port = port):
        self.sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM,)
        self.host = host
        self.port = port
        self.sk.bind((self.host, self.port))
        self.re = ''
        self.model,self.tokenizer=load_model()
        self.stop=None
        self.sk.listen(1)
        self.sk_, self.ad_ = self.sk.accept()
        self.last_action=None
        print('server running：',self.sk_, self.ad_)  

    def msg(self):
        self.re = ''
        wsk=False
        while True:
            sleep(0.5)
            if self.is_connected():
                try:
                    data = self.sk_.recv(2500)
                except Exception as e:
                    break
                if not data:
                    break
                if (data[1] & 0x80) >> 7:
                    fin = (data[0] & 0x80) >> 7 # FIN bit
                    opcode = data[0] & 0x0f # opcode
                    masked = (data[1] & 0x80) >> 7 # mask bit
                    mask = data[4:8] # masking key
                    payload = data[8:] # payload data

                    # print('fin is：{},opcode is：{}，mask:{}'.format(fin,opcode,masked))
                    message = ""
                    for i in range(len(payload)):
                        message += chr(payload[i] ^ mask[i % 4])
                    data=message
                    wsk=True
                else:
                    data=data.decode("utf-8")

                if '\r\n\r\n' in data: 
                    self.handshake(data)
                    data=data.split('\r\n\r\n',1)[1]
                if "stop" in data:
                    self.stop=True
                    break
                if len(data)<50:
                    break
                self.re+=data
                bt=eva(self.re, self.model,self.tokenizer)
                bt=bytes(bt, "utf-8")
                # If the signal changes,then print the information
                if bt != self.last_action:
                    if bt == b'buy':
                        print('Send buy.')
                    elif bt == b'sell':
                        print('Send sell.')
                    self.last_action = bt 
                if wsk:
                    tk=b'\x81'
                    lgt=len(bt)
                    tk+=struct.pack('B',lgt)
                    bt=tk+bt
                self.sk_.sendall(bt)
            else:
                print("Disconnected！Try to connect the client...")
                try:
                    # reconnect
                    self.sk_.close()
                    self.sk.listen(1)
                    self.sk_, self.ad_ = self.sk.accept()
                    print('Reconnected:', self.sk_, self.ad_)
                    # handshake
                    while True:
                        sleep(0.5)
                        data = self.sk_.recv(2500)
                        data=data.decode("utf-8")
                        if '\r\n\r\n' in data:
                            self.handshake(data)
                            break
                    print("Reconnection succeed！")
                    # # clean the socket
                    # while True:
                    #     if not self.sk_.recv(2500):
                    #         break
                except Exception as e:
                    print(f"Reconnection failed: {e}")
        return self.re
        
    def __del__(self):
        print("server closed!")
        self.sk.close()
        if self.sk_ is not None:
            self.sk_.close()
            self.ad_.close()
    def handshake(self,data):
        try:           
            # Handshake
            key = data.split("\r\n")[4].split(": ")[1]
            GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
            ac = base64.b64encode(hashlib.sha1((key+GUID).encode('utf-8')).digest())
            response_tpl="HTTP/1.1 101 Switching Protocols\r\n" \
                        "Upgrade:websocket\r\n" \
                        "Connection: Upgrade\r\n" \
                        "Sec-WebSocket-Accept: %s\r\n" \
                        "WebSocket-Location: ws://%s/\r\n\r\n"
            response_str = response_tpl % (ac.decode('utf-8'), "127.0.0.1:10055")
            self.sk_.send(bytes(response_str, encoding='utf-8'))
            print('Handshake succeed!')
        except Exception as e:
            print(f"Connection failed: {e}")
            return None
        
    def is_connected(self):
        try:
        # Check remote 
            # remote_addr = self.sk_.getpeername()
            data = self.sk_.recv(1, socket.MSG_PEEK)
            return True
        except socket.error:
            self.last_action=None
            return False

if __name__=="__main__":
    sv = server_()

    while True:
        rem=sv.msg()
        if sv.stop:
            break
        sleep(0.5)