torch.nn.qat

This module implements versions of the key nn modules Conv2d() and Linear() which run in FP32 but with rounding applied to simulate the effect of INT8 quantization.

Conv2d

class torch.nn.qat.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None) [source]

A Conv2d module attached with FakeQuantize modules for weight, used for quantization aware training.

We adopt the same interface as torch.nn.Conv2d, please see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d for documentation.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables

~Conv2d.weight_fake_quant – fake quant module for weight

classmethod from_float(mod) [source]

Create a qat module from a float module or qparams_dict

Args: mod a float module, either produced by torch.quantization utilities or directly from user

Linear

class torch.nn.qat.Linear(in_features, out_features, bias=True, qconfig=None) [source]

A linear module attached with FakeQuantize modules for weight, used for quantization aware training.

We adopt the same interface as torch.nn.Linear, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.

Similar to torch.nn.Linear, with FakeQuantize modules initialized to default.

Variables

~Linear.weight – fake quant module for weight

classmethod from_float(mod) [source]

Create a qat module from a float module or qparams_dict

Args: mod a float module, either produced by torch.quantization utilities or directly from user

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/torch.nn.qat.html