torch.nn.utils.prune.custom_from_mask

torch.nn.utils.prune.custom_from_mask(module, name, mask) [source]

Prunes tensor corresponding to parameter called name in module by applying the pre-computed mask in mask. Modifies module in place (and also return the modified module) by: 1) adding a named buffer called name+'_mask' corresponding to the binary mask applied to the parameter name by the pruning method. 2) replacing the parameter name by its pruned version, while the original (unpruned) parameter is stored in a new parameter named name+'_orig'.

Parameters
  • module (nn.Module) – module containing the tensor to prune
  • name (str) – parameter name within module on which pruning will act.
  • mask (Tensor) – binary mask to be applied to the parameter.
Returns

modified (i.e. pruned) version of the input module

Return type

module (nn.Module)

Examples

>>> m = prune.custom_from_mask(
        nn.Linear(5, 3), name='bias', mask=torch.Tensor([0, 1, 0])
    )
>>> print(m.bias_mask)
tensor([0., 1., 0.])

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.utils.prune.custom_from_mask.html