Distributed communication package - torch.distributed

Note

Please refer to PyTorch Distributed Overview for a brief introduction to all features related to distributed training.

Backends

torch.distributed supports three built-in backends, each with different capabilities. The table below shows which functions are available for use with CPU / CUDA tensors. MPI supports CUDA only if the implementation used to build PyTorch supports it.

Backend

gloo

mpi

nccl

Device

CPU

GPU

CPU

GPU

CPU

GPU

send

?

recv

?

broadcast

?

all_reduce

?

reduce

?

all_gather

?

gather

?

scatter

?

reduce_scatter

all_to_all

?

barrier

?

Backends that come with PyTorch

PyTorch distributed package supports Linux (stable), MacOS (stable), and Windows (prototype). By default for Linux, the Gloo and NCCL backends are built and included in PyTorch distributed (NCCL only when building with CUDA). MPI is an optional backend that can only be included if you build PyTorch from source. (e.g.building PyTorch on a host that has MPI installed.)

Note

As of PyTorch v1.8, Windows supports all collective communications backend but NCCL, If the init_method argument of init_process_group() points to a file it must adhere to the following schema:

  • Local file system, init_method="file:///d:/tmp/some_file"
  • Shared file system, init_method="file://////{machine_name}/{share_folder_name}/some_file"

Same as on Linux platform, you can enable TcpStore by setting environment variables, MASTER_ADDR and MASTER_PORT.

Which backend to use?

In the past, we were often asked: “which backend should I use?”.

  • Rule of thumb

    • Use the NCCL backend for distributed GPU training
    • Use the Gloo backend for distributed CPU training.
  • GPU hosts with InfiniBand interconnect

    • Use NCCL, since it’s the only backend that currently supports InfiniBand and GPUDirect.
  • GPU hosts with Ethernet interconnect

    • Use NCCL, since it currently provides the best distributed GPU training performance, especially for multiprocess single-node or multi-node distributed training. If you encounter any problem with NCCL, use Gloo as the fallback option. (Note that Gloo currently runs slower than NCCL for GPUs.)
  • CPU hosts with InfiniBand interconnect

    • If your InfiniBand has enabled IP over IB, use Gloo, otherwise, use MPI instead. We are planning on adding InfiniBand support for Gloo in the upcoming releases.
  • CPU hosts with Ethernet interconnect

    • Use Gloo, unless you have specific reasons to use MPI.

Common environment variables

Choosing the network interface to use

By default, both the NCCL and Gloo backends will try to find the right network interface to use. If the automatically detected interface is not correct, you can override it using the following environment variables (applicable to the respective backend):

  • NCCL_SOCKET_IFNAME, for example export NCCL_SOCKET_IFNAME=eth0
  • GLOO_SOCKET_IFNAME, for example export GLOO_SOCKET_IFNAME=eth0

If you’re using the Gloo backend, you can specify multiple interfaces by separating them by a comma, like this: export GLOO_SOCKET_IFNAME=eth0,eth1,eth2,eth3. The backend will dispatch operations in a round-robin fashion across these interfaces. It is imperative that all processes specify the same number of interfaces in this variable.

Other NCCL environment variables

NCCL has also provided a number of environment variables for fine-tuning purposes.

Commonly used ones include the following for debugging purposes:

  • export NCCL_DEBUG=INFO
  • export NCCL_DEBUG_SUBSYS=ALL

For the full list of NCCL environment variables, please refer to NVIDIA NCCL’s official documentation

Basics

The torch.distributed package provides PyTorch support and communication primitives for multiprocess parallelism across several computation nodes running on one or more machines. The class torch.nn.parallel.DistributedDataParallel() builds on this functionality to provide synchronous distributed training as a wrapper around any PyTorch model. This differs from the kinds of parallelism provided by Multiprocessing package - torch.multiprocessing and torch.nn.DataParallel() in that it supports multiple network-connected machines and in that the user must explicitly launch a separate copy of the main training script for each process.

In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel():

  • Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.
  • Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.

Initialization

The package needs to be initialized using the torch.distributed.init_process_group() function before calling any other methods. This blocks until all processes have joined.

torch.distributed.is_available() [source]

Returns True if the distributed package is available. Otherwise, torch.distributed does not expose any other APIs. Currently, torch.distributed is available on Linux, MacOS and Windows. Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source. Currently, the default value is USE_DISTRIBUTED=1 for Linux and Windows, USE_DISTRIBUTED=0 for MacOS.

torch.distributed.init_process_group(backend, init_method=None, timeout=datetime.timedelta(seconds=1800), world_size=-1, rank=-1, store=None, group_name='') [source]

Initializes the default distributed process group, and this will also initialize the distributed package.

There are 2 main ways to initialize a process group:
  1. Specify store, rank, and world_size explicitly.
  2. Specify init_method (a URL string) which indicates where/how to discover peers. Optionally specify rank and world_size, or encode all required parameters in the URL and omit them.

If neither is specified, init_method is assumed to be “env://”.

Parameters
  • backend (str or Backend) – The backend to use. Depending on build-time configurations, valid values include mpi, gloo, and nccl. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO). If using multiple processes per machine with nccl backend, each process must have exclusive access to every GPU it uses, as sharing GPUs between processes can result in deadlocks.
  • init_method (str, optional) – URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. Mutually exclusive with store.
  • world_size (int, optional) – Number of processes participating in the job. Required if store is specified.
  • rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1). Required if store is specified.
  • store (Store, optional) – Key/value store accessible to all workers, used to exchange connection/address information. Mutually exclusive with init_method.
  • timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value equals 30 minutes. This is applicable for the gloo backend. For nccl, this is applicable only if the environment variable NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1. When NCCL_BLOCKING_WAIT is set, this is the duration for which the process will block and wait for collectives to complete before throwing an exception. When NCCL_ASYNC_ERROR_HANDLING is set, this is the duration after which collectives will be aborted asynchronously and the process will crash. NCCL_BLOCKING_WAIT will provide errors to the user which can be caught and handled, but due to its blocking nature, it has a performance overhead. On the other hand, NCCL_ASYNC_ERROR_HANDLING has very little performance overhead, but crashes the process on errors. This is done since CUDA execution is async and it is no longer safe to continue executing user code since failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. Only one of these two environment variables should be set.
  • group_name (str, optional, deprecated) – Group name.

To enable backend == Backend.MPI, PyTorch needs to be built from source on a system that supports MPI.

class torch.distributed.Backend [source]

An enum-like class of available backends: GLOO, NCCL, MPI, and other registered backends.

The values of this class are lowercase strings, e.g., "gloo". They can be accessed as attributes, e.g., Backend.NCCL.

This class can be directly called to parse the string, e.g., Backend(backend_str) will check if backend_str is valid, and return the parsed lowercase string if so. It also accepts uppercase strings, e.g., Backend("GLOO") returns "gloo".

Note

The entry Backend.UNDEFINED is present but only used as initial value of some fields. Users should neither use it directly nor assume its existence.

torch.distributed.get_backend(group=None) [source]

Returns the backend of the given process group.

Parameters

group (ProcessGroup, optional) – The process group to work on. The default is the general main process group. If another specific group is specified, the calling process must be part of group.

Returns

The backend of the given process group as a lower case string.

torch.distributed.get_rank(group=None) [source]

Returns the rank of current process group

Rank is a unique identifier assigned to each process within a distributed process group. They are always consecutive integers ranging from 0 to world_size.

Parameters

group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.

Returns

The rank of the process group -1, if not part of the group

torch.distributed.get_world_size(group=None) [source]

Returns the number of processes in the current process group

Parameters

group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.

Returns

The world size of the process group -1, if not part of the group

torch.distributed.is_initialized() [source]

Checking if the default process group has been initialized

torch.distributed.is_mpi_available() [source]

Checks if the MPI backend is available.

torch.distributed.is_nccl_available() [source]

Checks if the NCCL backend is available.

Currently three initialization methods are supported:

TCP initialization

There are two ways to initialize using TCP, both requiring a network address reachable from all processes and a desired world_size. The first way requires specifying an address that belongs to the rank 0 process. This initialization method requires that all processes have manually specified ranks.

Note that multicast address is not supported anymore in the latest distributed package. group_name is deprecated as well.

import torch.distributed as dist

# Use address of one of the machines
dist.init_process_group(backend, init_method='tcp://10.1.1.20:23456',
                        rank=args.rank, world_size=4)

Shared file-system initialization

Another initialization method makes use of a file system that is shared and visible from all machines in a group, along with a desired world_size. The URL should start with file:// and contain a path to a non-existent file (in an existing directory) on a shared file system. File-system initialization will automatically create that file if it doesn’t exist, but will not delete the file. Therefore, it is your responsibility to make sure that the file is cleaned up before the next init_process_group() call on the same file path/name.

Note that automatic rank assignment is not supported anymore in the latest distributed package and group_name is deprecated as well.

Warning

This method assumes that the file system supports locking using fcntl - most local systems and NFS support it.

Warning

This method will always create the file and try its best to clean up and remove the file at the end of the program. In other words, each initialization with the file init method will need a brand new empty file in order for the initialization to succeed. If the same file used by the previous initialization (which happens not to get cleaned up) is used again, this is unexpected behavior and can often cause deadlocks and failures. Therefore, even though this method will try its best to clean up the file, if the auto-delete happens to be unsuccessful, it is your responsibility to ensure that the file is removed at the end of the training to prevent the same file to be reused again during the next time. This is especially important if you plan to call init_process_group() multiple times on the same file name. In other words, if the file is not removed/cleaned up and you call init_process_group() again on that file, failures are expected. The rule of thumb here is that, make sure that the file is non-existent or empty every time init_process_group() is called.

import torch.distributed as dist

# rank should always be specified
dist.init_process_group(backend, init_method='file:///mnt/nfs/sharedfile',
                        world_size=4, rank=args.rank)

Environment variable initialization

This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are:

  • MASTER_PORT - required; has to be a free port on machine with rank 0
  • MASTER_ADDR - required (except for rank 0); address of rank 0 node
  • WORLD_SIZE - required; can be set either here, or in a call to init function
  • RANK - required; can be set either here, or in a call to init function

The machine with rank 0 will be used to set up all connections.

This is the default method, meaning that init_method does not have to be specified (or can be env://).

Distributed Key-Value Store

The distributed package comes with a distributed key-value store, which can be used to share information between processes in the group as well as to initialize the distributed pacakge in torch.distributed.init_process_group() (by explicitly creating the store as an alternative to specifying init_method.) There are 3 choices for Key-Value Stores: TCPStore, FileStore, and HashStore.

class torch.distributed.Store

Base class for all store implementations, such as the 3 provided by PyTorch distributed: (TCPStore, FileStore, and HashStore).

class torch.distributed.TCPStore

A TCP-based distributed key-value store implementation. The server store holds the data, while the client stores can connect to the server store over TCP and perform actions such as set() to insert a key-value pair, get() to retrieve a key-value pair, etc.

Parameters
  • host_name (str) – The hostname or IP Address the server store should run on.
  • port (int) – The port on which the server store should listen for incoming requests.
  • world_size (int) – The total number of store users (number of clients + 1 for the server).
  • is_master (bool) – True when initializing the server store, False for client stores.
  • timeout (timedelta) – Timeout used by the store during initialization and for methods such as get() and wait().
Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Run on process 1 (server)
>>> server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
>>> # Run on process 2 (client)
>>> client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
>>> # Use any of the store methods from either the client or server after initialization
>>> server_store.set("first_key", "first_value")
>>> client_store.get("first_key")
class torch.distributed.HashStore

A thread-safe store implementation based on an underlying hashmap. This store can be used within the same process (for example, by other threads), but cannot be used across processes.

Example::
>>> import torch.distributed as dist
>>> store = dist.HashStore()
>>> # store can be used from other threads
>>> # Use any of the store methods after initialization
>>> store.set("first_key", "first_value")
class torch.distributed.FileStore

A store implementation that uses a file to store the underlying key-value pairs.

Parameters
  • file_name (str) – path of the file in which to store the key-value pairs
  • world_size (int) – The total number of processes using the store
Example::
>>> import torch.distributed as dist
>>> store1 = dist.FileStore("/tmp/filestore", 2)
>>> store2 = dist.FileStore("/tmp/filestore", 2)
>>> # Use any of the store methods from either the client or server after initialization
>>> store1.set("first_key", "first_value")
>>> store2.get("first_key")
class torch.distributed.PrefixStore

A wrapper around any of the 3 key-value stores (TCPStore, FileStore, and HashStore) that adds a prefix to each key inserted to the store.

Parameters
  • prefix (str) – The prefix string that is prepended to each key before being inserted into the store.
  • store (torch.distributed.store) – A store object that forms the underlying key-value store.
torch.distributed.Store.set(self: torch._C._distributed_c10d.Store, arg0: str, arg1: str) → None

Inserts the key-value pair into the store based on the supplied key and value. If key already exists in the store, it will overwrite the old value with the new supplied value.

Parameters
  • key (str) – The key to be added to the store.
  • value (str) – The value associated with key to be added to the store.
Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")
torch.distributed.Store.get(self: torch._C._distributed_c10d.Store, arg0: str) → bytes

Retrieves the value associated with the given key in the store. If key is not present in the store, the function will wait for timeout, which is defined when initializing the store, before throwing an exception.

Parameters

key (str) – The function will return the value associated with this key.

Returns

Value associated with key if key is in the store.

Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")
torch.distributed.Store.add(self: torch._C._distributed_c10d.Store, arg0: str, arg1: int) → int

The first call to add for a given key creates a counter associated with key in the store, initialized to amount. Subsequent calls to add with the same key increment the counter by the specified amount. Calling add() with a key that has already been set in the store by set() will result in an exception.

Parameters
  • key (str) – The key in the store whose counter will be incremented.
  • amount (int) – The quantity by which the counter will be incremented.
Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.add("first_key", 1)
>>> store.add("first_key", 6)
>>> # Should return 7
>>> store.get("first_key")
torch.distributed.Store.wait(*args, **kwargs)

Overloaded function.

  1. wait(self: torch._C._distributed_c10d.Store, arg0: List[str]) -> None

Waits for each key in keys to be added to the store. If not all keys are set before the timeout (set during store initialization), then wait will throw an exception.

Parameters

keys (list) – List of keys on which to wait until they are set in the store.

Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> # This will throw an exception after 30 seconds
>>> store.wait(["bad_key"])
  1. wait(self: torch._C._distributed_c10d.Store, arg0: List[str], arg1: datetime.timedelta) -> None

Waits for each key in keys to be added to the store, and throws an exception if the keys have not been set by the supplied timeout.

Parameters
  • keys (list) – List of keys on which to wait until they are set in the store.
  • timeout (timedelta) – Time to wait for the keys to be added before throwing an exception.
Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"], timedelta(seconds=10))
torch.distributed.Store.num_keys(self: torch._C._distributed_c10d.Store) → int

Returns the number of keys set in the store. Note that this number will typically be one greater than the number of keys added by set() and add() since one key is used to coordinate all the workers using the store.

Warning

When used with the TCPStore, num_keys returns the number of keys written to the underlying file. If the store is destructed and another store is created with the same file, the original keys will be retained.

Returns

The number of keys present in the store.

Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # This should return 2
>>> store.num_keys()
torch.distributed.Store.delete_key(self: torch._C._distributed_c10d.Store, arg0: str) → bool

Deletes the key-value pair associated with key from the store. Returns true if the key was successfully deleted, and false if it was not.

Warning

The delete_key API is only supported by the TCPStore and HashStore. Using this API with the FileStore will result in an exception.

Parameters

key (str) – The key to be deleted from the store

Returns

True if key was deleted, otherwise False.

Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, HashStore can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set("first_key")
>>> # This should return true
>>> store.delete_key("first_key")
>>> # This should return false
>>> store.delete_key("bad_key")
torch.distributed.Store.set_timeout(self: torch._C._distributed_c10d.Store, arg0: datetime.timedelta) → None

Sets the store’s default timeout. This timeout is used during initialization and in wait() and get().

Parameters

timeout (timedelta) – timeout to be set in the store.

Example::
>>> import torch.distributed as dist
>>> from datetime import timedelta
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
>>> store.set_timeout(timedelta(seconds=10))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"])

Groups

By default collectives operate on the default group (also called the world) and require all processes to enter the distributed function call. However, some workloads can benefit from more fine-grained communication. This is where distributed groups come into play. new_group() function can be used to create new groups, with arbitrary subsets of all processes. It returns an opaque group handle that can be given as a group argument to all collectives (collectives are distributed functions to exchange information in certain well-known programming patterns).

torch.distributed.new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None) [source]

Creates a new distributed group.

This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.

Warning

Using multiple process groups with the NCCL backend concurrently is not safe and the user should perform explicit synchronization in their application to ensure only one process group is used at a time. This means collectives from one process group should have completed execution on the device (not just enqueued since CUDA execution is async) before collectives from another process group are enqueued. See Using multiple NCCL communicators concurrently for more details.

Parameters
  • ranks (list[int]) – List of ranks of group members. If None, will be set to all ranks. Default is None.
  • timeout (timedelta, optional) – Timeout for operations executed against the process group. Default value equals 30 minutes. This is only applicable for the gloo backend.
  • backend (str or Backend, optional) – The backend to use. Depending on build-time configurations, valid values are gloo and nccl. By default uses the same backend as the global group. This field should be given as a lowercase string (e.g., "gloo"), which can also be accessed via Backend attributes (e.g., Backend.GLOO).
Returns

A handle of distributed group that can be given to collective calls.

Point-to-point communication

torch.distributed.send(tensor, dst, group=None, tag=0) [source]

Sends a tensor synchronously.

Parameters
  • tensor (Tensor) – Tensor to send.
  • dst (int) – Destination rank.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • tag (int, optional) – Tag to match send with remote recv
torch.distributed.recv(tensor, src=None, group=None, tag=0) [source]

Receives a tensor synchronously.

Parameters
  • tensor (Tensor) – Tensor to fill with received data.
  • src (int, optional) – Source rank. Will receive from any process if unspecified.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • tag (int, optional) – Tag to match recv with remote send
Returns

Sender rank -1, if not part of the group

isend() and irecv() return distributed request objects when used. In general, the type of this object is unspecified as they should never be created manually, but they are guaranteed to support two methods:

  • is_completed() - returns True if the operation has finished
  • wait() - will block the process until the operation is finished. is_completed() is guaranteed to return True once it returns.
torch.distributed.isend(tensor, dst, group=None, tag=0) [source]

Sends a tensor asynchronously.

Parameters
  • tensor (Tensor) – Tensor to send.
  • dst (int) – Destination rank.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • tag (int, optional) – Tag to match send with remote recv
Returns

A distributed request object. None, if not part of the group

torch.distributed.irecv(tensor, src=None, group=None, tag=0) [source]

Receives a tensor asynchronously.

Parameters
  • tensor (Tensor) – Tensor to fill with received data.
  • src (int, optional) – Source rank. Will receive from any process if unspecified.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • tag (int, optional) – Tag to match recv with remote send
Returns

A distributed request object. None, if not part of the group

Synchronous and asynchronous collective operations

Every collective operation function supports the following two kinds of operations, depending on the setting of the async_op flag passed into the collective:

Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations.

Asynchronous operation - when async_op is set to True. The collective operation function returns a distributed request object. In general, you don’t need to create it manually and it is guaranteed to support two methods:

  • is_completed() - in the case of CPU collectives, returns True if completed. In the case of CUDA operations, returns True if the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization.
  • wait() - in the case of CPU collectives, will block the process until the operation is completed. In the case of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization.

Example

The following code can serve as a reference regarding semantics for CUDA operations when using distributed collectives. It shows the explicit need to synchronize when using collective outputs on different CUDA streams:

# Code runs on each rank.
dist.init_process_group("nccl", rank=rank, world_size=2)
output = torch.tensor([rank]).cuda(rank)
s = torch.cuda.Stream()
handle = dist.all_reduce(output, async_op=True)
# Wait ensures the operation is enqueued, but not necessarily complete.
handle.wait()
# Using result on non-default stream.
with torch.cuda.stream(s):
    s.wait_stream(torch.cuda.default_stream())
    output.add_(100)
if rank == 0:
    # if the explicit call to wait_stream was omitted, the output below will be
    # non-deterministically 1 or 101, depending on whether the allreduce overwrote
    # the value after the add completed.
    print(output)

Collective functions

torch.distributed.broadcast(tensor, src, group=None, async_op=False) [source]

Broadcasts the tensor to the whole group.

tensor must have the same number of elements in all processes participating in the collective.

Parameters
  • tensor (Tensor) – Data to be sent if src is the rank of current process, and tensor to be used to save received data otherwise.
  • src (int) – Source rank.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.broadcast_object_list(object_list, src=0, group=None) [source]

Broadcasts picklable objects in object_list to the whole group. Similar to broadcast(), but Python objects can be passed in. Note that all objects in object_list must be picklable in order to be broadcasted.

Parameters
  • object_list (List[Any]) – List of input objects to broadcast. Each object must be picklable. Only objects on the src rank will be broadcast, but each rank must provide lists of equal sizes.
  • src (int) – Source rank from which to broadcast object_list.
  • group – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
Returns

None. If rank is part of the group, object_list will contain the broadcasted objects from src rank.

Note

For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().

Note

Note that this API differs slightly from the all_gather() collective since it does not provide an async_op handle and thus will be a blocking call.

Warning

broadcast_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3.
>>>     objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     objects = [None, None, None]
>>> dist.broadcast_object_list(objects, src=0)
>>> broadcast_objects
['foo', 12, {1: 2}]
torch.distributed.all_reduce(tensor, op=<ReduceOp.SUM: 0>, group=None, async_op=False) [source]

Reduces the tensor data across all machines in such a way that all get the final result.

After the call tensor is going to be bitwise identical in all processes.

Complex tensors are supported.

Parameters
  • tensor (Tensor) – Input and output of the collective. The function operates in-place.
  • op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Examples

>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
torch.distributed.reduce(tensor, dst, op=<ReduceOp.SUM: 0>, group=None, async_op=False) [source]

Reduces the tensor data across all machines.

Only the process with rank dst is going to receive the final result.

Parameters
  • tensor (Tensor) – Input and output of the collective. The function operates in-place.
  • dst (int) – Destination rank
  • op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.all_gather(tensor_list, tensor, group=None, async_op=False) [source]

Gathers tensors from the whole group in a list.

Complex tensors are supported.

Parameters
  • tensor_list (list[Tensor]) – Output list. It should contain correctly-sized tensors to be used for output of the collective.
  • tensor (Tensor) – Tensor to be broadcast from current process.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

Examples

>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
torch.distributed.all_gather_object(object_list, obj, group=None) [source]

Gathers picklable objects from the whole group into a list. Similar to all_gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.

Parameters
  • object_list (list[Any]) – Output list. It should be correctly sized as the size of the group for this collective and will contain the output.
  • object (Any) – Pickable Python object to be broadcast from current process.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used. Default is None.
Returns

None. If the calling rank is part of this group, the output of the collective will be populated into the input object_list. If the calling rank is not part of the group, the passed in object_list will be unmodified.

Note

Note that this API differs slightly from the all_gather() collective since it does not provide an async_op handle and thus will be a blocking call.

Note

For NCCL-based processed groups, internal tensor representations of objects must be moved to the GPU device before communication takes place. In this case, the device used is given by torch.cuda.current_device() and it is the user’s responsiblity to ensure that this is set so that each rank has an individual GPU, via torch.cuda.set_device().

Warning

all_gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
torch.distributed.gather(tensor, gather_list=None, dst=0, group=None, async_op=False) [source]

Gathers a list of tensors in a single process.

Parameters
  • tensor (Tensor) – Input tensor.
  • gather_list (list[Tensor], optional) – List of appropriately-sized tensors to use for gathered data (default is None, must be specified on the destination rank)
  • dst (int, optional) – Destination rank (default is 0)
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.gather_object(obj, object_gather_list=None, dst=0, group=None) [source]

Gathers picklable objects from the whole group in a single process. Similar to gather(), but Python objects can be passed in. Note that the object must be picklable in order to be gathered.

Parameters
  • obj (Any) – Input object. Must be picklable.
  • object_gather_list (list[Any]) – Output list. On the dst rank, it should be correctly sized as the size of the group for this collective and will contain the output. Must be None on non-dst ranks. (default is None)
  • dst (int, optional) – Destination rank. (default is 0)
  • group – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
Returns

None. On the dst rank, object_gather_list will contain the output of the collective.

Note

Note that this API differs slightly from the gather collective since it does not provide an async_op handle and thus will be a blocking call.

Note

Note that this API is not supported when using the NCCL backend.

Warning

gather_object() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
        gather_objects[dist.get_rank()],
        output if dist.get_rank() == 0 else None,
        dst=0
    )
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]
torch.distributed.scatter(tensor, scatter_list=None, src=0, group=None, async_op=False) [source]

Scatters a list of tensors to all processes in a group.

Each process will receive exactly one tensor and store its data in the tensor argument.

Parameters
  • tensor (Tensor) – Output tensor.
  • scatter_list (list[Tensor]) – List of tensors to scatter (default is None, must be specified on the source rank)
  • src (int) – Source rank (default is 0)
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) [source]

Scatters picklable objects in scatter_object_input_list to the whole group. Similar to scatter(), but Python objects can be passed in. On each rank, the scattered object will be stored as the first element of scatter_object_output_list. Note that all objects in scatter_object_input_list must be picklable in order to be scattered.

Parameters
  • scatter_object_output_list (List[Any]) – Non-empty list whose first element will store the object scattered to this rank.
  • scatter_object_input_list (List[Any]) – List of input objects to scatter. Each object must be picklable. Only objects on the src rank will be scattered, and the argument can be None for non-src ranks.
  • src (int) – Source rank from which to scatter scatter_object_input_list.
  • group – (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Default is None.
Returns

None. If rank is part of the group, scatter_object_output_list will have its first element set to the scattered object for this rank.

Note

Note that this API differs slightly from the scatter collective since it does not provide an async_op handle and thus will be a blocking call.

Warning

scatter_object_list() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Only call this function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> if dist.get_rank() == 0:
>>>     # Assumes world_size of 3.
>>>     objects = ["foo", 12, {1: 2}] # any picklable object
>>> else:
>>>     # Can be any list on non-src ranks, elements are not used.
>>>     objects = [None, None, None]
>>> output_list = [None]
>>> dist.scatter_object_list(output_list, objects, src=0)
>>> # Rank i gets objects[i]. For example, on rank 2:
>>> output_list
[{1: 2}]
torch.distributed.reduce_scatter(output, input_list, op=<ReduceOp.SUM: 0>, group=None, async_op=False) [source]

Reduces, then scatters a list of tensors to all processes in a group.

Parameters
  • output (Tensor) – Output tensor.
  • input_list (list[Tensor]) – List of tensors to reduce and scatter.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op.
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

torch.distributed.all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False) [source]

Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.

Parameters
  • output_tensor_list (list[Tensor]) – List of tensors to be gathered one per rank.
  • input_tensor_list (list[Tensor]) – List of tensors to scatter one per rank.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op.
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

Warning

all_to_all is experimental and subject to change.

Examples

>>> input = torch.arange(4) + rank * 4
>>> input = list(input.chunk(4))
>>> input
[tensor([0]), tensor([1]), tensor([2]), tensor([3])]     # Rank 0
[tensor([4]), tensor([5]), tensor([6]), tensor([7])]     # Rank 1
[tensor([8]), tensor([9]), tensor([10]), tensor([11])]   # Rank 2
[tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
>>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
>>> dist.all_to_all(output, input)
>>> output
[tensor([0]), tensor([4]), tensor([8]), tensor([12])]    # Rank 0
[tensor([1]), tensor([5]), tensor([9]), tensor([13])]    # Rank 1
[tensor([2]), tensor([6]), tensor([10]), tensor([14])]   # Rank 2
[tensor([3]), tensor([7]), tensor([11]), tensor([15])]   # Rank 3
>>> # Essentially, it is similar to following operation:
>>> scatter_list = input
>>> gather_list  = output
>>> for i in range(world_size):
>>>   dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
>>> input
tensor([0, 1, 2, 3, 4, 5])                                       # Rank 0
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18])                     # Rank 1
tensor([20, 21, 22, 23, 24])                                     # Rank 2
tensor([30, 31, 32, 33, 34, 35, 36])                             # Rank 3
>>> input_splits
[2, 2, 1, 1]                                                     # Rank 0
[3, 2, 2, 2]                                                     # Rank 1
[2, 1, 1, 1]                                                     # Rank 2
[2, 2, 2, 1]                                                     # Rank 3
>>> output_splits
[2, 3, 2, 2]                                                     # Rank 0
[2, 2, 1, 2]                                                     # Rank 1
[1, 2, 1, 2]                                                     # Rank 2
[1, 2, 1, 1]                                                     # Rank 3
>>> input = list(input.split(input_splits))
>>> input
[tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])]                   # Rank 0
[tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
[tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])]                 # Rank 2
[tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])]         # Rank 3
>>> output = ...
>>> dist.all_to_all(output, input)
>>> output
[tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])]   # Rank 0
[tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])]           # Rank 1
[tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])]              # Rank 2
[tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])]                  # Rank 3
torch.distributed.barrier(group=None, async_op=False, device_ids=None) [source]

Synchronizes all processes.

This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().

Parameters
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
  • device_ids ([int], optional) – List of device/GPU ids. Valid only for NCCL backend.
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

class torch.distributed.ReduceOp

An enum-like class for available reduction operations: SUM, PRODUCT, MIN, MAX, BAND, BOR, and BXOR.

Note that BAND, BOR, and BXOR reductions are not available when using the NCCL backend.

Additionally, MAX, MIN and PRODUCT are not supported for complex tensors.

The values of this class can be accessed as attributes, e.g., ReduceOp.SUM. They are used in specifying strategies for reduction collectives, e.g., reduce(), all_reduce_multigpu(), etc.

Members:

SUM

PRODUCT

MIN

MAX

BAND

BOR

BXOR

class torch.distributed.reduce_op

Deprecated enum-like class for reduction operations: SUM, PRODUCT, MIN, and MAX.

ReduceOp is recommended to use instead.

Autograd-enabled communication primitives

If you want to use collective communication functions supporting autograd you can find an implementation of those in the torch.distributed.nn.* module.

Functions here are synchronous and will be inserted in the autograd graph, so you need to ensure that all the processes that participated in the collective operation will do the backward pass for the backward communication to effectively happen and don’t cause a deadlock.

Please notice that currently the only backend where all the functions are guaranteed to work is gloo. .. autofunction:: torch.distributed.nn.broadcast .. autofunction:: torch.distributed.nn.gather .. autofunction:: torch.distributed.nn.scatter .. autofunction:: torch.distributed.nn.reduce .. autofunction:: torch.distributed.nn.all_gather .. autofunction:: torch.distributed.nn.all_to_all .. autofunction:: torch.distributed.nn.all_reduce

Multi-GPU collective functions

If you have more than one GPU on each node, when using the NCCL and Gloo backend, broadcast_multigpu() all_reduce_multigpu() reduce_multigpu() all_gather_multigpu() and reduce_scatter_multigpu() support distributed collective operations among multiple GPUs within each node. These functions can potentially improve the overall distributed training performance and be easily used by passing a list of tensors. Each Tensor in the passed tensor list needs to be on a separate GPU device of the host where the function is called. Note that the length of the tensor list needs to be identical among all the distributed processes. Also note that currently the multi-GPU collective functions are only supported by the NCCL backend.

For example, if the system we use for distributed training has 2 nodes, each of which has 8 GPUs. On each of the 16 GPUs, there is a tensor that we would like to all-reduce. The following code can serve as a reference:

Code running on Node 0

import torch
import torch.distributed as dist

dist.init_process_group(backend="nccl",
                        init_method="file:///distributed_test",
                        world_size=2,
                        rank=0)
tensor_list = []
for dev_idx in range(torch.cuda.device_count()):
    tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))

dist.all_reduce_multigpu(tensor_list)

Code running on Node 1

import torch
import torch.distributed as dist

dist.init_process_group(backend="nccl",
                        init_method="file:///distributed_test",
                        world_size=2,
                        rank=1)
tensor_list = []
for dev_idx in range(torch.cuda.device_count()):
    tensor_list.append(torch.FloatTensor([1]).cuda(dev_idx))

dist.all_reduce_multigpu(tensor_list)

After the call, all 16 tensors on the two nodes will have the all-reduced value of 16

torch.distributed.broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=0) [source]

Broadcasts the tensor to the whole group with multiple GPU tensors per node.

tensor must have the same number of elements in all the GPUs from all processes participating in the collective. each tensor in the list must be on a different GPU

Only nccl and gloo backend are currently supported tensors should only be GPU tensors

Parameters
  • tensor_list (List[Tensor]) – Tensors that participate in the collective operation. If src is the rank, then the specified src_tensor element of tensor_list (tensor_list[src_tensor]) will be broadcast to all other tensors (on different GPUs) in the src process and all tensors in tensor_list of other non-src processes. You also need to make sure that len(tensor_list) is the same for all the distributed processes calling this function.
  • src (int) – Source rank.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
  • src_tensor (int, optional) – Source tensor rank within tensor_list
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.all_reduce_multigpu(tensor_list, op=<ReduceOp.SUM: 0>, group=None, async_op=False) [source]

Reduces the tensor data across all machines in such a way that all get the final result. This function reduces a number of tensors on every node, while each tensor resides on different GPUs. Therefore, the input tensor in the tensor list needs to be GPU tensors. Also, each tensor in the tensor list needs to reside on a different GPU.

After the call, all tensor in tensor_list is going to be bitwise identical in all processes.

Complex tensors are supported.

Only nccl and gloo backend is currently supported tensors should only be GPU tensors

Parameters
  • list (tensor) – List of input and output tensors of the collective. The function operates in-place and requires that each tensor to be a GPU tensor on different GPUs. You also need to make sure that len(tensor_list) is the same for all the distributed processes calling this function.
  • op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.reduce_multigpu(tensor_list, dst, op=<ReduceOp.SUM: 0>, group=None, async_op=False, dst_tensor=0) [source]

Reduces the tensor data on multiple GPUs across all machines. Each tensor in tensor_list should reside on a separate GPU

Only the GPU of tensor_list[dst_tensor] on the process with rank dst is going to receive the final result.

Only nccl backend is currently supported tensors should only be GPU tensors

Parameters
  • tensor_list (List[Tensor]) – Input and output GPU tensors of the collective. The function operates in-place. You also need to make sure that len(tensor_list) is the same for all the distributed processes calling this function.
  • dst (int) – Destination rank
  • op (optional) – One of the values from torch.distributed.ReduceOp enum. Specifies an operation used for element-wise reductions.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
  • dst_tensor (int, optional) – Destination tensor rank within tensor_list
Returns

Async work handle, if async_op is set to True. None, otherwise

torch.distributed.all_gather_multigpu(output_tensor_lists, input_tensor_list, group=None, async_op=False) [source]

Gathers tensors from the whole group in a list. Each tensor in tensor_list should reside on a separate GPU

Only nccl backend is currently supported tensors should only be GPU tensors

Complex tensors are supported.

Parameters
  • output_tensor_lists (List[List[Tensor]]) –

    Output lists. It should contain correctly-sized tensors on each GPU to be used for output of the collective, e.g. output_tensor_lists[i] contains the all_gather result that resides on the GPU of input_tensor_list[i].

    Note that each element of output_tensor_lists has the size of world_size * len(input_tensor_list), since the function all gathers the result from every single GPU in the group. To interpret each element of output_tensor_lists[i], note that input_tensor_list[j] of rank k will be appear in output_tensor_lists[i][k * world_size + j]

    Also note that len(output_tensor_lists), and the size of each element in output_tensor_lists (each element is a list, therefore len(output_tensor_lists[i])) need to be the same for all the distributed processes calling this function.

  • input_tensor_list (List[Tensor]) – List of tensors(on different GPUs) to be broadcast from current process. Note that len(input_tensor_list) needs to be the same for all the distributed processes calling this function.
  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

torch.distributed.reduce_scatter_multigpu(output_tensor_list, input_tensor_lists, op=<ReduceOp.SUM: 0>, group=None, async_op=False) [source]

Reduce and scatter a list of tensors to the whole group. Only nccl backend is currently supported.

Each tensor in output_tensor_list should reside on a separate GPU, as should each list of tensors in input_tensor_lists.

Parameters
  • output_tensor_list (List[Tensor]) –

    Output tensors (on different GPUs) to receive the result of the operation.

    Note that len(output_tensor_list) needs to be the same for all the distributed processes calling this function.

  • input_tensor_lists (List[List[Tensor]]) –

    Input lists. It should contain correctly-sized tensors on each GPU to be used for input of the collective, e.g. input_tensor_lists[i] contains the reduce_scatter input that resides on the GPU of output_tensor_list[i].

    Note that each element of input_tensor_lists has the size of world_size * len(output_tensor_list), since the function scatters the result from every single GPU in the group. To interpret each element of input_tensor_lists[i], note that output_tensor_list[j] of rank k receives the reduce-scattered result from input_tensor_lists[i][k * world_size + j]

    Also note that len(input_tensor_lists), and the size of each element in input_tensor_lists (each element is a list, therefore len(input_tensor_lists[i])) need to be the same for all the distributed processes calling this function.

  • group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
  • async_op (bool, optional) – Whether this op should be an async op.
Returns

Async work handle, if async_op is set to True. None, if not async_op or if not part of the group.

Third-party backends

Besides the GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends through a run-time register mechanism. For references on how to develop a third-party backend through C++ Extension, please refer to Tutorials - Custom C++ and CUDA Extensions and test/cpp_extensions/cpp_c10d_extension.cpp. The capability of third-party backends are decided by their own implementations.

The new backend derives from c10d.ProcessGroup and registers the backend name and the instantiating interface through torch.distributed.Backend.register_backend() when imported.

When manually importing this backend and invoking torch.distributed.init_process_group() with the corresponding backend name, the torch.distributed package runs on the new backend.

Warning

The support of third-party backend is experimental and subject to change.

Launch utility

The torch.distributed package also provides a launch utility in torch.distributed.launch. This helper utility can be used to launch multiple processes per node for distributed training.

torch.distributed.launch is a module that spawns up multiple distributed training processes on each of the training nodes.

The utility can be used for single-node distributed training, in which one or more processes per node will be spawned. The utility can be used for either CPU training or GPU training. If the utility is used for GPU training, each distributed process will be operating on a single GPU. This can achieve well-improved single-node training performance. It can also be used in multi-node distributed training, by spawning up multiple processes on each node for well-improved multi-node distributed training performance as well. This will especially be benefitial for systems with multiple Infiniband interfaces that have direct-GPU support, since all of them can be utilized for aggregated communication bandwidth.

In both cases of single-node distributed training or multi-node distributed training, this utility will launch the given number of processes per node (--nproc_per_node). If used for GPU training, this number needs to be less or equal to the number of GPUs on the current system (nproc_per_node), and each process will be operating on a single GPU from GPU 0 to GPU (nproc_per_node - 1).

How to use this module:

  1. Single-Node multi-process distributed training
>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
           YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
           arguments of your training script)
  1. Multi-Node multi-process distributed training: (e.g. two nodes)

Node 1: (IP: 192.168.1.1, and has a free port: 1234)

>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
           --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
           --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
           and all other arguments of your training script)

Node 2:

>>> python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
           --nnodes=2 --node_rank=1 --master_addr="192.168.1.1"
           --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
           and all other arguments of your training script)
  1. To look up what optional arguments this module offers:
>>> python -m torch.distributed.launch --help

Important Notices:

1. This utility and multi-process distributed (single-node or multi-node) GPU training currently only achieves the best performance using the NCCL distributed backend. Thus NCCL backend is the recommended backend to use for GPU training.

2. In your training program, you must parse the command-line argument: --local_rank=LOCAL_PROCESS_RANK, which will be provided by this module. If your training program uses GPUs, you should ensure that your code only runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:

Parsing the local_rank argument

>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument("--local_rank", type=int)
>>> args = parser.parse_args()

Set your device to local rank using either

>>> torch.cuda.set_device(args.local_rank)  # before your code runs

or

>>> with torch.cuda.device(args.local_rank):
>>>    # your code to run

3. In your training program, you are supposed to call the following function at the beginning to start the distributed backend. You need to make sure that the init_method uses env://, which is the only supported init_method by this module.

torch.distributed.init_process_group(backend='YOUR BACKEND',
                                     init_method='env://')

4. In your training program, you can either use regular distributed functions or use torch.nn.parallel.DistributedDataParallel() module. If your training program uses GPUs for training and you would like to use torch.nn.parallel.DistributedDataParallel() module, here is how to configure it.

model = torch.nn.parallel.DistributedDataParallel(model,
                                                  device_ids=[args.local_rank],
                                                  output_device=args.local_rank)

Please ensure that device_ids argument is set to be the only GPU device id that your code will be operating on. This is generally the local rank of the process. In other words, the device_ids needs to be [args.local_rank], and output_device needs to be args.local_rank in order to use this utility

5. Another way to pass local_rank to the subprocesses via environment variable LOCAL_RANK. This behavior is enabled when you launch the script with --use_env=True. You must adjust the subprocess example above to replace args.local_rank with os.environ['LOCAL_RANK']; the launcher will not pass --local_rank when you specify this flag.

Warning

local_rank is NOT globally unique: it is only unique per process on a machine. Thus, don’t use it to decide if you should, e.g., write to a networked filesystem. See https://github.com/pytorch/pytorch/issues/12042 for an example of how things can go wrong if you don’t do this correctly.

Spawn utility

The Multiprocessing package - torch.multiprocessing package also provides a spawn function in torch.multiprocessing.spawn(). This helper function can be used to spawn multiple processes. It works by passing in the function that you want to run and spawns N processes to run it. This can be used for multiprocess distributed training as well.

For references on how to use it, please refer to PyTorch example - ImageNet implementation

Note that this function requires Python 3.4 or higher.

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/distributed.html