Fine-Tuning BERT for Text Classification
A hackable example with Python code
Although today’s 100B+ parameter transformer models are state-of-the-art in AI, there’s still much we can accomplish with smaller (< 1B parameter) models. In this article, I will walk through one such example, fine-tuning BERT (110M parameters) to classify phishing URLs. I’ll start by covering key concepts and then share example Python code.
Image from Canva.
Fine-tuning
Fine-tuning involves adapting a pre-trained model to a particular use case through additional training.
Pre-trained models are developed via unsupervised learning, which precludes the need for large-scale labeled datasets. Fine-tuned models can then exploit pre-trained model representations to significantly reduce training costs and improve model performance compared to training from scratch [1].
Fine-Tuning Large Language Models (LLMs)
Splitting the training process into multiple phases has led to today’s state-of-the-art transformer models, such as GPT-4o, Claude, and Llama 3.2. It also enables the democratization of AI since the expensive undertaking of model pre-training can be done by specialized research labs, who can then make these models publicly available for fine-tuning.
BERT
While model fine-tuning gained tremendous popularity post-ChatGPT, it’s been around since (at least) 2015 [2]. One of the early language models developed specifically for fine-tuning was Google’s BERT model, which was pre-trained on two unsupervised tasks: 1) masked language modeling (MLM) and 2) next sentence prediction [1].
The MLM pre-training task consists of predicting arbitrarily masked words in a sequence. This is in contrast to causal language modeling, which is restricted to predicting the word at the end of a sequence. Therefore, MLM enables models to leverage more context (i.e. text before AND after the masked word) to make predictions [1].
Next sentence prediction is important for downstream tasks that require understanding the relationship between two sentences (e.g., Question Answering and Semantic Similarity). This is implemented using special input tokens to distinguish the sentence prediction task from the MLM [1].
These pre-training tasks enable BERT to be fine-tuned on a wide range of tasks such as sentiment analysis, sentence similarity, question answering, named entity recognition, common sense reasoning, and many others [1].
Text Classification
Many of the tasks mentioned above (e.g. sentiment analysis, sentence similarity, named entity recognition) fall under the category of text classification, i.e., assigning a label to input text sequences.
There are countless practical applications of text classification, such as detecting spam in emails, categorizing IT support tickets, detecting toxic or harmful speech, and analyzing the sentiment of customer reviews. While each of these tasks is practically very different, their implementations are almost identical from a technical standpoint.
Example Code: Fine-tuning BERT for Phishing URL Identification
Here, we will walk through an example of BERT fine-tuning to classify phishing URLs. We will use the bert-base-uncased model freely available on the Hugging Face (HF) hub.
The model consists of 110M parameters, of which we will only train a small percentage. Therefore, this example should easily run on most consumer hardware (no GPU required).
The fine-tuned model is also available on the HF hub, and an example notebook is available on GitHub.
We’ll start by importing a few handy libraries.
from datasets import DatasetDict, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
import evaluate
import numpy as np
from transformers import DataCollatorWithPadding
Next, we’ll load the training dataset. It consists of 3,000 text-label pairs with a 70–15–15 train-test-validation split. The data are originally from here (open database license).
dataset_dict = load_dataset(“shawhin/phishing-site-classification”)
The Transformer library makes it super easy to load and adapt pre-trained models. Here’s what that looks like for the BERT model.
# define pre-trained model path
model_path = “google-bert/bert-base-uncased”
# load model tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# load model with binary classification head
id2label = {0: “Safe”, 1: “Not Safe”}
label2id = {“Safe”: 0, “Not Safe”: 1}
model = AutoModelForSequenceClassification.from_pretrained(model_path,
num_labels=2,
id2label=id2label,
label2id=label2id,)
When we load a model like this, all the parameters will be set as trainable by default. However, training all 110M parameters will be computationally costly and potentially unnecessary.
Instead, we can freeze most of the model parameters and only train the model’s final layer and classification head.
# freeze all base model parameters
for name, param in model.base_model.named_parameters():
param.requires_grad = False
# unfreeze base model pooling layers
for name, param in model.base_model.named_parameters():
if “pooler” in name:
param.requires_grad = True
Next, we will need to preprocess our data. This will consist of two key operations: tokenizing the URLs (i.e., converting them into integers) and truncating them.
# define text preprocessing
def preprocess_function(examples):
# return tokenized text with truncation
return tokenizer(examples[“text”], truncation=True)
# preprocess all datasets
tokenized_data = dataset_dict.map(preprocess_function, batched=True)
Another important step is creating a data collator that will dynamically pad token sequences in a batch during training so they have the same length. We can do this in one line of code.
# create data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
As a final step before training, we can define a function to compute a set of metrics to help us monitor training progress. Here, we will consider model accuracy and AUC.
# load metrics
accuracy = evaluate.load(“accuracy”)
auc_score = evaluate.load(“roc_auc”)
def compute_metrics(eval_pred):
# get predictions
predictions, labels = eval_pred
# apply softmax to get probabilities
probabilities = np.exp(predictions) / np.exp(predictions).sum(-1,
keepdims=True)
# use probabilities of the positive class for ROC AUC
positive_class_probs = probabilities[:, 1]
# compute auc
auc = np.round(auc_score.compute(prediction_scores=positive_class_probs,
references=labels)[‘roc_auc’],3)
# predict most probable class
predicted_classes = np.argmax(predictions, axis=1)
# compute accuracy
acc = np.round(accuracy.compute(predictions=predicted_classes,
references=labels)[‘accuracy’],3)
return {“Accuracy”: acc, “AUC”: auc}
Now, we are ready to fine-tune our model. We start by defining hyperparameters and other training arguments.
# hyperparameters
lr = 2e-4
batch_size = 8
num_epochs = 10
training_args = TrainingArguments(
output_dir=”bert-phishing-classifier_teacher”,
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
logging_strategy=”epoch”,
eval_strategy=”epoch”,
save_strategy=”epoch”,
load_best_model_at_end=True,
)
Then, we pass our training arguments into a trainer class and train the model.
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_data[“train”],
eval_dataset=tokenized_data[“test”],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
The training results are shown below. We can see that the training and validation loss are monotonically decreasing while the accuracy and AUC increase with each epoch.
Training results. Image by author.
As a final test, we can evaluate the performance of the model on the independent validation data, i.e., data not used for training or setting hyperparameters.
# apply model to validation dataset
predictions = trainer.predict(tokenized_data[“validation”])
# Extract the logits and labels from the predictions object
logits = predictions.predictions
labels = predictions.label_ids
# Use your compute_metrics function
metrics = compute_metrics((logits, labels))
print(metrics)
# >> {‘Accuracy’: 0.889, ‘AUC’: 0.946}
Bonus: Although a 110M parameter model is tiny compared to modern language models, we can reduce its computational requirements using model compression techniques. I cover how to reduce the memory footprint model by 7X in the article below.
Compressing Large Language Models (LLMs)
Conclusion
Fine-tuning pre-trained models is a powerful paradigm for developing better models at a lower cost than training them from scratch. Here, we saw how to do this with BERT using the Hugging Face Transformers library.
While the example code was for URL classification, it can be readily adapted to other text classification tasks.
More on LLMs 👇
My website: https://www.shawhintalebi.com/
[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[2] Semi-supervised Sequence Learning
Fine-Tuning BERT for Text Classification was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.