
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments,modeling_outputs
import torch
from torch import nn
import torch.nn.functional as F

dvc = 'cuda' if torch.cuda.is_available() else 'cpu'
print(dvc)
model_name_or_path = 'gpt2'
Tuned_model="gpt2_Adapter-tuning"

# Define the Adapter module
class Adapter(nn.Module):
    def __init__(self, in_features, bottleneck_features=64):
        super(Adapter, self).__init__()
        self.down_project = nn.Linear(in_features, bottleneck_features)
        self.up_project = nn.Linear(bottleneck_features, in_features)
        self.dropout = nn.Dropout(0.1)
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.down_project.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.down_project.bias, 0)
        nn.init.normal_(self.up_project.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.up_project.bias, 0)

    def forward(self, hidden_states):
        hidden_states = self.down_project(hidden_states)
        hidden_states = F.relu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.up_project(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

# Integrate the Adapter into the model
class GPT2LMHeadModelWithAdapters(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.adapters = nn.ModuleList([Adapter(config.n_embd) for _ in range(config.n_layer)])

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Apply adapters
        for i, adapter in enumerate(self.adapters):
            hidden_states = hidden_states + adapter(hidden_states)

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict the next token
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return modeling_outputs.CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
if __name__=="__main__":
# Load data
    df = pd.read_csv('llm_data.csv')

    tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)

    train_dataset = TextDataset(tokenizer=tokenizer,
                                file_path="train.txt", 
                                block_size=60)

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(output_dir=Tuned_model,     
                                    overwrite_output_dir=True,    
                                    num_train_epochs=3,     
                                    per_device_train_batch_size=32,
                                    save_strategy= 'no',   
                                    )

    # Initialize model with adapters
    model = GPT2LMHeadModelWithAdapters.from_pretrained(model_name_or_path)

    trainer = Trainer(model=model,
                    args=training_args,
                    data_collator=data_collator,
                    train_dataset=train_dataset,)

    trainer.train()

    trainer.save_model(Tuned_model)

    # Load the model for inference
    model = GPT2LMHeadModelWithAdapters.from_pretrained(Tuned_model)
    model.to(dvc)
    model.eval()

    prompt = ' '.join(map(str, df.iloc[:, 1:20].values[-1])) 
    generated = tokenizer.decode(model.generate(tokenizer.encode(prompt, return_tensors='pt').to(dvc), 
                                                do_sample=True, 
                                                max_length=200)[0], 
                                                skip_special_tokens=True)

    print(f"test the model:{generated}")

