torch.nn.intrinsic.qat¶
This module implements the versions of those fused operations needed for quantization aware training.
ConvBn2d¶
-
class
torch.nn.intrinsic.qat.
ConvBn2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]¶ A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of
torch.nn.Conv2d
andtorch.nn.BatchNorm2d
.Similar to
torch.nn.Conv2d
, with FakeQuantize modules initialized to default.- Variables
~ConvBn2d.freeze_bn –
~ConvBn2d.weight_fake_quant – fake quant module for weight
ConvBnReLU2d¶
-
class
torch.nn.intrinsic.qat.
ConvBnReLU2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]¶ A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of
torch.nn.Conv2d
andtorch.nn.BatchNorm2d
andtorch.nn.ReLU
.Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.
- Variables
~ConvBnReLU2d.weight_fake_quant – fake quant module for weight
ConvReLU2d¶
-
class
torch.nn.intrinsic.qat.
ConvReLU2d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]¶ A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for weight for quantization aware training.
We combined the interface of
Conv2d
andBatchNorm2d
.- Variables
~ConvReLU2d.weight_fake_quant – fake quant module for weight
ConvBn3d¶
-
class
torch.nn.intrinsic.qat.
ConvBn3d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]¶ A ConvBn3d module is a module fused from Conv3d and BatchNorm3d, attached with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of
torch.nn.Conv3d
andtorch.nn.BatchNorm3d
.Similar to
torch.nn.Conv3d
, with FakeQuantize modules initialized to default.- Variables
~ConvBn3d.freeze_bn –
~ConvBn3d.weight_fake_quant – fake quant module for weight
ConvBnReLU3d¶
-
class
torch.nn.intrinsic.qat.
ConvBnReLU3d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=None, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]¶ A ConvBnReLU3d module is a module fused from Conv3d, BatchNorm3d and ReLU, attached with FakeQuantize modules for weight, used in quantization aware training.
We combined the interface of
torch.nn.Conv3d
andtorch.nn.BatchNorm3d
andtorch.nn.ReLU
.Similar to torch.nn.Conv3d, with FakeQuantize modules initialized to default.
- Variables
~ConvBnReLU3d.weight_fake_quant – fake quant module for weight
ConvReLU3d¶
-
class
torch.nn.intrinsic.qat.
ConvReLU3d
(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]¶ A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with FakeQuantize modules for weight for quantization aware training.
We combined the interface of
Conv3d
andBatchNorm3d
.- Variables
~ConvReLU3d.weight_fake_quant – fake quant module for weight
LinearReLU¶
-
class
torch.nn.intrinsic.qat.
LinearReLU
(in_features, out_features, bias=True, qconfig=None)[source]¶ A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for weight, used in quantization aware training.
We adopt the same interface as
torch.nn.Linear
.Similar to torch.nn.intrinsic.LinearReLU, with FakeQuantize modules initialized to default.
- Variables
~LinearReLU.weight (torch.Tensor) – fake quant module for weight
Examples:
>>> m = nn.qat.LinearReLU(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30])