Unbalanced Data Loading for Multi-Task Learning in PyTorch
A practical PyTorch guide for training multi-task models on multiple unbalanced datasets.
Originally published on Medium

Designed by Kjpargeter / Freepik
Working on multi-task learning (MTL) problems require a unique training setup, mainly in terms of data handling, model architecture, and performance evaluation metrics.
In this post, I am reviewing the data handling part. Specifically, how to train a multi-task learning model on multiple datasets and how to handle tasks with a highly unbalanced dataset.
I will describe my suggestion in three steps:
- Combining two (or more) datasets into a single PyTorch Dataset. This dataset will be the input for a PyTorch DataLoader.
- Modifying the batch preparation process to produce either one task in each batch or alternatively mix samples from both tasks in each batch.
- Handling the highly unbalanced datasets at the batch level by using a batch sampler as part of the DataLoader.
I am only reviewing Dataset and DataLoader related code, ignoring other important modules like the model, optimizer and metrics definition.
For simplicity, I am using a generic two dataset example. However, the number of datasets and the type of data should not affect the main setup. We can even use several instances of the same dataset, in case we have more than one set of labels for the same set of samples. For example, a dataset of images with an object class and a spatial location, or a face emotions dataset with facial emotion and age labeling per image.
A PyTorch Dataset class needs to implement the __getitem__() function. This function handles samples fetching and preparation for a given index. When using two datasets, it is then possible to have two different methods of creating samples. Hence, we can even use a single dataset, get samples with different labels, and change the samples processing scheme (the output samples should have the same shape since we stack them as a batch tensor).
First, let’s define two datasets to work with:
import torch
from torch.utils.data.dataset import ConcatDataset
class MyFirstDataset(torch.utils.data.Dataset):
def __init__(self):
# dummy dataset
self.samples = torch.cat((-torch.ones(5), torch.ones(5)))
def __getitem__(self, index):
# change this to your samples fetching logic
return self.samples[index]
def __len__(self):
# change this to return number of samples in your dataset
return self.samples.shape[0]
class MySecondDataset(torch.utils.data.Dataset):
def __init__(self):
# dummy dataset
self.samples = torch.cat((torch.ones(50) * 5, torch.ones(5) * -5))
def __getitem__(self, index):
# change this to your samples fetching logic
return self.samples[index]
def __len__(self):
# change this to return number of samples in your dataset
return self.samples.shape[0]
first_dataset = MyFirstDataset()
second_dataset = MySecondDataset()
concat_dataset = ConcatDataset([first_dataset, second_dataset])
We define two (binary) datasets, one with ten samples of ±1 (equally distributed), and the second with 55 samples, 50 samples of the digit 5, and 5 samples of the digit -5. These datasets are only for illustration. In real datasets, you should have both the samples and the labels, you will probably read the data from a database or parse it from data folders, but these simple datasets are enough to understand the main concepts.
Next, we need to define a DataLoader. We provide it with our concat_dataset and set the loader parameters, such as the batch size, and whether or not to shuffle the samples.
batch_size = 8
# basic dataloader
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
batch_size=batch_size,
shuffle=True)
for inputs in dataloader:
print(inputs)
The output of this part looks like:
tensor([ 5., 5., 5., 5., -5., 5., -5., 5.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -5., 5., 1., 5., -1., 5., -1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([ 5., 5., 5., 5., -5., 1., 5., 5.])
tensor([ 5., 5., 5., 1., 5., 5., 5., -1.])
tensor([ 5., 5., 5., 5., -1., 5., 1., 5.])
tensor([ 5., -5., 1., 5., 5., 5., 5., 5.])
tensor([5.])
Each batch is a tensor of 8 samples from our concat_dataset. The order is set randomly, and samples are selected from the pool of samples.
Until now, everything was relatively straight forward. The datasets are combined into a single one, and samples are randomly picked from both of the original datasets to construct the mini-batch. Now let’s try to control and manipulate the samples in each batch. We want to get samples from only one dataset in each mini-batch, switching between them every other batch.
import math
import torch
from torch.utils.data.sampler import RandomSampler
class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
"""
iterate over tasks and provide a random batch per task in each mini-batch
"""
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.number_of_datasets = len(dataset.datasets)
self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])
def __len__(self):
return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)
def __iter__(self):
samplers_list = []
sampler_iterators = []
for dataset_idx in range(self.number_of_datasets):
cur_dataset = self.dataset.datasets[dataset_idx]
sampler = RandomSampler(cur_dataset)
samplers_list.append(sampler)
cur_sampler_iterator = sampler.__iter__()
sampler_iterators.append(cur_sampler_iterator)
push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
step = self.batch_size * self.number_of_datasets
samples_to_grab = self.batch_size
# for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
epoch_samples = self.largest_dataset_size * self.number_of_datasets
final_samples_list = [] # this is a list of indexes from the combined dataset
for _ in range(0, epoch_samples, step):
for i in range(self.number_of_datasets):
cur_batch_sampler = sampler_iterators[i]
cur_samples = []
for _ in range(samples_to_grab):
try:
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
except StopIteration:
# got to the end of iterator - restart the iterator and continue to get samples
# until reaching "epoch_samples"
sampler_iterators[i] = samplers_list[i].__iter__()
cur_batch_sampler = sampler_iterators[i]
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
final_samples_list.extend(cur_samples)
return iter(final_samples_list)
This is the definition of a BatchSchedulerSampler class, which creates a new samples iterator. First, by creating a RandomSampler for each internal dataset. And second by pulling samples (actually samples indexes) from each internal dataset iterator. Thus, building a new list of samples indexes. Using a batch size of 8 means that from each dataset we need to fetch 8 samples.
Now let’s run and print the samples using a new DataLoader, which gets our BatchSchedulerSampler as an input sampler (shuffle can’t be set to True when working with a sampler).
import torch
from multi_task_batch_scheduler import BatchSchedulerSampler
batch_size = 8
# dataloader with BatchSchedulerSampler
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
sampler=BatchSchedulerSampler(dataset=concat_dataset,
batch_size=batch_size),
batch_size=batch_size,
shuffle=False)
for inputs in dataloader:
print(inputs)
The output now looks like this:
tensor([-1., -1., 1., 1., -1., 1., 1., -1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([ 1., -1., -1., -1., 1., 1., -1., 1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -1., 1., 1., 1., -1., 1., -1.])
tensor([ 5., 5., -5., 5., 5., -5., 5., 5.])
tensor([ 1., 1., -1., -1., 1., -1., 1., 1.])
tensor([5., 5., 5., 5., 5., 5., 5., 5.])
tensor([-1., -1., -1., -1., 1., 1., 1., -1.])
tensor([ 5., -5., 5., 5., 5., 5., -5., 5.])
tensor([-1., 1., -1., 1., -1., 1., 1., -1.])
tensor([ 5., 5., 5., 5., 5., -5., 5., 5.])
tensor([ 1., -1., -1., 1., 1., 1., 1., -1.])
tensor([5., 5., 5., 5., 5., 5., 5.])
Hurray!!! For each mini-batch we now get only one dataset samples. We can play with this type of scheduling in order to downsample or upsample more important tasks.
The remaining problem in our batches now comes from the second highly unbalanced dataset. This is often the case in MTL, having a main task and a few other satellite sub-tasks. Training the main task and sub-tasks together might lead to improve performance and contribute to the generalization of the overall model. The problem is that samples of the sub-tasks are often very sparse, having only a few positive (or negative) samples. Let’s use our previous logic but also forcing a balanced batch with respect to the distribution of samples in each task.
To handle the unbalanced issue, we need to replace the random sampler in the BatchSchedulerSampler class with an ImbalancedDatasetSampler (I am using a great implementation from this repository). This class handles the balancing of the dataset. We can also mix and use RandomSampler for some tasks and ImbalancedDatasetSampler for others.
import math
import torch
from torch.utils.data import RandomSampler
from sampler import ImbalancedDatasetSampler
class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler):
"""
ImbalancedDatasetSampler is taken from:
https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
In order to be able to show the usage of ImbalancedDatasetSampler in this example I am editing the _get_label
to fit my datasets
"""
def _get_label(self, dataset, idx):
return dataset.samples[idx].item()
class BalancedBatchSchedulerSampler(torch.utils.data.sampler.Sampler):
"""
iterate over tasks and provide a balanced batch per task in each mini-batch
"""
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.number_of_datasets = len(dataset.datasets)
self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])
def __len__(self):
return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)
def __iter__(self):
samplers_list = []
sampler_iterators = []
for dataset_idx in range(self.number_of_datasets):
cur_dataset = self.dataset.datasets[dataset_idx]
if dataset_idx == 0:
# the first dataset is kept at RandomSampler
sampler = RandomSampler(cur_dataset)
else:
# the second unbalanced dataset is changed
sampler = ExampleImbalancedDatasetSampler(cur_dataset)
samplers_list.append(sampler)
cur_sampler_iterator = sampler.__iter__()
sampler_iterators.append(cur_sampler_iterator)
push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
step = self.batch_size * self.number_of_datasets
samples_to_grab = self.batch_size
# for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
epoch_samples = self.largest_dataset_size * self.number_of_datasets
final_samples_list = [] # this is a list of indexes from the combined dataset
for _ in range(0, epoch_samples, step):
for i in range(self.number_of_datasets):
cur_batch_sampler = sampler_iterators[i]
cur_samples = []
for _ in range(samples_to_grab):
try:
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
except StopIteration:
# got to the end of iterator - restart the iterator and continue to get samples
# until reaching "epoch_samples"
sampler_iterators[i] = samplers_list[i].__iter__()
cur_batch_sampler = sampler_iterators[i]
cur_sample_org = cur_batch_sampler.__next__()
cur_sample = cur_sample_org + push_index_val[i]
cur_samples.append(cur_sample)
final_samples_list.extend(cur_samples)
return iter(final_samples_list)
We first create ExampleImbalancedDatasetSampler, which inherit from ImbalancedDatasetSampler, only modifying the _get_label function to fit our use case.
Next, we use BalancedBatchSchedulerSampler, which is similar to the previous BatchSchedulerSampler class, only replacing the usage of RandomSampler for the unbalanced task with the ExampleImbalancedDatasetSampler.
Let’s run the new DataLoader:
import torch
from balanced_sampler import BalancedBatchSchedulerSampler
batch_size = 8
# dataloader with BalancedBatchSchedulerSampler
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset,
sampler=BalancedBatchSchedulerSampler(dataset=concat_dataset,
batch_size=batch_size),
batch_size=batch_size,
shuffle=False)
for inputs in dataloader:
print(inputs)
The output looks like:
tensor([-1., 1., 1., -1., -1., -1., 1., -1.])
tensor([ 5., 5., 5., 5., -5., -5., -5., -5.])
tensor([ 1., 1., 1., -1., 1., -1., 1., 1.])
tensor([ 5., -5., 5., -5., -5., -5., 5., 5.])
tensor([-1., -1., 1., -1., -1., -1., -1., 1.])
tensor([-5., 5., 5., 5., 5., -5., 5., -5.])
tensor([-1., -1., 1., 1., 1., 1., -1., -1.])
tensor([-5., 5., 5., 5., 5., -5., 5., 5.])
tensor([ 1., -1., 1., 1., 1., -1., 1., -1.])
tensor([ 5., 5., 5., -5., 5., -5., 5., 5.])
tensor([-1., -1., -1., -1., 1., 1., 1., 1.])
tensor([-5., 5., 5., 5., 5., 5., -5., 5.])
tensor([-1., 1., -1., 1., 1., 1., 1., 1.])
tensor([-5., -5., 5., 5., -5., -5., 5.])
The mini-batches of the unbalanced task are now much more balanced.
There is a lot of room to play with this setup even further. We can combine the tasks in a balanced way, and by setting the samples_to_grab to 4, which is half of the batch size, we can get a mixed mini-batch with 4 samples taken from each task. To produce a ratio of 1:2 toward a more important task, we can set samples_to_grab=2 for the first task and samples_to_grab=6 for the second task.
That’s it. The full code can be downloaded from my repository.