• Tutorials >
  • (advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS
Shortcuts

(advanced) PyTorch 1.0 Distributed Trainer with Amazon AWS

Author: Nathan Inkawhich

Edited by: Teng Li

In this tutorial we will show how to setup, code, and run a PyTorch 1.0 distributed trainer across two multi-gpu Amazon AWS nodes. We will start with describing the AWS setup, then the PyTorch environment configuration, and finally the code for the distributed trainer. Hopefully you will find that there is actually very little code change required to extend your current training code to a distributed application, and most of the work is in the one-time environment setup.

Amazon AWS Setup

In this tutorial we will run distributed training across two multi-gpu nodes. In this section we will first cover how to create the nodes, then how to setup the security group so the nodes can communicate with eachother.

Creating the Nodes

In Amazon AWS, there are seven steps to creating an instance. To get started, login and select Launch Instance.

Step 1: Choose an Amazon Machine Image (AMI) - Here we will select the Deep Learning AMI (Ubuntu) Version 14.0. As described, this instance comes with many of the most popular deep learning frameworks installed and is preconfigured with CUDA, cuDNN, and NCCL. It is a very good starting point for this tutorial.

Step 2: Choose an Instance Type - Now, select the GPU compute unit called p2.8xlarge. Notice, each of these instances has a different cost but this instance provides 8 NVIDIA Tesla K80 GPUs per node, and provides a good architecture for multi-gpu distributed training.

Step 3: Configure Instance Details - The only setting to change here is increasing the Number of instances to 2. All other configurations may be left at default.

Step 4: Add Storage - Notice, by default these nodes do not come with a lot of storage (only 75 GB). For this tutorial, since we are only using the STL-10 dataset, this is plenty of storage. But, if you want to train on a larger dataset such as ImageNet, you will have to add much more storage just to fit the dataset and any trained models you wish to save.

Step 5: Add Tags - Nothing to be done here, just move on.

Step 6: Configure Security Group - This is a critical step in the configuration process. By default two nodes in the same security group would not be able to communicate in the distributed training setting. Here, we want to create a new security group for the two nodes to be in. However, we cannot finish configuring in this step. For now, just remember your new security group name (e.g. launch-wizard-12) then move on to Step 7.

Step 7: Review Instance Launch - Here, review the instance then launch it. By default, this will automatically start initializing the two instances. You can monitor the initialization progress from the dashboard.

Configure Security Group

Recall that we were not able to properly configure the security group when creating the instances. Once you have launched the instance, select the Network & Security > Security Groups tab in the EC2 dashboard. This will bring up a list of security groups you have access to. Select the new security group you created in Step 6 (i.e. launch-wizard-12), which will bring up tabs called Description, Inbound, Outbound, and Tags. First, select the Inbound tab and Edit to add a rule to allow “All Traffic” from “Sources” in the launch-wizard-12 security group. Then select the Outbound tab and do the exact same thing. Now, we have effectively allowed all Inbound and Outbound traffic of all types between nodes in the launch-wizard-12 security group.

Necessary Information

Before continuing, we must find and remember the IP addresses of both nodes. In the EC2 dashboard find your running instances. For both instances, write down the IPv4 Public IP and the Private IPs. For the remainder of the document, we will refer to these as the node0-publicIP, node0-privateIP, node1-publicIP, and node1-privateIP. The public IPs are the addresses we will use to SSH in, and the private IPs will be used for inter-node communication.

Environment Setup

The next critical step is the setup of each node. Unfortunately, we cannot configure both nodes at the same time, so this process must be done on each node separately. However, this is a one time setup, so once you have the nodes configured properly you will not have to reconfigure for future distributed training projects.

The first step, once logged onto the node, is to create a new conda environment with python 3.6 and numpy. Once created activate the environment.

$ conda create -n nightly_pt python=3.6 numpy
$ source activate nightly_pt

Next, we will install a nightly build of Cuda 9.0 enabled PyTorch with pip in the conda environment.

$ pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cu90/torch_nightly.html

We must also install torchvision so we can use the torchvision model and dataset. At this time, we must build torchvision from source as the pip installation will by default install an old version of PyTorch on top of the nightly build we just installed.

$ cd
$ git clone https://github.com/pytorch/vision.git
$ cd vision
$ python setup.py install

And finally, VERY IMPORTANT step is to set the network interface name for the NCCL socket. This is set with the environment variable NCCL_SOCKET_IFNAME. To get the correct name, run the ifconfig command on the node and look at the interface name that corresponds to the node’s privateIP (e.g. ens3). Then set the environment variable as

$ export NCCL_SOCKET_IFNAME=ens3

Remember, do this on both nodes. You may also consider adding the NCCL_SOCKET_IFNAME setting to your .bashrc. An important observation is that we did not setup a shared filesystem between the nodes. Therefore, each node will have to have a copy of the code and a copy of the datasets. For more information about setting up a shared network filesystem between nodes, see here.

Distributed Training Code

With the instances running and the environments setup we can now get into the training code. Most of the code here has been taken from the PyTorch ImageNet Example which also supports distributed training. This code provides a good starting point for a custom trainer as it has much of the boilerplate training loop, validation loop, and accuracy tracking functionality. However, you will notice that the argument parsing and other non-essential functions have been stripped out for simplicity.

In this example we will use torchvision.models.resnet18 model and will train it on the torchvision.datasets.STL10 dataset. To accomodate for the dimensionality mismatch of STL-10 with Resnet18, we will resize each image to 224x224 with a transform. Notice, the choice of model and dataset are orthogonal to the distributed training code, you may use any dataset and model you wish and the process is the same. Lets get started by first handling the imports and talking about some helper functions. Then we will define the train and test functions, which have been largely taken from the ImageNet Example. At the end, we will build the main part of the code which handles the distributed training setup. And finally, we will discuss how to actually run the code.

Imports

The important distributed training specific imports here are torch.nn.parallel, torch.distributed, torch.utils.data.distributed, and torch.multiprocessing. It is also important to set the multiprocessing start method to spawn or forkserver (only supported in Python 3), as the default is fork which may cause deadlocks when using multiple worker processes for dataloading.

import time
import sys
import torch

import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from torch.multiprocessing import Pool, Process

Helper Functions

We must also define some helper functions and classes that will make training easier. The AverageMeter class tracks training statistics like accuracy and iteration count. The accuracy function computes and returns the top-k accuracy of the model so we can track learning progress. Both are provided for training convenience but neither are distributed training specific.

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

Train Functions

To simplify the main loop, it is best to separate a training epoch step into a function called train. This function trains the input model for one epoch of the train_loader. The only distributed training artifact in this function is setting the non_blocking attributes of the data and label tensors to True before the forward pass. This allows asynchronous GPU copies of the data meaning transfers can be overlapped with computation. This function also outputs training statistics along the way so we can track progress throughout the epoch.

The other function to define here is adjust_learning_rate, which decays the initial learning rate at a fixed schedule. This is another boilerplate trainer function that is useful to train accurate models.

def train(train_loader, model, criterion, optimizer, epoch):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        # Create non_blocking tensors for distributed training
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1[0], input.size(0))
        top5.update(prec5[0], input.size(0))

        # compute gradients in a backward pass
        optimizer.zero_grad()
        loss.backward()

        # Call step of optimizer to update model params
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % 10 == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1, top5=top5))

def adjust_learning_rate(initial_lr, optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = initial_lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Validation Function

To track generalization performance and simplify the main loop further we can also extract the validation step into a function called validate. This function runs a full validation step of the input model on the input validation dataloader and returns the top-1 accuracy of the model on the validation set. Again, you will notice the only distributed training feature here is setting non_blocking=True for the training data and labels before they are passed to the model.

def validate(val_loader, model, criterion):

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):

            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 100 == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg

Inputs

With the helper functions out of the way, now we have reached the interesting part. Here is where we will define the inputs for the run. Some of the inputs are standard model training inputs such as batch size and number of training epochs, and some are specific to our distributed training task. The required inputs are:

  • batch_size - batch size for each process in the distributed training group. Total batch size across distributed model is batch_size*world_size

  • workers - number of worker processes used with the dataloaders in each process

  • num_epochs - total number of epochs to train for

  • starting_lr - starting learning rate for training

  • world_size - number of processes in the distributed training environment

  • dist_backend - backend to use for distributed training communication (i.e. NCCL, Gloo, MPI, etc.). In this tutorial, since we are using several multi-gpu nodes, NCCL is suggested.

  • dist_url - URL to specify the initialization method of the process group. This may contain the IP address and port of the rank0 process or be a non-existant file on a shared file system. Here, since we do not have a shared file system this will incorporate the node0-privateIP and the port on node0 to use.

print("Collect Inputs...")

# Batch Size for training and testing
batch_size = 32

# Number of additional worker processes for dataloading
workers = 2

# Number of epochs to train for
num_epochs = 2

# Starting Learning Rate
starting_lr = 0.1

# Number of distributed processes
world_size = 4

# Distributed backend type
dist_backend = 'nccl'

# Url used to setup distributed training
dist_url = "tcp://172.31.22.234:23456"

Initialize process group

One of the most important parts of distributed training in PyTorch is to properly setup the process group, which is the first step in initializing the torch.distributed package. To do this, we will use the torch.distributed.init_process_group function which takes several inputs. First, a backend input which specifies the backend to use (i.e. NCCL, Gloo, MPI, etc.). An init_method input which is either a url containing the address and port of the rank0 machine or a path to a non-existant file on the shared file system. Note, to use the file init_method, all machines must have access to the file, similarly for the url method, all machines must be able to communicate on the network so make sure to configure any firewalls and network settings to accomodate. The init_process_group function also takes rank and world_size arguments which specify the rank of this process when run and the number of processes in the collective, respectively. The init_method input can also be “env://”. In this case, the address and port of the rank0 machine will be read from the following two environment variables respectively: MASTER_ADDR, MASTER_PORT. If rank and world_size arguments are not specified in the init_process_group function, they both can be read from the following two environment variables respectively as well: RANK, WORLD_SIZE.

Another important step, especially when each node has multiple gpus is to set the local_rank of this process. For example, if you have two nodes, each with 8 GPUs and you wish to train with all of them then \(world\_size=16\) and each node will have a process with local rank 0-7. This local_rank is used to set the device (i.e. which GPU to use) for the process and later used to set the device when creating a distributed data parallel model. It is also recommended to use NCCL backend in this hypothetical environment as NCCL is preferred for multi-gpu nodes.

print("Initialize Process Group...")
# Initialize Process Group
# v1 - init with url
dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=int(sys.argv[1]), world_size=world_size)
# v2 - init with file
# dist.init_process_group(backend="nccl", init_method="file:///home/ubuntu/pt-distributed-tutorial/trainfile", rank=int(sys.argv[1]), world_size=world_size)
# v3 - init with environment variables
# dist.init_process_group(backend="nccl", init_method="env://", rank=int(sys.argv[1]), world_size=world_size)


# Establish Local Rank and set device on this node
local_rank = int(sys.argv[2])
dp_device_ids = [local_rank]
torch.cuda.set_device(local_rank)

Initialize Model

The next major step is to initialize the model to be trained. Here, we will use a resnet18 model from torchvision.models but any model may be used. First, we initialize the model and place it in GPU memory. Next, we make the model DistributedDataParallel, which handles the distribution of the data to and from the model and is critical for distributed training. The DistributedDataParallel module also handles the averaging of gradients across the world, so we do not have to explicitly average the gradients in the training step.

It is important to note that this is a blocking function, meaning program execution will wait at this function until world_size processes have joined the process group. Also, notice we pass our device ids list as a parameter which contains the local rank (i.e. GPU) we are using. Finally, we specify the loss function and optimizer to train the model with.

print("Initialize Model...")
# Construct Model
model = models.resnet18(pretrained=False).cuda()
# Make model DistributedDataParallel
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=dp_device_ids, output_device=local_rank)

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), starting_lr, momentum=0.9, weight_decay=1e-4)

Initialize Dataloaders

The last step in preparation for the training is to specify which dataset to use. Here we use the STL-10 dataset from torchvision.datasets.STL10. The STL10 dataset is a 10 class dataset of 96x96px color images. For use with our model, we resize the images to 224x224px in the transform. One distributed training specific item in this section is the use of the DistributedSampler for the training set, which is designed to be used in conjunction with DistributedDataParallel models. This object handles the partitioning of the dataset across the distributed environment so that not all models are training on the same subset of data, which would be counterproductive. Finally, we create the DataLoader’s which are responsible for feeding the data to the processes.

The STL-10 dataset will automatically download on the nodes if they are not present. If you wish to use your own dataset you should download the data, write your own dataset handler, and construct a dataloader for your dataset here.

print("Initialize Dataloaders...")
# Define the transform for the data. Notice, we must resize to 224x224 with this dataset and model.
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Initialize Datasets. STL10 will automatically download if not present
trainset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
valset = datasets.STL10(root='./data', split='test', download=True, transform=transform)

# Create DistributedSampler to handle distributing the dataset across nodes when training
# This can only be called after torch.distributed.init_process_group is called
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)

# Create the Dataloaders to feed data to the training and validation steps
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=False, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=False)

Training Loop

The last step is to define the training loop. We have already done most of the work for setting up the distributed training so this is not distributed training specific. The only detail is setting the current epoch count in the DistributedSampler, as the sampler shuffles the data going to each process deterministically based on epoch. After updating the sampler, the loop runs a full training epoch, runs a full validation step then prints the performance of the current model against the best performing model so far. After training for num_epochs, the loop exits and the tutorial is complete. Notice, since this is an exercise we are not saving models but one may wish to keep track of the best performing model then save it at the end of training (see here).

best_prec1 = 0

for epoch in range(num_epochs):
    # Set epoch count for DistributedSampler
    train_sampler.set_epoch(epoch)

    # Adjust learning rate according to schedule
    adjust_learning_rate(starting_lr, optimizer, epoch)

    # train for one epoch
    print("\nBegin Training Epoch {}".format(epoch+1))
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    print("Begin Validation @ Epoch {}".format(epoch+1))
    prec1 = validate(val_loader, model, criterion)

    # remember best prec@1 and save checkpoint if desired
    # is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)

    print("Epoch Summary: ")
    print("\tEpoch Accuracy: {}".format(prec1))
    print("\tBest Accuracy: {}".format(best_prec1))

Running the Code

Unlike most of the other PyTorch tutorials, this code may not be run directly out of this notebook. To run, download the .py version of this file (or convert it using this) and upload a copy to both nodes. The astute reader would have noticed that we hardcoded the node0-privateIP and \(world\_size=4\) but input the rank and local_rank inputs as arg[1] and arg[2] command line arguments, respectively. Once uploaded, open two ssh terminals into each node.

  • On the first terminal for node0, run $ python main.py 0 0

  • On the second terminal for node0 run $ python main.py 1 1

  • On the first terminal for node1, run $ python main.py 2 0

  • On the second terminal for node1 run $ python main.py 3 1

The programs will start and wait after printing “Initialize Model…” for all four processes to join the process group. Notice the first argument is not repeated as this is the unique global rank of the process. The second argument is repeated as that is the local rank of the process running on the node. If you run nvidia-smi on each node, you will see two processes on each node, one running on GPU0 and one on GPU1.

We have now completed the distributed training example! Hopefully you can see how you would use this tutorial to help train your own models on your own datasets, even if you are not using the exact same distributed envrionment. If you are using AWS, don’t forget to SHUT DOWN YOUR NODES if you are not using them or you may find an uncomfortably large bill at the end of the month.

Where to go next

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources