[docs]classLinear(nnq.Linear):r""" A dynamic quantized linear module with floating point tensor as inputs and outputs. We adopt the same interface as `torch.nn.Linear`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. Similar to :class:`torch.nn.Linear`, attributes will be randomly initialized at module creation time and will be overwritten later Attributes: weight (Tensor): the non-learnable quantized weights of the module which are of shape :math:`(\text{out\_features}, \text{in\_features})`. bias (Tensor): the non-learnable floating point bias of the module of shape :math:`(\text{out\_features})`. If :attr:`bias` is ``True``, the values are initialized to zero. Examples:: >>> # xdoctest: +SKIP >>> m = nn.quantized.dynamic.Linear(20, 30) >>> input = torch.randn(128, 20) >>> output = m(input) >>> print(output.size()) torch.Size([128, 30]) """# version used in this class is different from the parent class nnq.Linear_version=4def__init__(self,in_features,out_features,bias_=True,dtype=torch.qint8):super().__init__(in_features,out_features,bias_,dtype=dtype)# We don't muck around with buffers or attributes or anything here# to keep the module simple. *everything* is simply a Python attribute.# Serialization logic is explicitly handled in the below serialization and# deserialization modulesself.version=4defforward(self,x):# Note that we can handle self.bias == None case.ifself._packed_params.dtype==torch.qint8:ifself.versionisNoneorself.version<4:Y=torch.ops.quantized.linear_dynamic(x,self._packed_params._packed_params)else:Y=torch.ops.quantized.linear_dynamic(x,self._packed_params._packed_params,reduce_range=True)elifself._packed_params.dtype==torch.float16:Y=torch.ops.quantized.linear_dynamic_fp16(x,self._packed_params._packed_params)else:raiseRuntimeError('Unsupported dtype on dynamic quantized linear!')returnY.to(x.dtype)def_get_name(self):return'DynamicQuantizedLinear'defextra_repr(self):extra_repr_str='in_features={}, out_features={}, dtype={}'.format(self.in_features,self.out_features,self._packed_params.dtype)ifself._packed_params.dtype==torch.qint8:extra_repr_str+=f', qscheme={self.weight().qscheme()}'returnextra_repr_strdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):version=local_metadata.get('version',None)self.version=versionsuper()._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)
[docs]@classmethoddeffrom_float(cls,mod):r"""Create a dynamic quantized module from a float module or qparams_dict Args: mod (Module): a float module, either produced by torch.ao.quantization utilities or provided by the user """float_modules=[torch.nn.Linear,torch.nn.modules.linear.NonDynamicallyQuantizableLinear,torch.ao.nn.intrinsic.modules.fused.LinearReLU,torch.ao.nn.qat.dynamic.Linear]asserttype(mod)infloat_modules, \
'nn.quantized.dynamic.Linear.from_float only works for one of'+ \
str([float_mod.__name__forfloat_modinfloat_modules])asserthasattr(mod,'qconfig'),'Input float module must have qconfig defined'iftype(mod)==nni.LinearReLU:mod=mod[0]ifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:weight_observer=mod.qconfig.weight()else:# We have the circular import issues if we import the qconfig in the beginning of this file:# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the# import until we need it.fromtorch.ao.quantization.qconfigimportdefault_dynamic_qconfigweight_observer=default_dynamic_qconfig.weight()dtype=weight_observer.dtypeassertdtypein[torch.qint8,torch.float16],"The only supported dtypes for " \
f"dynamic quantized linear are qint8 and float16 got: {dtype}"weight_observer(mod.weight)ifdtype==torch.qint8:qweight=_quantize_weight(mod.weight.float(),weight_observer)elifdtype==torch.float16:qweight=mod.weight.float()else:raiseRuntimeError('Unsupported dtype specified for dynamic quantized Linear!')qlinear=cls(mod.in_features,mod.out_features,dtype=dtype)qlinear.set_weight_bias(qweight,mod.bias)returnqlinear
[docs]@classmethoddeffrom_reference(cls,ref_qlinear):""" Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized module Args: ref_qlinear (Module): a reference quantized module, either produced by torch.ao.quantization functions or provided by the user """qlinear=cls(ref_qlinear.in_features,ref_qlinear.out_features,dtype=ref_qlinear.weight_dtype)qweight=ref_qlinear.get_quantized_weight()bias=ref_qlinear.biasqlinear.set_weight_bias(qweight,bias)returnqlinear
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.