Skip to content
BEE
Backend Engineering Essentials

[BEE-30092] Transfer Learning and Fine-Tuning Patterns

INFO

Transfer learning reuses representations learned on a large source dataset for a target task with fewer labeled examples, cutting the data and compute requirements by an order of magnitude — but only when the source and target domains are sufficiently similar and the fine-tuning strategy matches the available data size.

Context

ImageNet, trained on 1.2 million labeled images across 1 000 categories, produced an unexpected result: features learned by convolutional networks on this dataset transfer to almost any visual recognition task, even tasks with very different target categories. He et al.'s ResNet (arXiv:1512.03385, CVPR 2016) demonstrated that deep residual networks trained on ImageNet learn edge detectors, texture recognizers, and object part detectors in early layers — representations useful for medical imaging, satellite imagery, product defect detection, and dozens of other domains that share no categories with ImageNet.

Howard & Ruder (2018) generalized this insight to NLP with ULMFiT (arXiv:1801.06146, ACL 2018), demonstrating that a language model pre-trained on Wikipedia could be fine-tuned for text classification with as few as 100 labeled examples, matching what previous models required 10× more data to achieve. They introduced three fine-tuning techniques that remain standard: discriminative learning rates (each layer uses a different learning rate), gradual unfreezing (unfreeze one layer group at a time), and slanted triangular learning rates (a short warmup followed by linear decay). These three ideas collectively address catastrophic forgetting — the risk that gradient updates on the target task overwrite learned representations.

The practical consequence: a team with 5 000 labeled training examples can achieve results that previously required 50 000, by starting from a pre-trained backbone rather than random initialization. The cost is the choice of backbone, the right fine-tuning strategy, and awareness of when domain mismatch makes transfer harmful rather than helpful.

The Four-Quadrant Decision Framework

Andrej Karpathy's CS231n notes (http://cs231n.github.io/transfer-learning/) define the decision as a function of two axes: dataset size and domain similarity to the pre-training data.

Similar domainDifferent domain
Small datasetFeature extraction — freeze backbone, train head onlyDifficult case — consider in-domain pre-training or proxy task
Large datasetFine-tune most layers with lower LR for early layersFine-tune all layers, possibly with a higher LR for all

Feature extraction (freeze backbone): The pre-trained network is a fixed feature extractor. Only the final classification head is trained. This is the correct strategy when the target dataset is small (< 10 000 labeled samples) and the domain is similar to the pre-training source. The risk of fine-tuning with insufficient data is overfitting to the noise in the small target dataset.

Fine-tuning: Update weights in the backbone, not just the head. Use discriminative learning rates — lower LR for earlier layers, higher for later layers — because early layers contain generic features (edges, textures) that should change little, while later layers contain source-domain-specific features that need to adapt to the target domain.

Sun et al. (2017, arXiv:1707.02968) showed that more pre-training data consistently helps but with diminishing returns, and that the benefit depends on source-target domain similarity. When domains diverge substantially (e.g., applying ImageNet features to X-ray images), the first few layers remain useful but the deeper layers need full fine-tuning.

Feature Extraction with TIMM

TIMM (PyTorch Image Models, created by Ross Wightman, now maintained at github.com/huggingface/pytorch-image-models) provides > 700 pre-trained models with a unified API. num_classes=0 removes the classification head, returning the raw feature vector:

python
import timm
import torch
import torch.nn as nn
from torch.optim import AdamW

# Load backbone without classification head — returns feature vectors
backbone = timm.create_model(
    "resnet50",
    pretrained=True,
    num_classes=0,          # removes the final FC layer
    global_pool="avg",      # global average pooling → (batch, 2048)
)

# Freeze all backbone parameters
for param in backbone.parameters():
    param.requires_grad = False

# Small custom head — only these parameters are trained
num_target_classes = 10
head = nn.Sequential(
    nn.Dropout(p=0.3),
    nn.Linear(backbone.num_features, num_target_classes),
)

model = nn.Sequential(backbone, head).to("cuda")

# Optimizer sees ONLY head parameters — backbone is frozen
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-3,  # higher LR acceptable since only the head updates
)

For the feature extraction phase, a higher learning rate is acceptable because random-initialized head weights need large steps to converge, and the frozen backbone cannot be harmed. Once the head converges (typically 3–5 epochs), switch to fine-tuning.

Fine-Tuning with Discriminative Learning Rates

ULMFiT's key insight: earlier layers of a network encode generic, broadly transferable features (edges, textures, syntactic structure). Later layers encode source-domain-specific features. Earlier layers should change less during fine-tuning. Discriminative learning rates implement this by assigning a lower learning rate to earlier parameter groups:

python
import timm
import torch.nn as nn
from torch.optim import AdamW

model = timm.create_model("resnet50", pretrained=True, num_classes=NUM_CLASSES)

# Divide model into layer groups for discriminative LRs
# ResNet50: layer1/2 = generic features; layer3/4 = specific; fc = head
param_groups = [
    {"params": model.layer1.parameters(), "lr": 1e-5},  # earliest, slowest
    {"params": model.layer2.parameters(), "lr": 3e-5},
    {"params": model.layer3.parameters(), "lr": 1e-4},
    {"params": model.layer4.parameters(), "lr": 3e-4},
    {"params": model.fc.parameters(),     "lr": 1e-3},  # head, fastest
]

optimizer = AdamW(param_groups, weight_decay=0.01)

# Slanted triangular LR schedule: short linear warmup, then linear decay
# OneCycleLR implements this pattern
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=[g["lr"] for g in param_groups],
    steps_per_epoch=len(train_loader),
    epochs=NUM_EPOCHS,
    pct_start=0.1,  # 10% of training is warmup
)

The 10× rule-of-thumb for discriminative LRs: each successive layer group uses a learning rate 10× higher than the previous. The head uses the highest rate (1e-3), the earliest layers use the lowest (1e-5).

Gradual Unfreezing

Rather than unfreezing all layers simultaneously, unfreeze one layer group per training phase. This prevents catastrophic forgetting by allowing each unfrozen layer group to adapt incrementally:

python
def set_layer_group_requires_grad(model, group_name: str, requires_grad: bool):
    layer = getattr(model, group_name, None)
    if layer is None:
        return
    for param in layer.parameters():
        param.requires_grad = requires_grad

# Phase 1: Train only the head (all other layers frozen)
set_layer_group_requires_grad(model, "layer1", False)
set_layer_group_requires_grad(model, "layer2", False)
set_layer_group_requires_grad(model, "layer3", False)
set_layer_group_requires_grad(model, "layer4", False)
train(model, epochs=3, lr=1e-3)

# Phase 2: Unfreeze layer4, fine-tune with discriminative LRs
set_layer_group_requires_grad(model, "layer4", True)
train(model, epochs=3, lr_groups=[3e-4, 1e-3])  # layer4, head

# Phase 3: Unfreeze layer3
set_layer_group_requires_grad(model, "layer3", True)
train(model, epochs=3, lr_groups=[1e-4, 3e-4, 1e-3])

# Phase 4: Unfreeze everything
set_layer_group_requires_grad(model, "layer2", True)
set_layer_group_requires_grad(model, "layer1", True)
train(model, epochs=5, lr_groups=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3])

Gradual unfreezing adds training phases but substantially reduces total epochs needed versus fine-tuning everything from the start. The total compute cost is similar; the accuracy and stability improve because each phase reaches a good local minimum before the next layer group is unlocked.

Using Torchvision Pre-Trained Models

Torchvision provides official pre-trained weights with the weights= API (introduced in v0.13 to replace deprecated pretrained=True):

python
import torchvision.models as models
from torchvision.models import ResNet50_Weights, EfficientNet_B0_Weights

# ResNet50 with best available weights
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Replace final FC layer for a different number of classes
# ResNet50's fc: Linear(2048, 1000) — in_features depends on the architecture
in_features = model.fc.in_features  # 2048 for ResNet50
model.fc = torch.nn.Linear(in_features, NUM_TARGET_CLASSES)

# EfficientNet-B0 (smaller, faster than ResNet50)
model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
in_features = model.classifier[1].in_features  # 1280 for EfficientNet-B0
model.classifier[1] = torch.nn.Linear(in_features, NUM_TARGET_CLASSES)

# Use the same preprocessing as the pre-training dataset
transforms = ResNet50_Weights.IMAGENET1K_V2.transforms()

IMAGENET1K_V2 weights are preferred over V1 — they use improved training procedures (MixUp, CutMix, AutoAugment) and achieve higher accuracy at the same architecture. The weights.transforms() method returns the correct preprocessing pipeline for those weights, avoiding the common bug of using different normalization statistics.

Choosing the Learning Rate for Fine-Tuning

Leslie Smith's learning rate range test (arXiv:1506.01186) finds the optimal base learning rate empirically: run a single epoch with the learning rate increasing linearly from 1e-7 to 1e-1, plotting loss against learning rate. The optimal LR is in the zone of steepest loss decrease, just before the point where loss diverges.

As a starting point, fine-tuning learning rates are typically 10–100× smaller than training from scratch:

ScenarioTypical LR range
Training from scratch1e-3 to 1e-2
Fine-tuning (head only)1e-3 to 5e-3
Fine-tuning (all layers)1e-4 to 3e-4
Fine-tuning early layers1e-5 to 1e-4

Vision Transformers and Non-CNN Transfer Learning

Dosovitskiy et al.'s ViT (arXiv:2010.11929, ICLR 2021) extends transfer learning to transformer architectures: a vision transformer pre-trained on ImageNet-21k achieves 88.55% top-1 on ImageNet. HuggingFace provides ViT and Swin Transformer pre-trained weights for image classification:

python
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch

# Load a pre-trained ViT for fine-tuning on a custom dataset
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=NUM_TARGET_CLASSES,
    id2label={i: label for i, label in enumerate(class_names)},
    label2id={label: i for i, label in enumerate(class_names)},
    ignore_mismatched_sizes=True,  # replaces the pre-trained classification head
)

ViT fine-tuning requires the same gradual unfreezing and discriminative LR strategy as CNNs. The key difference: ViT is more sensitive to the learning rate schedule and typically requires a longer warmup (10–20% of training vs. 5% for CNN).

Common Mistakes

Using pretrained=True with the wrong preprocessing. Pre-trained weights are calibrated to specific normalization statistics (ImageNet mean/std). Failing to apply these statistics — or applying different ones — shifts the input distribution and degrades feature quality. Always use weights.transforms() from torchvision or the feature_extractor from HuggingFace to get the correct preprocessing for the specific checkpoint.

Fine-tuning with a uniform learning rate. Applying the same learning rate to all layers — particularly using a learning rate appropriate for training from scratch — destroys early-layer features that took millions of samples to learn. The early layers of a pre-trained ResNet encode edge and texture detectors that are near-universally useful. A learning rate of 1e-3 applied to layer1 will overwrite these in a few gradient steps. Use discriminative learning rates: 1e-5 for early layers, 1e-3 for the head.

Not replacing the classification head before loading weights. Calling model.fc = nn.Linear(in_features, new_num_classes) after loading weights initializes the new head randomly while preserving the backbone. Replacing the head before loading pre-trained weights results in loading weights into the old head shape, which fails if the number of classes differs.

Skipping gradual unfreezing and fine-tuning everything at once. Fine-tuning all layers simultaneously with a uniform learning rate converges to worse optima than gradual unfreezing on small-to-medium datasets. The head needs several epochs to reach a stable point before backbone weights can be usefully updated.

Assuming transfer learning always helps. When the source and target domains are very different — medical imaging from a model pre-trained on web photos, for example — ImageNet features in later layers may actively hurt by biasing the network toward irrelevant visual features. In this case, use in-domain pre-training (e.g., RadImageNet for radiology, SatMAE for satellite imagery) or limit transfer to early layers only.

References