May 25, 2021 · 9 min read

How to Convert a PyTorch DataParallel Project to Use DistributedDataParallel

A practical guide to migrating an existing DataParallel training setup to DistributedDataParallel, covering the wrapper, model, data loading, logging, and metrics.

Originally published on Medium

Servers

Photo by Taylor Vick on Unsplash

Many posts discuss the differences between PyTorch DataParallel and DistributedDataParallel and why it is best practice to use DistributedDataParallel.

PyTorch documentation summarizes this as:

DataParallel is usually slower than DistributedDataParallel even on a single machine due to GIL contention across threads, per-iteration replicated model, and additional overhead introduced by scattering inputs and gathering outputs.

But the truth is that for development and prototyping, when working on a new project, or when building on top of an existing GitHub repository [which already uses DataParallel], it is simpler to use the DataParallel version out-of-the-box, especially when using a single server [that has one or several GPUs]. The most prominent advantage is that it is easier to debug with DataParallel.

Still, there will come the point where you would like to convert your existing DataParallel project to the big league and use DistributedDataParallel — this is apparently not trivial as it should be.

DDP

So in this post, I will not discuss further the advantages and disadvantages of each of the methods but rather focus on the practical aspects of converting an existing project, implemented with DataParallel, into a DistributedDataParallel project.

I’ll try to describe the different pieces as general as possible, but of course, your use case is unique 😃 and might require careful and specific adjustments based on how your implementation and code look like.

Let’s go…

The first step is the wrapper. DataParallel uses single-process with multi-thread, but DistributedDataParallel is multi-process by design, so the first thing we should do is to wrap the entire code — our main function — using a multi-process wrapper.

To do so, we are going to use a wrapper provided by FAIR in the Detectron2 repository. We will also need the comm file, which gives some nice functionality for handling distribution resources.

# Copyright (c) Facebook, Inc. and its affiliates.
# copied from detectron2/detectron2/engine/launch.py
import logging
from datetime import timedelta
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from detectron2.utils import comm

DEFAULT_TIMEOUT = timedelta(minutes=30)


def _find_free_port():
    import socket
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    return port


def launch(
    main_func,
    num_gpus_per_machine,
    num_machines=1,
    machine_rank=0,
    dist_url=None,
    args=(),
    timeout=DEFAULT_TIMEOUT,
):
    world_size = num_machines * num_gpus_per_machine
    if world_size > 1:
        if dist_url == "auto":
            assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
            port = _find_free_port()
            dist_url = f"tcp://127.0.0.1:{port}"

        mp.spawn(
            _distributed_worker,
            nprocs=num_gpus_per_machine,
            args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args, timeout),
            daemon=False,
        )
    else:
        main_func(*args)


def _distributed_worker(
    local_rank, main_func, world_size, num_gpus_per_machine,
    machine_rank, dist_url, args, timeout=DEFAULT_TIMEOUT,
):
    assert torch.cuda.is_available()
    global_rank = machine_rank * num_gpus_per_machine + local_rank
    dist.init_process_group(
        backend="NCCL",
        init_method=dist_url,
        world_size=world_size,
        rank=global_rank,
        timeout=timeout,
    )
    comm.synchronize()

    torch.cuda.set_device(local_rank)

    assert comm._LOCAL_PROCESS_GROUP is None
    num_machines = world_size // num_gpus_per_machine
    for i in range(num_machines):
        ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            comm._LOCAL_PROCESS_GROUP = pg

    main_func(*args)

The wrapper code copied from Detectron2 GitHub repository

The function that we need to run is called launch(). Let’s review the input parameters:

launch(
    main_func,
    num_gpus_per_machine=4,
    num_machines=1,
    machine_rank=0,
    dist_url='auto',
    args=()
)

We need to provide it with the number of machines we are using, for this post we will use a single machine with four GPUs, so num_machines=1 and num_gpus_per_machine=4. The machine_rank is used to specify the number [or the index] of the machine when using more than one machine, for our example machine_rank=0. dist_url is used to provide the master machine IP address when using several machines in distributed training, but since we only use one we can use dist_url="auto", which will use the localhost as IP and a free port. The main_func variable provides the main functionality of our code, this should be the main function used until now to run your flow in the DataParallel case. Finally, we can provide the args variable, these are all the input arguments to our main_func function.

Under the hood, the launch function will spin multiple processes using the torch multiprocessing module [torch.multiprocessing] and the spawn function, which will start the processes as new child processes — one for each GPU.

It is important to note that after running the launch function the exact same code will run in each child process simultaneously. So all pre-setup code, that needs to run only once at the beginning, should be called before running the launch function. For example, connecting to your experiments manager application, getting arguments from a config file that should be passed to all processes in advance, syncing data to the local server, etc.

Now that we have a working wrapper we can adjust our model. We need to “convert” our model, which was initialized on the CPU, to a DistributedDataParallel GPU model.

If in the DataParallel case we needed to wrap our model with:

model = torch.nn.DataParallel(model)

In the distributed case we will wrap the model with the DistributedDataParallel class, and this call will be done by each one of the processes independently. So, we first need to figure out what is the “rank” of our current process [rank is basically our GPU index], copy the model from the CPU to that specific GPU, and set it up as a DistributedDataParallel model:

from torch.nn.parallel import DistributedDataParallel as DDP
cur_rank = comm.get_local_rank()
model = DDP(model.to(cur_rank), device_ids=[cur_rank], broadcast_buffers=False)

The broadcast_buffers is an important parameter, it means whether or not to sync variables that are statistics-based between GPUs during training, such as the mean and variance of BatchNorm layers.

So we have the main wrapper and a model, next we need to adjust the data loading part.

In the DataParalle case, we have a single batch of size N, and during the forward pass it is automatically scattered evenly across our 4 GPUs, providing each GPU with its minibatch of size N/4. Now in the Distributed case, we need each GPU process to read only the samples relevant to its mini-batch out of the full samples batch. This is done with PyTorch DistributedSampler:

train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=seed)
valid_sampler = DistributedSampler(valid_dataset, shuffle=False, seed=seed)

The sampler splits the samples based on the number of processes we have and provides each process with the relevant sample indexes for his minibatch. After initializing the sampler, we need to provide the DataLoader class with this sampler instance and set its shuffle parameter to False. We will also need to call the sampler’s set_epoch(epoch) function as part of our epoch for-loop to get a different order of training samples in each epoch.

At this point, we basically have the three main components: a multiprocess wrapper, a model defined as a Distributed type, and a way to handle the data loading part. Next, we will review other important adjustments we can make in our code related to logging, reporting, and metrics calculation.

As all processes run the same code, using prints and a logger to report to the console [or a file] will repeatedly cause the same prints to be written. Although some might be important and relevant to report for each process, many are not, and we can use the following command to report a message only once in the main process [which is our rank 0 process]:

if comm.is_main_process():
    print("something that is printed only once")

This trick is also valuable when reporting metrics to an external tool or when saving the model weights. You can use it to make sure an action only happens once in the flow by the main process.

Now, what about the metrics?

Using the PyTorch DistributedDataParallel module, you don’t need to manage and “collect” [gather] the loss values from all processes to run the backward step, the loss.backward() will do it for you under the hood, and since it runs for each process, it will provide the same gradients correction to all model replications on all GPUs.

However, this is only relevant to the loss value that is used inside the backward pass. In order to report the actual loss values or their average, you will need to collect the loss per process and have the main process report their average:

loss = criterion(outputs, labels)
...
all_loss = comm.gather(loss)
if comm.is_main_process():
    report_avg(all_loss)

Now, all_loss has the loss values of all processes, and the main process can average them, accumulate them, and report what is needed.

Note that the gather command requires all processes to provide their values to the main process, slowing down the entire flow. A better way to handle this is to have each process collect and store its loss values and run the gather every number of iterations or at the end of an epoch.

Additional important comments:

  • A valuable trick is the ability to sync all processes at specific points in the code and force them to wait for each other. This can be the cure for the following scenario: the flow runs and for some reason just hangs. Looking at nvidia-smi you see that all GPUs are at 100% utilization while one GPU is at 0%. This is one example of a frustrating case that occurs when working with the distribution mode and might happen because the main GPU is processing something that takes more time, like saving and uploading your model. Forcing the processes to sync at some specific points in the flow might help, for instance, when going from the evaluation step back to the next training epoch. This can be done with a barrier function by using the following command: comm.synchronize() This function puts a barrier in the code, forcing all processes to wait for the rest to arrive and pass that point together 😃.

  • The DistributedDataParallel module transfers information between the processes, for this to happen PyTorch serializes the variables that are part of the data loader class. This requires that such variables are valid for serialization. The error is pretty informative, so you know when you are facing such an issue. I had issues with serializing part of my logger instances since it had a filter that uses a different class which was not suitable for serialization.

That’s it. I hope this guide will help you transition from a PyTorch DataParallel implementation to a DistributedDataParallel mechanism and enjoy the benefits and speed it provides.

Further reading: