
This module implements the versions of those fused operations needed for quantization aware training.


class torch.nn.intrinsic.qat.ConvBn2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None) [source]

A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

  • ~ConvBn2d.freeze_bn
  • ~ConvBn2d.weight_fake_quant – fake quant module for weight


class torch.nn.intrinsic.qat.ConvBnReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None) [source]

A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d and torch.nn.ReLU.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.


~ConvBnReLU2d.weight_fake_quant – fake quant module for weight


class torch.nn.intrinsic.qat.ConvReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None) [source]

A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training.

We combined the interface of Conv2d and BatchNorm2d.


~ConvReLU2d.weight_fake_quant – fake quant module for weight


class torch.nn.intrinsic.qat.LinearReLU(in_features, out_features, bias=True, qconfig=None) [source]

A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in quantization aware training.

We adopt the same interface as torch.nn.Linear.

Similar to torch.nn.intrinsic.LinearReLU, with FakeQuantize modules initialized to default.


~LinearReLU.weight – fake quant module for weight


>>> m = nn.qat.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.