Finetuning Foundation Models

In the previous post, I talked about how to build a foundation model from scratch, following a GPT model architecture. In this one, I plan to write about a new type of training on top of that model implementation.
This type of training is famously called finetuning. The idea behind this ML technique is to train a pre-trained foundation model on a task-specific dataset. It's a type of transfer learning that uses a trained model as the base model and repurposes it for a specific task.
With a foundation model as the base model, we can create smaller models, targeting other problems, and there is no need to train them from scratch. We transform the generalist AI model into a domain-specific specialist model.
Understanding the problem
The example is inspired by the Build a Large Language Model book. We will use a spam or not spam dataset for a binary classification task.
I tried using the model built in the previous post, but because it's very rudimentary and isn't trained on a large data corpus, the classification task was very inefficient.
You will see that we will use a pretrained GPT model as the base model for finetuning and repurpose it for binary classification. We still need to build the dataset and train the model on it. And finally, evaluate and test on new texts to check how it is performing.
Create the dataset and dataloader
For finetuning, we'll use the UCI SMS Spam Collection dataset to train the pretrained model. In this part, we need to have some scripts to download the dataset, transform it into a Pandas dataframe, split the dataset into train, validation and test sets, build the dataset, and finish it by wrapping it into a dataloader so we can work with batches.
Downloading the dataset:
import requests
import zipfile
import os
import pandas as pd
from pathlib import Path
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "src/data/sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
if data_file_path.exists():
print(f"{data_file_path} already exists. Skipping download and extraction.")
return
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status()
with open(zip_path, "wb") as out_file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
out_file.write(chunk)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extracted_path)
original_file_path = Path(extracted_path) / "SMSSpamCollection"
os.rename(original_file_path, data_file_path)
print(f"File downloaded and saved as {data_file_path}")
try:
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (requests.exceptions.RequestException, TimeoutError) as e:
print(f"Primary URL failed: {e}. Trying backup URL...")
url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])
Here, we not only download the dataset, but wrap it into a dataframe to prepare it for splitting and the custom dataset and dataloader we are going to build.
| # | Label | Text |
|---|---|---|
| 0 | ham | Go until jurong point, crazy.. Available only ... |
| 1 | ham | Ok lar... Joking wif u oni... |
| 2 | spam | Free entry in 2 a wkly comp to win FA Cup fina... |
| 3 | ham | U dun say so early hor... U c already then say... |
| 4 | ham | Nah I don't think he goes to usf, he lives aro... |
In this dataset, we have 747 “spams” and 4825 “hams”. To balance it, we build a helper function:
def create_balanced_dataset(df):
num_spam = df[df["Label"] == "spam"].shape[0]
ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])
return balanced_df
df = create_balanced_dataset(df)
We use 747 as the maximum count for both spam and ham.
The label is still represented as a string. Let's one-hot encode it: ham should be 0 and spam 1:
df["Label"] = df["Label"].map({"ham": 0, "spam": 1})
With the dataset preprocessed, only the splitting strategy is missing.
Here, we use a random split:
def random_split(df, train_size, validation_size):
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
train_end = int(len(df) * train_size)
validation_end = train_end + int(len(df) * validation_size)
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
return train_df, validation_df, test_df
train_df, validation_df, test_df = random_split(df, 0.7, 0.1)
train_path = "src/finetuning/dataset/train.csv"
validation_path = "src/finetuning/dataset/validation.csv"
test_path = "src/finetuning/dataset/test.csv"
train_df.to_csv(train_path, index=None)
validation_df.to_csv(validation_path, index=None)
test_df.to_csv(test_path, index=None)
It's split into:
- Train: 70%
- Validation: 10%
- Test: 20%
And then, we transform it into a CSV file, which can be reused later on.
With the preprocessing done, we need to build a custom dataset that will be used for all 3 splits.
There are 3 main parts of this custom dataset:
- Encoding the input data
- Padding to the longest length
- Handling each item
Encoding the input data requires using the tokenizer. Because we can have different tokenizers, we let the user of this class decide which tokenizer to apply:
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
self.encoded_texts = [
tokenizer.encode(text) for text in self.data["Text"]
]
Moving on to padding:
class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
# ...
if max_length is None:
self.max_length = self._longest_encoded_length()
else:
self.max_length = max_length
self.encoded_texts = [
encoded_text[:self.max_length]
for encoded_text in self.encoded_texts
]
self.encoded_texts = [
encoded_text + [pad_token_id] *
(self.max_length - len(encoded_text))
for encoded_text in self.encoded_texts
]
def _longest_encoded_length(self):
max_length = 0
for encoded_text in self.encoded_texts:
encoded_length = len(encoded_text)
if encoded_length > max_length:
max_length = encoded_length
return max_length
The idea is simple: get the longest encoded length from the input and use it as the max length. Then, we pad each token ids. The pad token IDs to be added are the difference between the max length and the actual size of those token IDs.
The third and last part: handling the items:
class SpamDataset(Dataset):
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["Label"]
return (
torch.tensor(encoded, dtype=torch.long),
torch.tensor(label, dtype=torch.long)
)
def __len__(self):
return len(self.data)
We should return each item based on its index. It returns the encoded IDs and the label (target) as a tensor.
The __len__ method is just the length of the data.
Here is how we use it:
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = SpamDataset(
csv_file=train_path,
tokenizer=tokenizer
)
val_dataset = SpamDataset(
csv_file=validation_path,
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
test_dataset = SpamDataset(
csv_file=test_path,
max_length=train_dataset.max_length,
tokenizer=tokenizer
)
Then, we just build a dataloader so we can train the model using batches:
batch_size = 8
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
)
And that is. We have our dataset ready to be used in our model training.
Finetuning on a spam dataset
To fine-tune our model, we first need to build and repurpose it for a classification problem.
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 1024,
"drop_rate": 0.0,
"qkv_bias": True
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
model = GPTModel(BASE_CONFIG)
This is pretty similar to what we've already done before.
Because this model is not trained on a large amount of data, we load the parameters from GPT2:
def load_weights_into_gpt(gpt, params):
gpt.positional_embedding.weight = torch.nn.Parameter(torch.tensor(params["wpe"]))
gpt.token_embedding.weight = torch.nn.Parameter(torch.tensor(params["wte"]))
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split((params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1)
gpt.transformer_blocks[b].attention.W_query.weight = torch.nn.Parameter(torch.tensor(q_w.T))
gpt.transformer_blocks[b].attention.W_key.weight = torch.nn.Parameter(torch.tensor(k_w.T))
gpt.transformer_blocks[b].attention.W_value.weight = torch.nn.Parameter(torch.tensor(v_w.T))
q_b, k_b, v_b = np.split((params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1)
gpt.transformer_blocks[b].attention.W_query.bias = torch.nn.Parameter(torch.tensor(q_b))
gpt.transformer_blocks[b].attention.W_key.bias = torch.nn.Parameter(torch.tensor(k_b))
gpt.transformer_blocks[b].attention.W_value.bias = torch.nn.Parameter(torch.tensor(v_b))
gpt.transformer_blocks[b].attention.out_proj.weight = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["attn"]["c_proj"]["w"].T))
gpt.transformer_blocks[b].attention.out_proj.bias = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["attn"]["c_proj"]["b"]))
gpt.transformer_blocks[b].ff.layers[0].weight = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["mlp"]["c_fc"]["w"].T))
gpt.transformer_blocks[b].ff.layers[0].bias = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["mlp"]["c_fc"]["b"]))
gpt.transformer_blocks[b].ff.layers[2].weight = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["mlp"]["c_proj"]["w"].T))
gpt.transformer_blocks[b].ff.layers[2].bias = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["mlp"]["c_proj"]["b"]))
gpt.transformer_blocks[b].layer_norm1.scale = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["ln_1"]["g"]))
gpt.transformer_blocks[b].layer_norm1.shift = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["ln_1"]["b"]))
gpt.transformer_blocks[b].layer_norm2.scale = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["ln_2"]["g"]))
gpt.transformer_blocks[b].layer_norm2.shift = torch.nn.Parameter(
torch.tensor(params["blocks"][b]["ln_2"]["b"]))
gpt.final_norm.scale = torch.nn.Parameter(torch.tensor(params["g"]))
gpt.final_norm.shift = torch.nn.Parameter(torch.tensor(params["b"]))
gpt.out_head.weight = torch.nn.Parameter(torch.tensor(params["wte"]))
With weights loaded into our model, we replace the output head to do a binary classification.
num_classes = 2
def build_classifier(model, config, num_classes):
model.out_head = torch.nn.Linear(
in_features=config["emb_dim"],
out_features=num_classes
)
return model
build_classifier(model, BASE_CONFIG, num_classes)
The third step is to freeze parts of the model parameters to preserve patterns and learning from the pretrained model.
for param in model.parameters():
param.requires_grad = False
for param in model.transformer_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
We are doing a selective layer freezing for fine-tuning.
- Freeze everything: sets
requires_grad = Falseon all parameters, meaning no gradients will be computed and nothing will be updated during backprop. - Unfreeze the last transformer block: re-enables gradient updates only for the final transformer block (
transformer_blocks[-1]). - Unfreeze the final layer norm: also re-enables gradients for
final_norm.
Most of the parameters are frozen to preserve what was learned in the pretrained model and handle a possible catastrophic forgetting, but let some layers be updated to let the model learn certain patterns for the classification task.
Let's implement the model training:
def train_classifier(model, train_loader, val_loader, optimizer, num_epochs, eval_freq, eval_iter):
"""
Runs the full fine-tuning loop to train the classification head on the spam-detection task.
- For each epoch, iterates over all batches in `train_loader`,
- Computes the loss,
- Backpropagates,
- Updates weights with `optimizer`.
- Every `eval_freq` global steps, losses are evaluated on both splits and printed.
- At the end of each epoch, accuracy is computed and printed for both splits.
Output: a tuple (train_losses, val_losses, train_accuracies, val_accuracies)
where each element is a Python list of floats — losses sampled at every
`eval_freq` steps, and accuracies recorded once per epoch.
"""
train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
global_step = -1
for epoch in range(num_epochs):
model.train()
print(f"Epoch {epoch+1}: ")
for input_batch, target_batch in train_loader:
optimizer.zero_grad()
loss = calculate_loss_batch(input_batch, target_batch, model)
loss.backward()
optimizer.step()
global_step += 1
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(model, train_loader, val_loader, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Step {global_step:06d}: "
f"Train loss {train_loss:.3f} | "
f"Val loss {val_loss:.3f}")
train_accuracy = calculate_accuracy_loader(train_loader, model, num_batches=eval_iter)
val_accuracy = calculate_accuracy_loader(val_loader, model, num_batches=eval_iter)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%\n")
train_accuracies.append(train_accuracy)
val_accuracies.append(val_accuracy)
return train_losses, val_losses, train_accuracies, val_accuracies
The implementation uses a standard approach for model training in deep learning. We train on epochs and batches.
As the function states, these are the steps:
- For each epoch, iterate over all batches in
train_loader - Computes the loss
- Backpropagates
- Updates weights with
optimizer - Every
eval_freqglobal steps, losses are evaluated on both splits and printed - At the end of each epoch, accuracy is computed and printed for both splits
We need to understand what these 4 functions are doing: calculate_loss_batch, calculate_accuracy_loader, evaluate_model, and calculate_accuracy_loader.
calculate_loss_batch computes the cross-entropy for a single batch:
def calculate_loss_batch(input_batch, target_batch, model):
"""
Computes the cross-entropy loss for a single batch of inputs and targets.
- Passes `input_batch` through the model
- Extracts the last-token logits for each sequence
- Evaluates cross-entropy against `target_batch`
Output: a scalar PyTorch tensor containing the mean cross-entropy loss for
the batch.
"""
logits = model(input_batch)[:, -1, :]
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
This calculates the loss for the entire batch.
calculate_accuracy_loader computes the average loss for the dataloader:
def calculate_loss_loader(data_loader, model, num_batches=None):
"""
Computes the average cross-entropy loss of the model over a data loader.
- Accumulates per-batch losses from `calculate_loss_batch` across up to `num_batches` batches
- Returns their mean.
- Returns NaN if the loader is empty.
Output: a Python float representing the average cross-entropy loss; NaN when
the data loader contains no batches
"""
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calculate_loss_batch(
input_batch, target_batch, model
)
total_loss += loss.item()
else:
break
return total_loss / num_batches
It computes the loss for each batch and averages them (total loss divided by the number of batches).
evaluate_model will handle the model evaluation, the loss for the training and the validation sets.
def evaluate_model(model, train_loader, val_loader, eval_iter):
"""
Evaluates the model's loss on the training and validation sets without
updating weights.
- Temporarily switches the model to evaluation mode
- Computes the average loss over `eval_iter` batches from each loader via `calculate_loss_loader`
Output: a tuple (train_loss, val_loss) of Python floats,
each representing the average cross-entropy loss on the respective split
"""
model.eval()
with torch.no_grad():
train_loss = calculate_loss_loader(train_loader, model, num_batches=eval_iter)
val_loss = calculate_loss_loader(val_loader, model, num_batches=eval_iter)
model.train()
return train_loss, val_loss
This is called from time to time to evaluate the model to check how the model is learning through training.
calculate_accuracy_loader computes the classification accuracy:
def calculate_accuracy_loader(data_loader, model, num_batches=None):
"""
Computes the classification accuracy of the model over a data loader.
- Iterates through up to `num_batches` batches from `data_loader`
- Runs the model in evaluation mode without gradient tracking
- Compares predicted class labels (argmax of the last-token logits) against the ground-truth targets.
Output: a single float in [0.0, 1.0] representing the fraction of examples
classified correctly across all evaluated batches.
"""
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
with torch.no_grad():
logits = model(input_batch)[:, -1, :]
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += ((predicted_labels == target_batch).sum().item())
else:
break
return correct_predictions / num_examples
Accuracy is calculated by dividing the correct predictions by the number of examples.
With the model training implemented, we can put everything together and train the model:
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 5
train_losses, val_losses, train_accuracies, val_accuracies = train_classifier(
model, train_loader, val_loader, optimizer,
num_epochs=num_epochs, eval_freq=50, eval_iter=5
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
For 5 epochs, these are the results:
Epoch 1:
Step 000000: Train loss 0.730 | Val loss 0.766
Step 000050: Train loss 0.631 | Val loss 0.639
Step 000100: Train loss 0.482 | Val loss 0.515
Training accuracy: 90.00% | Validation accuracy: 97.50%
Epoch 2:
Step 000150: Train loss 0.287 | Val loss 0.170
Step 000200: Train loss 0.046 | Val loss 0.043
Step 000250: Train loss 0.030 | Val loss 0.060
Training accuracy: 95.00% | Validation accuracy: 92.50%
Epoch 3:
Step 000300: Train loss 0.114 | Val loss 0.125
Step 000350: Train loss 0.071 | Val loss 0.031
Training accuracy: 97.50% | Validation accuracy: 97.50%
Epoch 4:
Step 000400: Train loss 0.013 | Val loss 0.034
Step 000450: Train loss 0.095 | Val loss 0.060
Step 000500: Train loss 0.120 | Val loss 0.036
Training accuracy: 97.50% | Validation accuracy: 100.00%
Epoch 5:
Step 000550: Train loss 0.180 | Val loss 0.023
Step 000600: Train loss 0.040 | Val loss 0.031
Training accuracy: 100.00% | Validation accuracy: 100.00%
Training completed in 3.92 minutes.
We can validate the model learning and reduce loss over time. For this specific test, it took only ~4 minutes to run the entire training.
Before testing on new text inputs, let's just save the model:
torch.save(model.state_dict(), "src/finetuning/classifier.pth")
Classification: testing on new text inputs
In this final section, we will run some tests for the model.
First, let's build it:
CHOOSE_MODEL = "gpt2-small (124M)"
INPUT_PROMPT = "Every effort moves"
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 1024,
"drop_rate": 0.0,
"qkv_bias": True
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
model = GPTModel(BASE_CONFIG)
build_classifier(model, BASE_CONFIG, num_classes=2)
The model structure is here. Now, we load the parameters learned before:
model.load_state_dict(torch.load("src/finetuning/classifier.pth", weights_only=True))
Let's build the classification evaluation:
def classify_review(text, model, max_length=None, pad_token_id=50256):
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
input_ids = tokenizer.encode(text)
supported_context_length = model.positional_embedding.weight.shape[0]
input_ids = input_ids[:min(max_length, supported_context_length)]
input_ids += [pad_token_id] * (max_length - len(input_ids))
input_tensor = torch.tensor(input_ids).unsqueeze(0)
with torch.no_grad():
logits = model(input_tensor)[:, -1, :]
predicted_label = torch.argmax(logits, dim=-1).item()
return "spam" if predicted_label == 1 else "not spam"
The idea of this function is to call the model on the input text we are testing and classify if it is "spam" or "not spam".
We use the supported_context_length to make sure the model never receives more tokens than it was trained to handle.
It produces the logits and applies argmax over the two output logits back to a label: 1 → "spam" and 0 → "not spam".
Let's test it:
# === spam ===
text = (
"🛑 CONGRATULATIONS!! 🛑 Click HERE immediately to claim your"
" FREE $5,000 Walmart Gift Card before your link EXPIRES in 2 minutes!!! ⏳💰"
)
classify_review(text, model, max_length=train_dataset.max_length)
# === // ===
# === not spam (but can be considered as spam) ===
text = (
"Hey there, I noticed your profile online and wanted to briefly"
" touch base about our AI-powered synergy solutions that can scale"
" your revenue by 400% this quarter—got 5 mins for a quick call"
" tomorrow at 10 AM?"
)
classify_review(text, model, max_length=train_dataset.max_length)
# === // ===
# === not spam ===
text = (
"Hey, just wanted to check if we're still on"
" for dinner tonight? Let me know!"
)
classify_review(text, model, max_length=train_dataset.max_length)
# === // ===
The second test could be considered spam depending on the user. But the other ones are more accurate.
Resources
- Self-Attention, Foundation Models, and the GPT Architecture from Scratch
- LLM from Scratch Repo
- Build a Large Language Model
- Mastering PyTorch: From Linear Regression to Computer Vision