
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import torch
from peft import get_peft_model, LoraConfig, PeftModel

dvc='cuda' if torch.cuda.is_available() else 'cpu'
print(dvc)
model_name_or_path='gpt2'

peft_config = LoraConfig(
                        #  task_type=None, 
                        #  inference_mode=False, 
                        #  r=8, 
                         lora_alpha=32, 
                         lora_dropout=0.1,
                         )

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

df = pd.read_csv('llm_data.csv')

# sentences = [' '.join(map(str, prices)) for prices in df.iloc[:-10,1:].values]
# with open('train.txt', 'w') as f:
#     for sentence in sentences:
#         f.write(sentence + '\n')

tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)

train_dataset = TextDataset(tokenizer=tokenizer,
                            file_path="train.txt", 
                            block_size=60)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(output_dir=peft_model_id,     
                                  overwrite_output_dir=True,    
                                  num_train_epochs=3,     
                                  per_device_train_batch_size=32,
                                  save_strategy= 'no',   
                                #   save_steps=10_000,    
                                #   save_total_limit=2,
                                #   load_best_model_at_end=True,
                                  )

model = GPT2LMHeadModel.from_pretrained(model_name_or_path)

model = get_peft_model(model, peft_config)


trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=data_collator,
                  train_dataset=train_dataset,)

trainer.train()

# model.save_pretrained(peft_model_id)

trainer.save_model(peft_model_id)

# config = PeftConfig.from_pretrained(peft_model_id)
model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)
model.to(dvc)
model.eval()

prompt = ' '.join(map(str, df.iloc[:,1:20].values[-1])) 
generated = tokenizer.decode(model.generate(tokenizer.encode(prompt, return_tensors='pt').to(dvc), 
                                            do_sample=True, 
                                            max_length=200)[0], 
                                            skip_special_tokens=True)

print(f"test the model:{generated}")