DDP Communication Hooks

DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases.

Warning

DDP communication hook is experimental and subject to change.

Warning

DDP communication hooks can only support single process single device mode on NCCL backend.

How to Use a Communication Hook?

To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below.

torch.nn.parallel.DistributedDataParallel.register_comm_hook().
noindex

Default Communication Hooks

Default communication hooks are simple stateless hooks, so the input state in register_comm_hook is either a process group or None.

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket) [source]

This DDP communication hook just calls allreduce using GradBucket tensors. Once gradient tensors are aggregated across all workers, its then callback takes the mean and returns the result. If user registers this hook, DDP results is expected to be same as the case where no hook was registered. Hence, this won’t change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior.

Example::
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket) [source]

This DDP communication hook implements a simple gradient compression approach that converts GradBucket tensors whose type is assumed to be torch.float32 to half-precision floating point format (torch.float16). It allreduces those float16 gradient tensors. Once compressed gradient tensors are allreduced, its then callback called decompress converts the aggregated result back to float32 and takes the mean.

Example::
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)

PowerSGD Communication Hook

PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below.

PowerSGD State

class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=10, use_error_feedback=True, warm_start=True, random_seed=0) [source]

Stores both the algorithm’s hyperparameters and the internal state for all the gradients during the training. Particularly, matrix_approximation_rank and start_powerSGD_iter are the main hyperparameters that should be tuned by the user. For performance, we suggest to keep binary hyperparameters use_error_feedback and warm_start on.

  1. matrix_approximation_rank controls the size of compressed low-rank tensors, which determines the compression rate. The lower the rank, the stronger the compression.

    1.1. If matrix_approximation_rank is too low, the full model quality will need more training steps to reach or will never reach and yield loss in accuracy.

    1.2. The increase of matrix_approximation_rank can substantially increase the computation costs of the compression, and the accuracy may not be futher improved beyond a certain matrix_approximation_rank threshold.

To tune matrix_approximation_rank, we suggest to start from 1 and increase by factors of 2 (like an expoential grid search, 1, 2, 4, …), until a satisfactory accuracy is reached. Typically only a small value 1-4 is used. For some NLP tasks (as shown in Appendix D of the original paper), this value has been increased to 32.

  1. start_powerSGD_iter defers PowerSGD compression util step start_powerSGD_iter, and vanilla allreduce runs prior to step start_powerSGD_iter. This hybrid scheme of vanilla allreduce + PowerSGD can effectively improve the accuracy, even a relatively small matrix_approximation_rank is used. This is because that, the beginning of training phase is usually very sensitive to inaccurate gradients, and compressing gradients too early may make the training quickly take a suboptimal trajectory, which can result in an irrecoverable impact on the accuracy.

To tune start_powerSGD_iter, we suggest to start with 10% of total training steps, and increase it until a satisfactory accuracy is reached.

Warning

If error feedback or warm-up is enabled, the minimum value of start_powerSGD_iter allowed in DDP is 2. This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, and this can conflict with any tensor memorized before the rebuild process.

PowerSGD Hooks

Warning

PowerSGD typically requires extra memory of the same size as the model’s gradients to enable error feedback, which can compensate for biased compressed communication and improve accuracy.

Warning

The current implementation may cause gradient overflow for FP16 input.

torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket) [source]

This DDP communication hook implements PowerSGD gradient compression algorithm described in the paper. Once gradient tensors are aggregated across all workers, this hook applies compression as follows:

  1. Views the input flattened 1D gradient tensor as two groups of per-parameter tensors: high-rank tensors and vector-like rank-1 tensors (for biases).
  2. Handles rank-1 tensors by allreducing them without compression:

    2.1. Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression;

    2.2. Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.

  3. Handles high-rank tensors by PowerSGD compression:

    3.1. For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;

    3.2. Computes each P in Ps, which is equal to MQ;

    3.3. Allreduces Ps as a batch;

    3.4. Orthogonalizes each P in Ps;

    3.5. Computes each Q in Qs, which is approximately equal to M^TP;

    3.6. Allreduces Qs as a batch;

    3.7. Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.

Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank` and start_powerSGD_iter.
  • bucket (dist._GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode at this time, only exactly one tensor is stored in this bucket.
Returns

Future handler of the communication, which updates the gradients in place.

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket) [source]

This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in the paper. This variant does not compress the gradients layer by layer, but instead compresses the flattened input tensor that batches all the gradients. Therefore, it is faster than powerSGD_hook(), but usually results in a much lower accuracy, unless matrix_approximation_rank is 1.

Warning

Increasing matrix_approximation_rank here may not necessarily increase the accuracy, because batching per-parameter tensors without column/row alignment can destroy low-rank structure. Therefore, the user should always consider powerSGD_hook() first, and only consider this variant when a satisfactory accuracy can be achieved when matrix_approximation_rank is 1.

Once gradient tensors are aggregated across all workers, this hook applies compression as follows:

  1. Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
  2. Creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
  3. Computes P, which is equal to MQ;
  4. Allreduces P;
  5. Orthogonalizes P;
  6. Computes Q, which is approximately equal to M^TP;
  7. Allreduces Q;
  8. Computes M, which is approximately equal to PQ^T.
  9. Truncates the input tensor to the original length.

Note that this communication hook enforces vanilla allreduce for the first state.start_powerSGD_iter iterations. This not only gives the user more control over the tradeoff between speedup and accuracy, but also helps abstract away some complexity of the internal optimization of DDP for future communication hook developers.

Parameters
  • state (PowerSGDState) – State information to configure the compression rate and support error feedback, warm start, etc. To tune the compression configs, mainly need to tune matrix_approximation_rank and start_powerSGD_iter.
  • bucket (dist._GradBucket) – Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode at this time, only exactly one tensor is stored in this bucket.
Returns

Future handler of the communication, which updates the gradients in place.

Example::
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)

Acknowledgements

Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper.

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/ddp_comm_hooks.html