import os
import logging
from pathlib import Path
from transformers.onnx import export, FeaturesManager
from transformers import AutoConfig, AutoTokenizer, GPT2LMHeadModel, modeling_outputs
from torch import nn
import torch.nn.functional as F
import onnx
# Set up basic configuration for logging
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
# Define the Adapter class, which is a simple feed-forward network with dropout
class Adapter(nn.Module):
    def __init__(self, in_features, bottleneck_features=64):
        super(Adapter, self).__init__()
        # Down projection layer
        self.down_project = nn.Linear(in_features, bottleneck_features)
        # Up projection layer
        self.up_project = nn.Linear(bottleneck_features, in_features)
        # Dropout layer for regularization
        self.dropout = nn.Dropout(0.1)
        # Initialize weights of the layers
        self.init_weights()

    def init_weights(self):
        # Initialize weights for down projection layer
        nn.init.normal_(self.down_project.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.down_project.bias, 0)
        # Initialize weights for up projection layer
        nn.init.normal_(self.up_project.weight, mean=0.0, std=0.02)
        nn.init.constant_(self.up_project.bias, 0)

    def forward(self, hidden_states):
        # Apply down projection and ReLU activation
        hidden_states = self.down_project(hidden_states)
        hidden_states = F.relu(hidden_states)
        # Apply dropout
        hidden_states = self.dropout(hidden_states)
        # Apply up projection
        hidden_states = self.up_project(hidden_states)
        # Apply dropout again
        hidden_states = self.dropout(hidden_states)
        return hidden_states

# Define the GPT2LMHeadModelWithAdapters class, which inherits from GPT2LMHeadModel
# and adds adapter layers to each transformer layer
class GPT2LMHeadModelWithAdapters(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        # Create a list of adapter modules, one for each transformer layer
        self.adapters = nn.ModuleList([Adapter(config.n_embd) for _ in range(config.n_layer)])

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        # Get the outputs from the transformer
        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Apply each adapter to the hidden states
        for i, adapter in enumerate(self.adapters):
            hidden_states = hidden_states + adapter(hidden_states)

        # Get the logits for the language modeling head
        lm_logits = self.lm_head(hidden_states)

        # Compute loss if labels are provided
        loss = None
        if labels is not None:
            # Shift logits and labels for loss computation
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the logits and labels for cross-entropy loss
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # Return the outputs in the appropriate format
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return modeling_outputs.CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

# Function to load the model and tokenizer
def load_model_and_tokenizer(model_id):
    try:
        # Load the model configuration
        config = AutoConfig.from_pretrained(model_id)
        # Load the model
        model = GPT2LMHeadModelWithAdapters.from_pretrained(model_id)
        # Load the tokenizer
        # tokenizer = AutoTokenizer.from_pretrained('gpt2')
        return config, model,tokenizer
    except Exception as e:
        # Log any errors that occur during loading
        logging.error(f"Error loading model and tokenizer: {e}")
        raise

# Function to export the model to ONNX format
def export_model_to_onnx(model, config, tokenizer, output_path, opset):
    try:
        # Get the appropriate feature for the model
        model_type = config.model_type.replace("-", "_")
        feature = "causal-lm-with-past"
        # Get the ONNX configuration
        onnx_config_constructor = FeaturesManager.get_config(model_type, feature=feature)
        onnx_config = onnx_config_constructor(config)
        # Create the output directory if it doesn't exist
        if not os.path.exists(output_path.parent):
            os.makedirs(output_path.parent)

        # Export the model to ONNX
        export(
            model=model,
            config=onnx_config,
            opset=opset,
            output=output_path,
            preprocessor=tokenizer,
        )
        # Log success message
        logging.info(f"Model successfully converted to ONNX and saved in {output_path}")
    except Exception as e:
        # Log any errors that occur during export
        logging.error(f"Error exporting model to ONNX: {e}")
        raise

# Main function to orchestrate the process
def main():
    # Define the model ID, output path, and ONNX opset version
    model_id = "gpt2_Adapter-tuning"
    onnx_path = "./gpt2_onnx"
    out_path = Path(os.path.join(onnx_path, "gpt2_adapter_tuning.onnx"))
    opset = 14

    # Load the model and tokenizer
    config, model, tokenizer = load_model_and_tokenizer(model_id)
    # Export the model to ONNX
    export_model_to_onnx(model, config, tokenizer, out_path, opset)


def check_onnx():

    # Check the ONNX model
    onnx_model = onnx.load("gpt2_onnx/gpt2_adapter_tuning.onnx")
    onnx.checker.check_model(onnx_model)
    print("ONNX model check passed!")

def quantization():

    from onnxruntime.quantization import quantize_dynamic, QuantType

    # load model
    model_path = "gpt2_onnx/gpt2_adapter_tuning.onnx"
    onnx_model = onnx.load(model_path)

    #dynamic quantize INT4
    quantized_model_path = "gpt2_onnx/quantized_gpt2.onnx"
    quantize_dynamic(model_path, quantized_model_path, weight_type=QuantType.QUInt4)

    print(f"Save the quantized model to: {quantized_model_path}")

# Entry point of the script
if __name__ == "__main__":
    main()
    check_onnx()
    quantization()
