Jul 20, 2022

AI

A Comprehensive Guide to Fine-Tuning Llama 2 7B on Custom Datasets

This blog post explores fine-tuning the Llama 2 7B model using Hugging Face Transformers and TRL, focusing on prompt analysis and optimization techniques like LoRA and 4-bit quantization for efficient GPU usage

Data Leap

Author

Introduction

In this blog post, we will explore the fine-tuning process for Llama 2 7B. Leveraging the capabilities of Hugging Face Transformers and TRL, we will explore two major techniques:

  • Analysis of the Base Model with Prompts: Exploring the complexities of prompts and prompt templates, and their effect on the performance of the model.

  • Optimizing Large Language Models through Fine-Tuning: Fine-tuning of the model with a focus on efficiency on a single GPU. Exploring techniques like LoRA and 4-bit quantization.


LLAMA

A collection of foundational language models, spanning from 7B to 65B parameters. The training regimen involves processing trillions of tokens, demonstrating the feasibility of achieving state-of-the-art models solely through the use of publicly accessible datasets, without reliance on proprietary or inaccessible data sources. Notably, LLaMA-13B exhibits superior performance compared to GPT-3 (175B) across a majority of benchmarks, while LLaMA-65B competes favourably with top-tier models such as Chinchilla70B and PaLM-540B.


LLaMA, an auto-regressive language model, is built on the transformer architecture. Like other prominent language models, LLaMA functions by taking a sequence of words as input and predicting the next word, recursively generating text.

Training: When a model is constructed from the ground up, it undergoes training. This process entails adjusting all the model's coefficients or weights to grasp patterns and relationships within the data.

Fine-Tuning: Fine-tuning assumes that the model has acquired a foundational understanding of language through training. This phase involves making targeted adjustments to tailor the model for a specific task or domain. Think of it as honing a well-educated model for a particular task.

Prompt Engineering: Prompt engineering revolves around the crafting of input prompts or questions to guide the LLM in generating desired outputs. It's about customizing the interaction with the model to elicit specific results.


Prompt Engineering:

Step 1: Load the Dataset

Let's load the alpaca dataset and use a sample of the dataset.


Dataset: The purpose of this dataset is to train a substantial language model to interpret instructions and generate code from natural language. Each entry in the dataset comprises:

  • An instruction that describes a specific task.

  • An input section, offering additional context when necessary for understanding the instruction.

  • The anticipated output that corresponds to the given instruction.


Step 2: Load the model

Let us load the model and tokenizer.

# model_id =  "linhvu/decapoda-research-llama-7b-hf"
model_id = "NousResearch/Llama-2-7b-hf"
# model_id = "meta-llama/Llama-2-13b-chat-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


Step 3: Create a prompt and make predictions.

Let us create a prompt and check the generated outputs

prompt = f"""### Instruction: {instruction}
### Input: {input}
### Response:
"""
inputs= tokenizer(prompt,return_tensors='pt').input_ids
outputs = tokenizer.decode(model.generate(inputs,max_length=200)[0],skip_special_tokens=True)

Output:


The base llama model cannot comprehend instructions and provide relevant answers for our specific task; instead, it simply echoes the inputs we provided until reaching the token limit.

In this blog, we will focus on fine-tuning the model. There are different ways to fine-tune a llm.


Fine Tuning

Full Fine-Tuning:

  • Entails training the entire pre-trained model with new data.

  • Involves updating all model layers and parameters during the fine-tuning process.

  • While it can yield high accuracy, it demands substantial computational resources and time.

  • This presents a risk of catastrophic forgetting, where updating all weights may cause the algorithm to unintentionally lose knowledge acquired during pretraining. This can result in varied outcomes, ranging from increased error margins to complete erasure of specific task memories, leading to suboptimal performance.

  • Best suited for scenarios where the target task significantly differs from the original pre-training task.

Parameter Efficient Fine-Tuning (PEFT), e.g., LoRA:

  • Concentrates on updating only a subset of the model's parameters.

  • Frequently, this involves freezing specific layers or portions of the model to prevent catastrophic forgetting. Alternatively, additional trainable layers may be introduced while keeping the original model's weights frozen.

  • Can facilitate faster fine-tuning with fewer computational resources, though it may sacrifice some accuracy compared to full fine-tuning.

  • Encompasses methods like LoRA, AdaLoRA, and Adaption Prompt (LLaMA Adapter).

  • Ideal when the new task shares similarities with the original pre-training task.

Quantization-Based Fine-Tuning (QLoRA):

  • Involves reducing the precision of model parameters, such as converting 32-bit floating-point values to 8-bit or 4-bit integers.

  • Results in reduced CPU and GPU memory requirements by a factor of 4x with 8-bit integers or 8x with 4-bit integers.

  • However, this reduction in precision may lead to a loss in performance.

  • Can be advantageous for deploying models on resource-constrained devices like mobile phones or edge devices, as it reduces memory usage and enables faster inference on hardware with reduced precision support.

We will be using a Quantization-Based Fine-Tuning below. We will continue from the previous steps.


Step 4: Dataset preparation

We need to modify our dataset in a manner consistent with how the model was trained. Let's create a new field called text which is in the format "### Instruction: <instruction> ### Input: <input> ### Response: <output>".

def format_instruction(instruction: str, input: str, output:str ):
    return f"""### Instruction: {instruction.strip()}
    ### Input:
    {input.strip()}
    ### Response:
    {output.strip()}
    """.strip()


def generate_instruction_dataset(data_point):

    return {
        "instruction": data_point["instruction"],
        "input": data_point["input"],
        "output": data_point["output"],
        "text": format_instruction(data_point["instruction"],data_point["input"],data_point["output"])
    }

def process_dataset(data: Dataset):
    return (
        data.shuffle(seed=42)
        .map(generate_instruction_dataset)
#         .remove_columns(['question_id', 'answer', '__index_level_0__'])
    )

## APPLYING PREPROCESSING ON WHOLE DATASET
dataset["train"] = process_dataset(dataset["train"])
dataset["test"] = process_dataset(dataset["test"])
dataset["validation"] = process_dataset(dataset["test"])


Step 5: Prepare the model for training

We will create a Lora config and prepare the model for training using TrainingArguments.

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], #specific to Llama models.
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

OUTPUT_DIR = "llama2-code-adapter"
training_arguments = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    optim="paged_adamw_32bit",
    logging_steps=1,
    learning_rate=1e-4,
    fp16=True,
    max_grad_norm=0.3,
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    warmup_ratio=0.05,
    save_strategy="epoch",
    group_by_length=True,
    output_dir=OUTPUT_DIR,
    save_safetensors=True,
    lr_scheduler_type="cosine",
    seed=42,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference


Step 6: Model Training

We will now train and save the Lora adapter.

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=validation_data,
    peft_config=lora_config,
    dataset_text_field="text",
    max_seq_length=1024,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()

peft_model_path="./peft-model-code"

trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)


Step 7: Predictions

Let's now load the saved model and use it to make predictions on some test examples to see if fine-tuning the model improved the capability of the model for our task.

peft_model_dir = "peft-model-code"

# load base LLM model and tokenizer
trained_model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(peft_model_dir)

prompt =  f"""### Instruction: {instruction.strip()}
     ### Input: {inputs.strip()}
    ### Response:
    """

input_ids = tokenizer(prompt, return_tensors='pt',truncation=True).input_ids.cuda()

outputs = trained_model.generate(input_ids=input_ids, max_new_tokens=100)
model_output= tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]

As we can see in the above image the model can generate much better outputs for our specific task.


Conclusion:

Even with a modest fine-tuning effort involving just 500 examples from our dataset with 5 minutes of training, we can see noticeable improvements in the model generations for our specific task.

This fine-tuning shows significant improvement compared to when we relied solely on prompts, as indicated by the absence of repetition and the correctness of the actual code aspects within the answers