Source code for torch.ao.quantization.quantize_fx
from typing import Dict, Any, List, Callable, Tuple, Optional, Set
import torch
from torch.fx import GraphModule
from torch.fx._symbolic_trace import Tracer
from torch.fx.node import Target, Node, Argument
from torch.nn.intrinsic import _FusedModule
from .fx import Fuser # noqa: F401
from .fx import prepare, convert # noqa: F401
from .fx import get_tensorrt_backend_config_dict # noqa: F401
from .fx.graph_module import ObservedGraphModule
from .fx.qconfig_utils import (
check_is_valid_convert_custom_config_dict,
check_is_valid_fuse_custom_config_dict,
check_is_valid_prepare_custom_config_dict,
check_is_valid_qconfig_dict,
)
from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
raise ValueError(
"input model must be a GraphModule, "
+ "Got type:"
+ str(type(model))
+ " Please make "
+ "sure to follow the tutorials."
)
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
r""" Swap FloatFunctional with FXFloatFunctional
"""
modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.nn.quantized.FloatFunctional):
modules_to_swap.append(name)
else:
_swap_ff_with_fxff(module)
for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
def _fuse_fx(
graph_module: GraphModule,
is_qat: bool,
fuse_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
r""" Internal helper function to fuse modules in preparation for quantization
Args:
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(
graph_module, is_qat, fuse_custom_config_dict, backend_config_dict)
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)
def __enter__(self):
return
def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return
class QuantizationTracer(Tracer):
def __init__(
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes
# NB: initialized the module_type of top level module to None
# we are assuming people won't configure the model with the type of top level
# module here, since people can use "" for global config
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
self.record_stack_traces = True
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return (
(
m.__module__.startswith("torch.nn")
and not isinstance(m, torch.nn.Sequential)
)
or module_qualified_name in self.skipped_module_names
or type(m) in self.skipped_module_classes
or isinstance(m, _FusedModule)
)
def call_module(
self,
m: torch.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
module_qualified_name = self.path_of_module(m)
# Creating scope with information of current module
# scope will be restored automatically upon exit
with ScopeContextManager(self.scope, m, module_qualified_name):
return super().call_module(m, forward, args, kwargs)
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr)
self.node_name_to_scope[node.name] = (
self.scope.module_path,
self.scope.module_type,
)
return node
def _prepare_fx(
model: torch.nn.Module,
qconfig_dict: Any,
is_qat: bool,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
) -> ObservedGraphModule:
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_dict`, `prepare_custom_config_dict`, `equalization_qonfig_dict`:
see docs for :func:`~torch.ao.quantization.prepare_fx`
`is_standalone_module`: a boolean flag indicates whether we are
quantizing a standalone module or not, a standalone module
is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
the way we quantize standalone module is described in:
:func:`~torch.ao.quantization._prepare_standalone_module_fx`
"""
if prepare_custom_config_dict is None:
prepare_custom_config_dict = {}
if equalization_qconfig_dict is None:
equalization_qconfig_dict = {}
check_is_valid_qconfig_dict(qconfig_dict)
check_is_valid_prepare_custom_config_dict(prepare_custom_config_dict)
check_is_valid_qconfig_dict(equalization_qconfig_dict)
skipped_module_names = prepare_custom_config_dict.get(
"non_traceable_module_name", []
)
skipped_module_classes = prepare_custom_config_dict.get(
"non_traceable_module_class", []
)
# swap FloatFunctional with FXFloatFunctional
_swap_ff_with_fxff(model)
# symbolically trace the model
if not is_standalone_module:
# standalone module and custom module config are applied in top level module
standalone_module_name_configs = prepare_custom_config_dict.get(
"standalone_module_name", []
)
skipped_module_names += [config[0] for config in standalone_module_name_configs]
standalone_module_class_configs = prepare_custom_config_dict.get(
"standalone_module_class", []
)
skipped_module_classes += [
config[0] for config in standalone_module_class_configs
]
float_custom_module_classes = get_custom_module_class_keys(
prepare_custom_config_dict, "float_to_observed_custom_module_class"
)
skipped_module_classes += float_custom_module_classes
preserved_attributes = prepare_custom_config_dict.get("preserved_attributes", [])
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
graph_module = _fuse_fx(
graph_module,
is_qat,
prepare_custom_config_dict,
backend_config_dict)
prepared = prepare(
graph_module,
qconfig_dict,
is_qat,
tracer.node_name_to_scope,
prepare_custom_config_dict=prepare_custom_config_dict,
equalization_qconfig_dict=equalization_qconfig_dict,
backend_config_dict=backend_config_dict,
is_standalone_module=is_standalone_module,
)
for attr_name in preserved_attributes:
setattr(prepared, attr_name, getattr(model, attr_name))
return prepared
def _prepare_standalone_module_fx(
model: torch.nn.Module,
qconfig_dict: Any,
is_qat: bool,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
and will be quantized separately as one unit.
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
* model(GraphModule): prepared standalone module. It has these attributes:
* `_standalone_module_input_quantized_idxs(List[Int])`: a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
* `_standalone_module_output_quantized_idxs(List[Int])`: a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
return _prepare_fx(
model,
qconfig_dict,
is_qat,
prepare_custom_config_dict,
backend_config_dict=backend_config_dict,
is_standalone_module=True,
)
[docs]def fuse_fx(
model: torch.nn.Module, fuse_custom_config_dict: Optional[Dict[str, Any]] = None
) -> GraphModule:
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
* `model`: a torch.nn.Module model
* `fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.::
fuse_custom_config_dict = {
"additional_fuser_method_mapping": {
(Module1, Module2): fuse_module1_module2
}
# Attributes that are not used in forward function will
# be removed when constructing GraphModule, this is a list of attributes
# to preserve as an attribute of the GraphModule even when they are
# not used in the code, these attributes will also persist through deepcopy
"preserved_attributes": ["preserved_attr"],
}
Example::
from torch.ao.quantization import fuse_fx
m = Model().eval()
m = fuse_fx(m)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
assert not model.training, "fuse_fx only works on models in eval mode"
check_is_valid_fuse_custom_config_dict(fuse_custom_config_dict)
graph_module = torch.fx.symbolic_trace(model)
preserved_attributes: Set[str] = set()
if fuse_custom_config_dict:
preserved_attributes = set(
fuse_custom_config_dict.get("preserved_attributes", [])
)
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
return _fuse_fx(graph_module, False, fuse_custom_config_dict)
[docs]def prepare_fx(
model: torch.nn.Module,
qconfig_dict: Any,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
equalization_qconfig_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> ObservedGraphModule:
r""" Prepare a model for post training static quantization
Args:
* `model`: torch.nn.Module model, must be in eval mode
* `qconfig_dict`: qconfig_dict is a dictionary with the following configurations::
qconfig_dict = {
# optional, global config
"": qconfig?,
# optional, used for module and function types
# could also be split into module_types and function_types if we prefer
"object_type": [
(torch.nn.Conv2d, qconfig?),
(torch.nn.functional.add, qconfig?),
...,
],
# optional, used for module names
"module_name": [
("foo.bar", qconfig?)
...,
],
# optional, matched in order, first match takes precedence
"module_name_regex": [
("foo.*bar.*conv[0-9]+", qconfig?)
...,
],
# optional, used for matching object type invocations in a submodule by
# order
# TODO(future PR): potentially support multiple indices ('0,1') and/or
# ranges ('0:3').
"module_name_object_type_order": [
# fully_qualified_name, object_type, index, qconfig
("foo.bar", torch.nn.functional.linear, 0, qconfig?),
],
# priority (in increasing order):
# global, object_type, module_name_regex, module_name,
# module_name_object_type_order
# qconfig == None means fusion and quantization should be skipped for anything
# matching the rule
}
* `prepare_custom_config_dict`: customization configuration dictionary for quantization tool::
prepare_custom_config_dict = {
# optional: specify the path for standalone modules
# These modules are symbolically traced and quantized as one unit
"standalone_module_name": [
# module_name, qconfig_dict, prepare_custom_config_dict
("submodule.standalone",
None, # qconfig_dict for the prepare function called in the submodule,
# None means use qconfig from parent qconfig_dict
{"input_quantized_idxs": [], "output_quantized_idxs": []}), # prepare_custom_config_dict
{} # backend_config_dict, TODO: point to README doc when it's ready
],
"standalone_module_class": [
# module_class, qconfig_dict, prepare_custom_config_dict
(StandaloneModule,
None, # qconfig_dict for the prepare function called in the submodule,
# None means use qconfig from parent qconfig_dict
{"input_quantized_idxs": [0], "output_quantized_idxs": [0]}, # prepare_custom_config_dict
{}) # backend_config_dict, TODO: point to README doc when it's ready
],
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
# (only needed for static quantization)
"float_to_observed_custom_module_class": {
"static": {
CustomModule: ObservedCustomModule
}
},
# the qualified names for the submodule that are not symbolically traceable
"non_traceable_module_name": [
"non_traceable_module"
],
# the module classes that are not symbolically traceable
# we'll also put dynamic/weight_only custom module here
"non_traceable_module_class": [
NonTraceableModule
],
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
},
# Additioanl module mapping for qat
"additional_qat_module_mapping": {
torch.nn.intrinsic.ConvBn2d: torch.nn.qat.ConvBn2d
},
# Additional fusion patterns
"additional_fusion_pattern": {
(torch.nn.BatchNorm2d, torch.nn.Conv2d): ConvReluFusionhandler
},
# Additional quantization patterns
"additional_quant_pattern": {
torch.nn.Conv2d: ConvReluQuantizeHandler,
(torch.nn.ReLU, torch.nn.Conv2d): ConvReluQuantizeHandler,
}
# By default, inputs and outputs of the graph are assumed to be in
# fp32. Providing `input_quantized_idxs` will set the inputs with the
# corresponding indices to be quantized. Providing
# `output_quantized_idxs` will set the outputs with the corresponding
# indices to be quantized.
"input_quantized_idxs": [0],
"output_quantized_idxs": [0],
# Attributes that are not used in forward function will
# be removed when constructing GraphModule, this is a list of attributes
# to preserve as an attribute of the GraphModule even when they are
# not used in the code, these attributes will also persist through deepcopy
"preserved_attributes": ["preserved_attr"],
}
* `equalization_qconfig_dict`: equalization_qconfig_dict is a dictionary
with a similar structure as qconfig_dict except it will contain
configurations specific to equalization techniques such as input-weight
equalization.
* `backend_config_dict`: a dictionary that specifies how operators are quantized
in a backend, this includes how the operaetors are observed,
supported fusion patterns, how quantize/dequantize ops are
inserted, supported dtypes etc. The structure of the dictionary is still WIP
and will change in the future, please don't use right now.
Return:
A GraphModule with observer (configured by qconfig_dict), ready for calibration
Example::
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization import prepare_fx
float_model.eval()
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
qconfig_dict = {"": qconfig}
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
calibrate(prepared_model, sample_inference_data)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
assert not model.training, "prepare_fx only works for models in " + "eval mode"
return _prepare_fx(
model,
qconfig_dict,
False, # is_qat
prepare_custom_config_dict,
equalization_qconfig_dict,
backend_config_dict,
)
[docs]def prepare_qat_fx(
model: torch.nn.Module,
qconfig_dict: Any,
prepare_custom_config_dict: Optional[Dict[str, Any]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> ObservedGraphModule:
r""" Prepare a model for quantization aware training
Args:
* `model`: torch.nn.Module model, must be in train mode
* `qconfig_dict`: see :func:`~torch.ao.quantization.prepare_fx`
* `prepare_custom_config_dict`: see :func:`~torch.ao.quantization.prepare_fx`
* `backend_config_dict`: see :func:`~torch.ao.quantization.prepare_fx`
Return:
A GraphModule with fake quant modules (configured by qconfig_dict), ready for
quantization aware training
Example::
import torch
from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization import prepare_fx
qconfig = get_default_qat_qconfig('fbgemm')
def train_loop(model, train_data):
model.train()
for image, target in data_loader:
...
float_model.train()
qconfig_dict = {"": qconfig}
prepared_model = prepare_fx(float_model, qconfig_dict)
# Run calibration
train_loop(prepared_model, train_loop)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
assert model.training, "prepare_qat_fx only works for models in " + "train mode"
return _prepare_fx(
model,
qconfig_dict,
True, # is_qat
prepare_custom_config_dict,
backend_config_dict=backend_config_dict,
)
def _convert_fx(
graph_module: GraphModule,
is_reference: bool,
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config_dict is None:
convert_custom_config_dict = {}
_check_is_graph_module(graph_module)
check_is_valid_convert_custom_config_dict(convert_custom_config_dict)
quantized = convert(
graph_module,
is_reference,
convert_custom_config_dict,
is_standalone_module,
_remove_qconfig_flag=_remove_qconfig,
convert_qconfig_dict=qconfig_dict,
)
preserved_attributes = convert_custom_config_dict.get("preserved_attributes", [])
for attr_name in preserved_attributes:
setattr(quantized, attr_name, getattr(graph_module, attr_name))
return quantized
[docs]def convert_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a quantized model
Args:
* `graph_module`: A prepared and calibrated/trained model (GraphModule)
* `is_reference`: flag for whether to produce a reference quantized model,
which will be a common interface between pytorch quantization with
other backends like accelerators
* `convert_custom_config_dict`: dictionary for custom configurations for convert function::
convert_custom_config_dict = {
# additional object (module/operator) mappings that will overwrite the default
# module mappinng
"additional_object_mapping": {
"static": {
FloatModule: QuantizedModule,
float_op: quantized_op
},
"dynamic": {
FloatModule: DynamicallyQuantizedModule,
float_op: dynamically_quantized_op
},
},
# user will manually define the corresponding quantized
# module class which has a from_observed class method that converts
# observed custom module to quantized custom module
"observed_to_quantized_custom_module_class": {
"static": {
ObservedCustomModule: QuantizedCustomModule
},
"dynamic": {
ObservedCustomModule: QuantizedCustomModule
},
"weight_only": {
ObservedCustomModule: QuantizedCustomModule
}
},
# Attributes that are not used in forward function will
# be removed when constructing GraphModule, this is a list of attributes
# to preserve as an attribute of the GraphModule even when they are
# not used in the code
"preserved_attributes": ["preserved_attr"],
}
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
* `qconfig_dict`: qconfig_dict with either same keys as what is passed to
the qconfig_dict in `prepare_fx` API, with same values or `None`, or
additional keys with values set to `None`
For each entry whose value is set to None, we skip quantizing that entry in the model::
qconfig_dict = {
# used for object_type, skip quantizing torch.nn.functional.add
"object_type": [
(torch.nn.functional.add, None),
(torch.nn.functional.linear, qconfig_from_prepare)
...,
],
# sed for module names, skip quantizing "foo.bar"
"module_name": [
("foo.bar", None)
...,
],
}
Return:
A quantized model (GraphModule)
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
quantized_model = convert_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
return _convert_fx(
graph_module,
is_reference,
convert_custom_config_dict,
_remove_qconfig=_remove_qconfig,
qconfig_dict=qconfig_dict,
)
def _convert_standalone_module_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
) -> torch.nn.Module:
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
and convert it to a quantized model
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config_dict, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
return _convert_fx(
graph_module,
is_reference,
convert_custom_config_dict,
is_standalone_module=True,
)