import time
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from sklearn.metrics import mean_squared_error
import torch
import numpy as np
from peft import PeftModel
import matplotlib.pyplot as plt

df = pd.read_csv('llm_data.csv')
dvc='cuda' if torch.cuda.is_available() else 'cpu'
base_model='gpt2'
fine_tuning_path='./gpt2_stock'
lora_tuning_path ='./gpt2_LORA_None'

tokenizer = GPT2Tokenizer.from_pretrained(base_model)
model_fine_tuning = GPT2LMHeadModel.from_pretrained(fine_tuning_path).to(dvc)


model_lora_tuning = GPT2LMHeadModel.from_pretrained(base_model)
model_lora_tuning=PeftModel.from_pretrained(model_lora_tuning, lora_tuning_path).to(dvc)


input_data=df.iloc[:,1:20].values[-1]
true_prices= df.iloc[-1:,21:].values.tolist()[0]
prompt = ' '.join(map(str, input_data))

def generater(model):
    global true_prices
    model.eval()
    token=tokenizer.encode(prompt, return_tensors='pt').to(dvc)
    start_=time.time()
    generated = tokenizer.decode(model.generate(token, do_sample=True, max_length=200)[0], skip_special_tokens=True)
    end_=time.time()
    print(f'generate time:{end_-start_}')
    generated_prices=generated.split('\n')[0]
    generated_prices=list(map(float,generated_prices.split()))
    generated_prices=generated_prices[0:len(true_prices)]
    def trim_lists(a, b):
        min_len = min(len(a), len(b))
        return a[:min_len], b[:min_len]
    true_prices,generated_prices=trim_lists(true_prices,generated_prices)
    print(f"input data:{input_data}")
    print(f"true prices:{true_prices}")
    print(f"generated prices:{generated_prices}")
    mse = mean_squared_error(true_prices, generated_prices)
    print('MSE:', mse)
    rmse=np.sqrt(mse)
    nrmse=rmse/(np.max(true_prices)-np.min(generated_prices))
    print(f"RMSE:{rmse},NRMSE:{nrmse}")
    return generated_prices, mse, rmse, nrmse

def plot_(a,b,title):
    plt.figure(figsize=(10, 6))
    if title=='predication':
        plt.plot(true_prices, label='True Values', marker='o')
    plt.plot(a, label='fine_tuning', marker='x')
    plt.plot(b, label='lora_tuning', marker='s')
    plt.title(title)
    plt.xlabel('Index')
    plt.ylabel('Value')
    plt.legend()
    plt.savefig(f"{title}.png")

def groups_chart(a,b,models):
    metrics = ['Train Time(s)', 'Inference Time (s)', 'Memory Usage (GB)', 'MSE', 'RMSE', 'NRMSE']
    plt.figure(figsize=(10, 6))
    a=[101.7946,1.243,5.67,a[1],a[2],a[3]]
    b=[69.5605,0.877,4.10,b[1],b[2],b[3]]

    bar_width = 0.2

    r1 = np.arange(len(metrics))
    r2 = [x + bar_width for x in r1]

    plt.bar(r1, a, color='r', width=bar_width, edgecolor='grey', label=models[0])
    plt.bar(r2, b, color='b', width=bar_width, edgecolor='grey', label=models[1])

    plt.yscale('log')
    plt.xlabel('Metrics', fontweight='bold')
    plt.xticks([r + bar_width for r in range(len(metrics))], metrics)
    plt.ylabel('Values (log scale)', fontweight='bold')
    plt.title('Model Comparison')
    plt.legend()
    # plt.show()
    plt.savefig('Comparison.png')
    
fine_tuning_result = generater(model_fine_tuning)
lora_tuning_result = generater(model_lora_tuning)

plot_(fine_tuning_result[0],lora_tuning_result[0],title='predication')
groups_chart(fine_tuning_result,lora_tuning_result,models=['fine-tuning','lora-tuning'])