How to build your very first SimSiam model with FashionMNIST
Contrastive learning has many use cases these days. From NLP and computer vision to recommendation systems, contrastive learning can be used to learn underlying data representations without any explicit labels, which can then be used for downstream classification, detection, similarity search, etc.
There are many online resources to help the audience understand the basic ideas of contrastive learning so that I won’t add one more blog post repeating the information. Instead, I will show you how to convert your supervised learning problem into a contrastive learning problem in this article. Specifically, I will start with a basic classification model for the FashionMNIST (MIT licence). Then, I will proceed to an advanced problem with limited training labels (e.g., reducing the full training set of 60,000 labels to 1,000). I will introduce SimSiam, a state-of-the-art method for contrastive learning, and show step-by-step instructions on modifying the original linear layers in the SimSiam style. Ultimately, I’ll show the results — SimSiam could improve the F1 score by 15% with a very basic configuration.
Image source: https://pxhere.com/en/photo/395408
Now, let’s start. First, we’ll load in the FashionMNIST dataset. A custom FashionMNIST class is used to obtain a subset of the training set named the finetune_dataset. The source code for the customer FashionMNIST class will be given at the end of this article.
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from FashionMNIST import FashionMNIST
train_dataset = FashionMNIST(“./FashionMNIST”,
train=True,
transform=transforms.ToTensor(),
download=True,
)
test_dataset = FashionMNIST(“./FashionMNIST”,
train=False,
transform=transforms.ToTensor(),
download=True,
)
finetune_dataset = FashionMNIST(“./FashionMNIST”,
train=True,
transform=transforms.ToTensor(),
download=True,
first_k=1000,
)
# Create a subplot with 4×4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
# Loop through each subplot and plot an image
for i in range(4):
for j in range(4):
image, label = train_dataset[i * 4 + j] # Get image and label
image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array
axs[i, j].imshow(image_numpy, cmap=’gray’) # Plot the image
axs[i, j].axis(‘off’) # Turn off axis
axs[i, j].set_title(f”Label: {label}”) # Set title with label
plt.tight_layout() # Adjust layout
plt.show() # Show plot
The code will show a grid of images from the train_dataset
First 16 images from the FashionMNIST training set. Image by author.
Next, we’ll define the supervised classification model. The architecture contains a backbone of convolutional layers and an MLP head of two linear layers. This will set a consistent baseline for the following experiments, as SimSiam will only replace the MLP head for contrastive learning purposes.
import torch.nn as nn
class supervised_classification(nn.Module):
def __init__(self):
super(supervised_classification, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.fc = nn.Sequential(
nn.Linear(128*4*4, 32),
nn.ReLU(),
nn.Linear(32, 10),
)
def forward(self, x):
x = self.backbone(x).view(-1, 128 * 4 * 4)
return self.fc(x)
We’ll train the model for 10 epochs:
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
“learning_rate”: 0.001,
“architecture”: “fashion mnist classification full training”,
“dataset”: “FashionMNIST”,
“epochs”: 10,
“batch_size”: 64,
}
wandb.init(
# set the wandb project where this run will be logged
project=”supervised_classification”,
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
supervised = supervised_classification()
optimizer = optim.SGD(supervised.parameters(),
lr=wandb_config[“learning_rate”],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset,
batch_size=wandb_config[“batch_size”],
shuffle=True,
)
# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config[“epochs”]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
prediction = supervised(image)
loss = loss_fun(prediction, target)
loss.backward()
optimizer.step()
wandb.log({“training loss”: loss})
torch.save(supervised.state_dict(), “weights/fully_supervised.pt”)
Using the classification_report from the scikit-learn package, we’ll get the following results:
from sklearn.metrics import classification_report
supervised = supervised_classification()
supervised.load_state_dict(torch.load(“weights/fully_supervised.pt”))
supervised.eval()
supervised.to(device)
target_list = []
prediction_list = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
prediction = supervised(image.to(device))
prediction_list.extend(torch.argmax(prediction, dim=1).detach().cpu().numpy())
target_list.extend(target.detach().cpu().numpy())
print(classification_report(target_list, prediction_list))
# Create a subplot with 4×4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
# Loop through each subplot and plot an image
for i in range(4):
for j in range(4):
image, label = test_dataset[i * 4 + j] # Get image and label
image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array
prediction = supervised(torch.unsqueeze(image, dim=0).to(device))
prediction = torch.argmax(prediction, dim=1).detach().cpu().numpy()
axs[i, j].imshow(image_numpy, cmap=’gray’) # Plot the image
axs[i, j].axis(‘off’) # Turn off axis
axs[i, j].set_title(f”Label: {label}, Pred: {prediction}”) # Set title with label
plt.tight_layout() # Adjust layout
plt.show() # Show plotClassification results of the fully supervised model. Image by author.
Now, let’s think about a new problem. What should we do if we’re given a limited subset of the training set labels, e.g., only 1000 images out of the total 60,000 images are annotated? The natural idea is to simply train the model on the limited annotated dataset. So without changing the backbone, we let the model train on the limited subset for 100 epochs (we increase the epochs to have a fair comparison to our SimSiam training):
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
“learning_rate”: 0.001,
“architecture”: “fashion mnist classification full training on finetune set”,
“dataset”: “FashionMNIST”,
“epochs”: 100,
“batch_size”: 64,
}
wandb.init(
# set the wandb project where this run will be logged
project=”supervised_classification”,
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
supervised = supervised_classification()
optimizer = optim.SGD(supervised.parameters(),
lr=wandb_config[“learning_rate”],
momentum=0.9,
weight_decay=1e-5,
)
finetune_dataloader = DataLoader(finetune_dataset,
batch_size=wandb_config[“batch_size”],
shuffle=True,
)
# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config[“epochs”]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader, total=len(finetune_dataloader))):
optimizer.zero_grad()
prediction = supervised(image)
loss = loss_fun(prediction, target)
loss.backward()
optimizer.step()
wandb.log({“training loss”: loss})
torch.save(supervised.state_dict(), “weights/fully_supervised_finetunedataset.pt”)Fully supervised training loss on the limited training set. Image by author.Quantitative evaluation results on the testing set. Note the performance drops more than 25% by reducing the training size. Image by author.
Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.
The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.
SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.
SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.
Comparison between SimCLR and SimSiam. Image by author.The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566
We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.
So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform “classification” but to learn the “representation,” so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:
import torch.nn as nn
import matplotlib.pyplot as plt
class SimSiam(nn.Module):
def __init__(self):
super(SimSiam, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 128*4*4),
)
def forward(self, x):
x = self.backbone(x)
x = x.view(-1, 128 * 4 * 4)
pred_output = self.prediction_mlp(x)
return x, pred_output
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
return -cos(pred, proj.detach()).mean()
The pseudo-code for training the SimSiam is given in the original paper below:
Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566
And we convert it into real training code:
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment
import wandb
wandb_config = {
“learning_rate”: 0.0001,
“architecture”: “simsiam”,
“dataset”: “FashionMNIST”,
“epochs”: 100,
“batch_size”: 256,
}
wandb.init(
# set the wandb project where this run will be logged
project=”simsiam”,
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
simsiam = SimSiam()
random_augmenter = RandAugment(num_ops=5)
optimizer = optim.SGD(simsiam.parameters(),
lr=wandb_config[“learning_rate”],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset, batch_size=wandb_config[“batch_size”], shuffle=True)
# Training loop
for epoch in range(wandb_config[“epochs”]):
simsiam.train()
print(f”Epoch {epoch}”)
train_loss = 0
for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0,
random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0
proj1, pred1 = simsiam(aug1)
proj2, pred2 = simsiam(aug2)
loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
loss.backward()
optimizer.step()
wandb.log({“training loss”: loss})
if (epoch+1) % 10 == 0:
torch.save(simsiam.state_dict(), f”weights/simsiam_epoch{epoch+1}.pt”)
We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.
Training loss for SimSiam. Image by author.
Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:
import tqdm
import numpy as np
import torch
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
simsiam = SimSiam()
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load(“weights/simsiam_epoch100.pt”))
simsiam.eval()
simsiam.to(device)
features = []
labels = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
proj, pred = simsiam(image.to(device))
features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
labels.extend(target.detach().cpu().numpy().tolist())
import plotly.express as px
import umap.umap_ as umap
reducer = umap.UMAP(n_components=3, n_neighbors=10, metric=”cosine”)
projections = reducer.fit_transform(np.array(features))
px.scatter(projections, x=0, y=1,
color=labels, labels={‘color’: ‘Fashion MNIST Labels’}
)The UMAP of the SimSiam representation over the testing set. Image by author.
It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The big purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation in the vision domain.
Now that we have the correct representations, how can they benefit our classification problem? We simply load the trained SimSiam backbone into our classification model. However, instead of fine-tuning the whole architecture in the limited training set, we fine-tuned the linear layers and froze the backbone because we didn’t want to corrupt the representation already learned.
import tqdm
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
wandb_config = {
“learning_rate”: 0.001,
“architecture”: “supervised learning with simsiam backbone”,
“dataset”: “FashionMNIST”,
“epochs”: 100,
“batch_size”: 64,
}
wandb.init(
# set the wandb project where this run will be logged
project=”simsiam-finetune”,
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
supervised = supervised_classification()
model_dict = supervised.state_dict()
simsiam_dict = {k: v for k, v in model_dict.items() if k in torch.load(“simsiam.pt”)}
supervised.load_state_dict(simsiam_dict, strict=False)
finetune_dataloader = DataLoader(finetune_dataset, batch_size=32, shuffle=True)
for param in supervised.backbone.parameters():
param.requires_grad = False
parameters = [para for para in supervised.parameters() if para.requires_grad]
optimizer = optim.SGD(parameters,
lr=wandb_config[“learning_rate”],
momentum=0.9,
weight_decay=1e-5,
)
# Training loop
for epoch in range(wandb_config[“epochs”]):
supervised.train()
train_loss = 0
for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader)):
optimizer.zero_grad()
prediction = supervised(image)
loss = nn.CrossEntropyLoss()(prediction, target)
loss.backward()
optimizer.step()
wandb.log({“training loss”: loss})
torch.save(supervised.state_dict(), “weights/supervised_with_simsiam.pt”)
Here is the evaluation result of the SimSiam-pre-trained classification model. The average F1 score is increased by 15% compared to the supervised-only method.
The classification scores of the SimSiam model fine-tune on the limited set. Image by author.
Summary. We showcase a simple but intuitive example, using FashionMNIST for contrastive learning. By using SimSiam for backbone pre-training and only fine-tuning the linear layers on the limited training set (which contains only 2% of the labels of the full training set), we increased the average F1 score by 15% over the fully supervised learning method. The trained weights, the notebook, and the customized FashionMNIST dataset class are all included in this GitHub repository.
Give it a try!
References:
Chen et al., Exploring simple siamese representation learning. CVPR 2021.Chen et al., A simple framework for contrastive learning of visual representations. ICML 2020.Chen et al., An Empirical Study of Training Self-Supervised Vision Transformers. ICCV 2021.Xiao et al., Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. arXiv preprint 2017. Github: https://github.com/zalandoresearch/fashion-mnist
A Practical Guide to Contrastive Learning was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.
How to build your very first SimSiam model with FashionMNISTContrastive learning has many use cases these days. From NLP and computer vision to recommendation systems, contrastive learning can be used to learn underlying data representations without any explicit labels, which can then be used for downstream classification, detection, similarity search, etc.There are many online resources to help the audience understand the basic ideas of contrastive learning so that I won’t add one more blog post repeating the information. Instead, I will show you how to convert your supervised learning problem into a contrastive learning problem in this article. Specifically, I will start with a basic classification model for the FashionMNIST (MIT licence). Then, I will proceed to an advanced problem with limited training labels (e.g., reducing the full training set of 60,000 labels to 1,000). I will introduce SimSiam, a state-of-the-art method for contrastive learning, and show step-by-step instructions on modifying the original linear layers in the SimSiam style. Ultimately, I’ll show the results — SimSiam could improve the F1 score by 15% with a very basic configuration.Image source: https://pxhere.com/en/photo/395408Now, let’s start. First, we’ll load in the FashionMNIST dataset. A custom FashionMNIST class is used to obtain a subset of the training set named the finetune_dataset. The source code for the customer FashionMNIST class will be given at the end of this article.import matplotlib.pyplot as pltimport torchvision.transforms as transformsfrom FashionMNIST import FashionMNISTtrain_dataset = FashionMNIST(“./FashionMNIST”, train=True, transform=transforms.ToTensor(), download=True, )test_dataset = FashionMNIST(“./FashionMNIST”, train=False, transform=transforms.ToTensor(), download=True, )finetune_dataset = FashionMNIST(“./FashionMNIST”, train=True, transform=transforms.ToTensor(), download=True, first_k=1000, )# Create a subplot with 4×4 gridfig, axs = plt.subplots(4, 4, figsize=(8, 8))# Loop through each subplot and plot an imagefor i in range(4): for j in range(4): image, label = train_dataset[i * 4 + j] # Get image and label image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array axs[i, j].imshow(image_numpy, cmap=’gray’) # Plot the image axs[i, j].axis(‘off’) # Turn off axis axs[i, j].set_title(f”Label: {label}”) # Set title with labelplt.tight_layout() # Adjust layoutplt.show() # Show plotThe code will show a grid of images from the train_datasetFirst 16 images from the FashionMNIST training set. Image by author.Next, we’ll define the supervised classification model. The architecture contains a backbone of convolutional layers and an MLP head of two linear layers. This will set a consistent baseline for the following experiments, as SimSiam will only replace the MLP head for contrastive learning purposes.import torch.nn as nnclass supervised_classification(nn.Module): def __init__(self): super(supervised_classification, self).__init__() self.backbone = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(128), ) self.fc = nn.Sequential( nn.Linear(128*4*4, 32), nn.ReLU(), nn.Linear(32, 10), ) def forward(self, x): x = self.backbone(x).view(-1, 128 * 4 * 4) return self.fc(x)We’ll train the model for 10 epochs:import tqdmimport torchimport torch.optim as optimfrom torch.utils.data import DataLoaderimport wandbwandb_config = { “learning_rate”: 0.001, “architecture”: “fashion mnist classification full training”, “dataset”: “FashionMNIST”, “epochs”: 10, “batch_size”: 64, }wandb.init( # set the wandb project where this run will be logged project=”supervised_classification”, # track hyperparameters and run metadata config=wandb_config,)# Initialize model and optimizerdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)supervised = supervised_classification()optimizer = optim.SGD(supervised.parameters(), lr=wandb_config[“learning_rate”], momentum=0.9, weight_decay=1e-5, )train_dataloader = DataLoader(train_dataset, batch_size=wandb_config[“batch_size”], shuffle=True, )# Training looploss_fun = nn.CrossEntropyLoss()for epoch in range(wandb_config[“epochs”]): supervised.train() train_loss = 0 for batch_idx, (image, target) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))): optimizer.zero_grad() prediction = supervised(image) loss = loss_fun(prediction, target) loss.backward() optimizer.step() wandb.log({“training loss”: loss}) torch.save(supervised.state_dict(), “weights/fully_supervised.pt”)Using the classification_report from the scikit-learn package, we’ll get the following results:from sklearn.metrics import classification_reportsupervised = supervised_classification() supervised.load_state_dict(torch.load(“weights/fully_supervised.pt”))supervised.eval()supervised.to(device)target_list = []prediction_list = []for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))): with torch.no_grad(): prediction = supervised(image.to(device)) prediction_list.extend(torch.argmax(prediction, dim=1).detach().cpu().numpy()) target_list.extend(target.detach().cpu().numpy())print(classification_report(target_list, prediction_list))# Create a subplot with 4×4 gridfig, axs = plt.subplots(4, 4, figsize=(8, 8))# Loop through each subplot and plot an imagefor i in range(4): for j in range(4): image, label = test_dataset[i * 4 + j] # Get image and label image_numpy = image.numpy().squeeze() # Convert image tensor to numpy array prediction = supervised(torch.unsqueeze(image, dim=0).to(device)) prediction = torch.argmax(prediction, dim=1).detach().cpu().numpy() axs[i, j].imshow(image_numpy, cmap=’gray’) # Plot the image axs[i, j].axis(‘off’) # Turn off axis axs[i, j].set_title(f”Label: {label}, Pred: {prediction}”) # Set title with labelplt.tight_layout() # Adjust layoutplt.show() # Show plotClassification results of the fully supervised model. Image by author.Now, let’s think about a new problem. What should we do if we’re given a limited subset of the training set labels, e.g., only 1000 images out of the total 60,000 images are annotated? The natural idea is to simply train the model on the limited annotated dataset. So without changing the backbone, we let the model train on the limited subset for 100 epochs (we increase the epochs to have a fair comparison to our SimSiam training):import tqdmimport torchimport torch.optim as optimfrom torch.utils.data import DataLoaderimport wandbwandb_config = { “learning_rate”: 0.001, “architecture”: “fashion mnist classification full training on finetune set”, “dataset”: “FashionMNIST”, “epochs”: 100, “batch_size”: 64, }wandb.init( # set the wandb project where this run will be logged project=”supervised_classification”, # track hyperparameters and run metadata config=wandb_config,)# Initialize model and optimizerdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)supervised = supervised_classification()optimizer = optim.SGD(supervised.parameters(), lr=wandb_config[“learning_rate”], momentum=0.9, weight_decay=1e-5, )finetune_dataloader = DataLoader(finetune_dataset, batch_size=wandb_config[“batch_size”], shuffle=True, )# Training looploss_fun = nn.CrossEntropyLoss()for epoch in range(wandb_config[“epochs”]): supervised.train() train_loss = 0 for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader, total=len(finetune_dataloader))): optimizer.zero_grad() prediction = supervised(image) loss = loss_fun(prediction, target) loss.backward() optimizer.step() wandb.log({“training loss”: loss}) torch.save(supervised.state_dict(), “weights/fully_supervised_finetunedataset.pt”)Fully supervised training loss on the limited training set. Image by author.Quantitative evaluation results on the testing set. Note the performance drops more than 25% by reducing the training size. Image by author.Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.Comparison between SimCLR and SimSiam. Image by author.The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform “classification” but to learn the “representation,” so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:import torch.nn as nnimport matplotlib.pyplot as pltclass SimSiam(nn.Module): def __init__(self): super(SimSiam, self).__init__() self.backbone = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(128), ) self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, 128*4*4), ) def forward(self, x): x = self.backbone(x) x = x.view(-1, 128 * 4 * 4) pred_output = self.prediction_mlp(x) return x, pred_output cos = nn.CosineSimilarity(dim=1, eps=1e-6)def negative_cosine_similarity_stopgradient(pred, proj): return -cos(pred, proj.detach()).mean()The pseudo-code for training the SimSiam is given in the original paper below:Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566And we convert it into real training code:import tqdmimport torchimport torch.optim as optimfrom torch.utils.data import DataLoaderfrom torchvision.transforms import RandAugmentimport wandbwandb_config = { “learning_rate”: 0.0001, “architecture”: “simsiam”, “dataset”: “FashionMNIST”, “epochs”: 100, “batch_size”: 256, }wandb.init( # set the wandb project where this run will be logged project=”simsiam”, # track hyperparameters and run metadata config=wandb_config,)# Initialize model and optimizerdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)simsiam = SimSiam()random_augmenter = RandAugment(num_ops=5)optimizer = optim.SGD(simsiam.parameters(), lr=wandb_config[“learning_rate”], momentum=0.9, weight_decay=1e-5, )train_dataloader = DataLoader(train_dataset, batch_size=wandb_config[“batch_size”], shuffle=True)# Training loopfor epoch in range(wandb_config[“epochs”]): simsiam.train() print(f”Epoch {epoch}”) train_loss = 0 for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))): optimizer.zero_grad() aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0, random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0 proj1, pred1 = simsiam(aug1) proj2, pred2 = simsiam(aug2) loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2 loss.backward() optimizer.step() wandb.log({“training loss”: loss}) if (epoch+1) % 10 == 0: torch.save(simsiam.state_dict(), f”weights/simsiam_epoch{epoch+1}.pt”)We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.Training loss for SimSiam. Image by author.Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:import tqdmimport numpy as npimport torchdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)simsiam = SimSiam() test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)simsiam.load_state_dict(torch.load(“weights/simsiam_epoch100.pt”))simsiam.eval()simsiam.to(device)features = []labels = []for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))): with torch.no_grad(): proj, pred = simsiam(image.to(device)) features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist()) labels.extend(target.detach().cpu().numpy().tolist())import plotly.express as pximport umap.umap_ as umapreducer = umap.UMAP(n_components=3, n_neighbors=10, metric=”cosine”)projections = reducer.fit_transform(np.array(features))px.scatter(projections, x=0, y=1, color=labels, labels={‘color’: ‘Fashion MNIST Labels’})The UMAP of the SimSiam representation over the testing set. Image by author.It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as “Sandal,” “Sneaker,” “Bag,” and “Ankle boot.” The big purple cluster corresponds to clothing classes like “T-shirt/top,” “Trousers,” “Pullover,” “Dress,” “Coat,” and “Shirt.” The SimSiam demonstrates learning a meaningful representation in the vision domain.Now that we have the correct representations, how can they benefit our classification problem? We simply load the trained SimSiam backbone into our classification model. However, instead of fine-tuning the whole architecture in the limited training set, we fine-tuned the linear layers and froze the backbone because we didn’t want to corrupt the representation already learned.import tqdmimport torchimport torch.optim as optimfrom torch.utils.data import DataLoaderimport wandbwandb_config = { “learning_rate”: 0.001, “architecture”: “supervised learning with simsiam backbone”, “dataset”: “FashionMNIST”, “epochs”: 100, “batch_size”: 64, }wandb.init( # set the wandb project where this run will be logged project=”simsiam-finetune”, # track hyperparameters and run metadata config=wandb_config,)# Initialize model and optimizerdevice = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)supervised = supervised_classification() model_dict = supervised.state_dict()simsiam_dict = {k: v for k, v in model_dict.items() if k in torch.load(“simsiam.pt”)}supervised.load_state_dict(simsiam_dict, strict=False)finetune_dataloader = DataLoader(finetune_dataset, batch_size=32, shuffle=True)for param in supervised.backbone.parameters(): param.requires_grad = Falseparameters = [para for para in supervised.parameters() if para.requires_grad]optimizer = optim.SGD(parameters, lr=wandb_config[“learning_rate”], momentum=0.9, weight_decay=1e-5, )# Training loopfor epoch in range(wandb_config[“epochs”]): supervised.train() train_loss = 0 for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader)): optimizer.zero_grad() prediction = supervised(image) loss = nn.CrossEntropyLoss()(prediction, target) loss.backward() optimizer.step() wandb.log({“training loss”: loss}) torch.save(supervised.state_dict(), “weights/supervised_with_simsiam.pt”)Here is the evaluation result of the SimSiam-pre-trained classification model. The average F1 score is increased by 15% compared to the supervised-only method.The classification scores of the SimSiam model fine-tune on the limited set. Image by author.Summary. We showcase a simple but intuitive example, using FashionMNIST for contrastive learning. By using SimSiam for backbone pre-training and only fine-tuning the linear layers on the limited training set (which contains only 2% of the labels of the full training set), we increased the average F1 score by 15% over the fully supervised learning method. The trained weights, the notebook, and the customized FashionMNIST dataset class are all included in this GitHub repository.Give it a try!References:Chen et al., Exploring simple siamese representation learning. CVPR 2021.Chen et al., A simple framework for contrastive learning of visual representations. ICML 2020.Chen et al., An Empirical Study of Training Self-Supervised Vision Transformers. ICCV 2021.Xiao et al., Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. arXiv preprint 2017. Github: https://github.com/zalandoresearch/fashion-mnistA Practical Guide to Contrastive Learning was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story. deep-learning, contrastive-learning, hands-on-tutorials, machine-learning, siamese-networks Towards Data Science – MediumRead More


0 Comments