.. _torch.export_db: ExportDB ======== ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into. If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo wih a module:export tag. .. toctree:: :maxdepth: 1 :caption: Tags torch.escape-hatch torch.dynamic-shape torch.cond python.closure torch.dynamic-value python.data-structure python.assert python.control-flow torch.map python.builtin python.object-model python.context-manager torch.operator torch.mutation Supported --------- assume_constant_result ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ def __init__(self): super().__init__() @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]", arg1: "i64[]"): slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(l_x_, 0, 0, 4); l_x_ = None return (slice_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg1'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='slice_1'), target=None)]) Range constraints: {} Equality constraints: [] autograd_function ^^^^^^^^^^^^^^^^^ .. note:: Tags: Support Level: SUPPORTED Original source code: .. code-block:: python import torch class MyAutogradFunction(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, grad_output): return grad_output + 1 class AutogradFunction(torch.nn.Module): """ TorchDynamo does not keep track of backward() on autograd functions. We recommend to use `allow_in_graph` to mitigate this problem. """ def forward(self, x): return MyAutogradFunction.apply(x) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): clone: "f32[3, 2]" = torch.ops.aten.clone.default(l_x_); l_x_ = None return (clone,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='clone'), target=None)]) Range constraints: {} Equality constraints: [] class_method ^^^^^^^^^^^^ .. note:: Tags: Support Level: SUPPORTED Original source code: .. code-block:: python import torch class ClassMethod(torch.nn.Module): """ Class methods are inlined during tracing. """ @classmethod def method(cls, x): return x + 1 def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 2) def forward(self, x): x = self.linear(x) return self.method(x) * self.__class__.method(x) * type(self).method(x) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[2]", l_x_: "f32[3, 4]"): t: "f32[4, 2]" = torch.ops.aten.t.default(arg0_1); arg0_1 = None addmm: "f32[3, 2]" = torch.ops.aten.addmm.default(arg1_1, l_x_, t); arg1_1 = l_x_ = t = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1) add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1) mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1); add = add_1 = None add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(addmm, 1); addmm = None mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2); mul = add_2 = None return (mul_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target='linear.weight'), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target='linear.bias'), InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='mul_1'), target=None)]) Range constraints: {} Equality constraints: [] cond_branch_class_method ^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.cond ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond class MySubModule(torch.nn.Module): def foo(self, x): return x.cos() def forward(self, x): return self.foo(x) class CondBranchClassMethod(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using class method in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def __init__(self): super().__init__() self.subm = MySubModule() def bar(self, x): return x.sin() def forward(self, x): return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3]"): cos: "f32[3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3]"): sin: "f32[3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_branch_nested_function ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.cond ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nested_function(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using nested function in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x): def inner_true_fn(y): return x + y return inner_true_fn(x) def false_fn(x): def inner_false_fn(y): return x - y return inner_false_fn(x) return cond(x.shape[0] < 10, true_fn, false_fn, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3]"): add: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None return (add,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3]"): sub: "f32[3]" = torch.ops.aten.sub.Tensor(arg0_1, arg0_1); arg0_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_branch_nonlocal_variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.cond ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nonlocal_variables(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. The code below will not work because capturing closure variables is not supported. ``` my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(y): nonlocal my_tensor_var, my_primitive_var return y + my_tensor_var + my_primitive_var def false_fn(y): nonlocal my_tensor_var, my_primitive_var return y - my_tensor_var - my_primitive_var return cond(x.shape[0] > 5, true_fn, false_fn, [x]) ``` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ my_tensor_var = x + 100 my_primitive_var = 3.14 def true_fn(x, y, z): return x + y + z def false_fn(x, y, z): return x - y - z return cond( x.shape[0] > 5, true_fn, false_fn, [x, my_tensor_var, torch.tensor(my_primitive_var)], ) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, _lifted_tensor_constant0: "f32[]", l_x_: "f32[6]"): add: "f32[6]" = torch.ops.aten.add.Tensor(l_x_, 100) lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(_lifted_tensor_constant0); _lifted_tensor_constant0 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(True, true_graph_0, false_graph_0, [l_x_, add, lift_fresh_copy]); true_graph_0 = false_graph_0 = l_x_ = add = lift_fresh_copy = None getitem: "f32[6]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"): add: "f32[6]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None add_1: "f32[6]" = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None return (add_1,) class (torch.nn.Module): def forward(self, arg0_1: "f32[6]", arg1_1: "f32[6]", arg2_1: "f32[]"): sub: "f32[6]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None sub_1: "f32[6]" = torch.ops.aten.sub.Tensor(sub, arg2_1); sub = arg2_1 = None return (sub_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='_lifted_tensor_constant0'), target='_lifted_tensor_constant0'), InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_closed_over_variable ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`python.closure ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond class CondClosedOverVariable(torch.nn.Module): """ torch.cond() supports branches closed over arbitrary variables. """ def forward(self, pred, x): def true_fn(val): return x * 2 def false_fn(val): return x - 2 return cond(pred, true_fn, false_fn, [x + 1]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_pred_: "b8[]", l_x_: "f32[3, 2]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(l_pred_, true_graph_0, false_graph_0, [l_x_]); l_pred_ = true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[3, 2]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None return (mul,) class (torch.nn.Module): def forward(self, arg0_1: "f32[3, 2]"): sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(arg0_1, 2); arg0_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_pred_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] cond_operands ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.cond ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch.export import Dim from functorch.experimental.control_flow import cond x = torch.randn(3, 2) y = torch.ones(2) dim0_x = Dim("dim0_x") def cond_operands(x, y): """ The operands passed to cond() must be: - a list of tensors - match arguments of `true_fn` and `false_fn` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x, y): return x + y def false_fn(x, y): return x - y return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[s0, 2]", l_y_: "f32[2]"): sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(l_x_, 0) gt: "Sym(s0 > 2)" = sym_size_int > 2; sym_size_int = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [l_x_, l_y_]); gt = true_graph_0 = false_graph_0 = l_x_ = l_y_ = None getitem: "f32[s0, 2]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"): add: "f32[s0, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (add,) class (torch.nn.Module): def forward(self, arg0_1: "f32[s0, 2]", arg1_1: "f32[2]"): sub: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (sub,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)} Equality constraints: [] cond_predicate ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.cond ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_predicate(x): """ The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ pred = x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[6, 4, 3]"): true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [l_x_]); true_graph_0 = false_graph_0 = l_x_ = None getitem: "f32[6, 4, 3]" = conditional[0]; conditional = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[6, 4, 3]"): cos: "f32[6, 4, 3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,) class (torch.nn.Module): def forward(self, arg0_1: "f32[6, 4, 3]"): sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] constrain_as_size_example ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def constrain_as_size_example(x): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs constrain_as_size is used for values that NEED to be used for constructing tensor. """ a = x.item() torch._constrain_as_size(a, min=0, max=5) return torch.ones((a, 5)) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "i64[]"): _local_scalar_dense: "Sym(i4)" = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None ge: "Sym(i4 >= 0)" = _local_scalar_dense >= 0 scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: "Sym(i4 <= 5)" = _local_scalar_dense <= 5 scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5) ones: "f32[i4, 5]" = torch.ops.aten.ones.default([_local_scalar_dense, 5], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return (ones,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='ones'), target=None)]) Range constraints: {i0: ValueRanges(lower=2, upper=5, is_bool=False), i1: ValueRanges(lower=2, upper=5, is_bool=False), i2: ValueRanges(lower=2, upper=5, is_bool=False), i3: ValueRanges(lower=2, upper=5, is_bool=False), i4: ValueRanges(lower=2, upper=5, is_bool=False)} Equality constraints: [] constrain_as_value_example ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def constrain_as_value_example(x, y): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs. constrain_as_value is used for values that don't need to be used for constructing tensor. """ a = x.item() torch._constrain_as_value(a, min=0, max=5) if a < 6: return y.sin() return y.cos() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "i64[]", l_y_: "f32[5, 5]"): _local_scalar_dense: "Sym(i4)" = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None ge: "Sym(i4 >= 0)" = _local_scalar_dense >= 0 scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: "Sym(i4 <= 5)" = _local_scalar_dense <= 5 scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5); _local_scalar_dense = None sin: "f32[5, 5]" = torch.ops.aten.sin.default(l_y_); l_y_ = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='sin'), target=None)]) Range constraints: {i0: ValueRanges(lower=0, upper=5, is_bool=False), i1: ValueRanges(lower=0, upper=5, is_bool=False), i2: ValueRanges(lower=0, upper=5, is_bool=False), i3: ValueRanges(lower=0, upper=5, is_bool=False), i4: ValueRanges(lower=0, upper=5, is_bool=False)} Equality constraints: [] decorator ^^^^^^^^^ .. note:: Tags: Support Level: SUPPORTED Original source code: .. code-block:: python import functools import torch def test_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) + 1 return wrapper class Decorator(torch.nn.Module): """ Decorators calls are inlined into the exported function during tracing. """ @test_decorator def forward(self, x, y): return x + y Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_args_1_: "f32[3, 2]", l_args_2_: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_args_1_, l_args_2_); l_args_1_ = l_args_2_ = None add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None return (add_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_args_1_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_args_2_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None)]) Range constraints: {} Equality constraints: [] dictionary ^^^^^^^^^^ .. note:: Tags: :doc:`python.data-structure ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dictionary(x, y): """ Dictionary structures are inlined and flattened along tracing. """ elements = {} elements["x2"] = x * x y = y * elements["x2"] return {"y": y} Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]", l_y_: "i64[]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(l_x_, l_x_); l_x_ = None mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(l_y_, mul); l_y_ = mul = None return (mul_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='mul_1'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_assert ^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.assert ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_assert(x): """ A basic usage of python assertion. """ # assertion with error message assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2" # assertion without error message assert x.shape[0] > 1 return x Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): return (l_x_,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_constructor ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_constructor(x): """ Tensor constructors should be captured with dynamic shape inputs rather than being baked in with static shape. """ return torch.ones(x.shape[0] * 2) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0: "f32[3, 2]"): ones: "f32[6]" = torch.ops.aten.ones.default([6], device = device(type='cpu'), pin_memory = False) return (ones,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='ones'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_if_guard ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class DynamicShapeIfGuard(torch.nn.Module): """ `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """ def forward(self, x): if x.shape[0] == 3: return x.cos() return x.sin() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2, 2]"): cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(l_x_); l_x_ = None return (cos,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='cos'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_map ^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`torch.map ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import map def dynamic_shape_map(xs, y): """ functorch map() maps a function over the first tensor dimension. """ def body(x, y): return x + y return map(body, xs, y) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_xs_: "f32[3, 2]", l_y_: "f32[2]"): body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, 1, l_xs_, l_y_); body_graph_0 = l_xs_ = l_y_ = None getitem: "f32[3, 2]" = map_impl[0]; map_impl = None return (getitem,) class (torch.nn.Module): def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"): add: "f32[2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return [add] Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_xs_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_slicing ^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_slicing(x): """ Slices with dynamic shape arguments should be captured into the graph rather than being baked in. """ return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(l_x_, 0, 0, 1); l_x_ = None slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2); slice_1 = None return (slice_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='slice_2'), target=None)]) Range constraints: {} Equality constraints: [] dynamic_shape_view ^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_view(x): """ Dynamic shapes should be propagated to view arguments instead of being baked into the exported graph. """ new_x_shape = x.size()[:-1] + (2, 5) x = x.view(*new_x_shape) return x.permute(0, 2, 1) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[10, 10]"): view: "f32[10, 2, 5]" = torch.ops.aten.view.default(l_x_, [10, 2, 5]); l_x_ = None permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]); view = None return (permute,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='permute'), target=None)]) Range constraints: {} Equality constraints: [] fn_with_kwargs ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.data-structure ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch ), tags={"python.data-structure"}, support_level=SupportLevel.SUPPORTED, ) def fn_with_kwargs(pos0, tuple0, *myargs, mykw0, **mykwargs): """ Keyword arguments are not supported at the moment. """ out = pos0 for arg in tuple0: out = out * arg for arg in myargs: out = out * arg out = out * mykw0 out = out * mykwargs["input0"] * mykwargs["input1"] return out Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, out: "f32[4]", arg: "f32[4]", arg_1: "f32[4]", arg_2: "f32[4]", arg_3: "f32[4]", l_mykw0_: "f32[4]", l_mykwargs_input0_: "f32[4]", l_mykwargs_input1_: "f32[4]"): mul: "f32[4]" = torch.ops.aten.mul.Tensor(out, arg); out = arg = None mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, arg_1); mul = arg_1 = None mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, arg_2); mul_1 = arg_2 = None mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, arg_3); mul_2 = arg_3 = None mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, l_mykw0_); mul_3 = l_mykw0_ = None mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, l_mykwargs_input0_); mul_4 = l_mykwargs_input0_ = None mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, l_mykwargs_input1_); mul_5 = l_mykwargs_input1_ = None return (mul_6,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='out'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg_1'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg_2'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg_3'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_mykw0_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_mykwargs_input0_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_mykwargs_input1_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='mul_6'), target=None)]) Range constraints: {} Equality constraints: [] list_contains ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape `, :doc:`python.data-structure `, :doc:`python.assert ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def list_contains(x): """ List containment relation can be checked on a dynamic shape or constants. """ assert x.size(-1) in [6, 2] assert x.size(0) not in [4, 5, 6] assert "monkey" not in ["cow", "pig"] return x + x Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, l_x_); l_x_ = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] list_unpack ^^^^^^^^^^^ .. note:: Tags: :doc:`python.data-structure `, :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python from typing import List import torch def list_unpack(args: List[torch.Tensor]): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. """ x, *y = args return x + y[0] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 2]", l_args_1_: "i64[]", arg2: "i64[]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, l_args_1_); x = l_args_1_ = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_args_1_'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg2'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] nested_function ^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.closure ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def nested_function(a, b): """ Nested functions are traced through. Side effects on global captures are not supported though. """ x = a + b z = a - b def closure(y): nonlocal x x += 1 return x * y + z return closure(x) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_a_: "f32[3, 2]", l_b_: "f32[2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_a_, l_b_) sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(l_a_, l_b_); l_a_ = l_b_ = None add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1); add = None mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1); add_1 = None add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub); mul = sub = None return (add_2,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_a_'), target=None), InputSpec(kind=, arg=TensorArgument(name='l_b_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add_2'), target=None)]) Range constraints: {} Equality constraints: [] null_context_manager ^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.context-manager ` Support Level: SUPPORTED Original source code: .. code-block:: python import contextlib import torch def null_context_manager(x): """ Null context manager in Python will be traced out. """ ctx = contextlib.nullcontext() with ctx: return x.sin() + x.cos() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): sin: "f32[3, 2]" = torch.ops.aten.sin.default(l_x_) cos: "f32[3, 2]" = torch.ops.aten.cos.default(l_x_); l_x_ = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] pytree_flatten ^^^^^^^^^^^^^^ .. note:: Tags: Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch.utils import _pytree as pytree def pytree_flatten(x): """ Pytree from PyTorch cannot be captured by TorchDynamo. """ y, spec = pytree.tree_flatten(x) return y[0] + 1 Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, child: "f32[3, 2]", arg1: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(child, 1); child = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='child'), target=None), InputSpec(kind=, arg=TensorArgument(name='arg1'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] scalar_output ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch.export import Dim x = torch.ones(3, 2) dim1_x = Dim("dim1_x") def scalar_output(x): """ Returning scalar values from the graph is supported, in addition to Tensor outputs. Symbolic shapes are captured and rank is specialized. """ return x.shape[1] + 1 Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, s0]"): # No stacktrace found for following nodes sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(l_x_, 1); l_x_ = None add: "Sym(s0 + 1)" = sym_size_int + 1; sym_size_int = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=SymIntArgument(name='add'), target=None)]) Range constraints: {s0: ValueRanges(lower=2, upper=oo, is_bool=False)} Equality constraints: [] specialized_attribute ^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: Support Level: SUPPORTED Original source code: .. code-block:: python from enum import Enum import torch class Animal(Enum): COW = "moo" class SpecializedAttribute(torch.nn.Module): """ Model attributes are specialized. """ def __init__(self): super().__init__() self.a = "moo" self.b = 4 def forward(self, x): if self.a == Animal.COW.value: return x * x + self.b else: raise ValueError("bad") Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(l_x_, l_x_); l_x_ = None add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4); mul = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] static_for_loop ^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class StaticForLoop(torch.nn.Module): """ A for loop with constant number of iterations should be unrolled in the exported graph. """ def __init__(self): super().__init__() def forward(self, x): ret = [] for i in range(10): # constant ret.append(i + x) return ret Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 0) add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 1) add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 2) add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 3) add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 4) add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 5) add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 6) add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 7) add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 8) add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 9); l_x_ = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_9'), target=None)]) Range constraints: {} Equality constraints: [] static_if ^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class StaticIf(torch.nn.Module): """ `if` statement with static predicate value should be traced through with the taken branch. """ def __init__(self): super().__init__() def forward(self, x): if len(x.shape) == 3: return x + torch.ones(1, 1, 1) return x Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2, 2]"): ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False) add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(l_x_, ones); l_x_ = ones = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] tensor_setattr ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.builtin ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def tensor_setattr(x, attr): """ setattr() call onto tensors is not supported. """ setattr(x, attr, torch.randn(3, 2)) return x + 4 Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]", arg1): add: "f32[3, 2]" = torch.ops.aten.add.Tensor(l_x_, 4); l_x_ = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=, arg=ConstantArgument(value='attr'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] type_reflection_method ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.builtin ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class A: @classmethod def func(cls, x): return 1 + x def type_reflection_method(x): """ type() calls on custom objects followed by method calls are not allowed due to its overly dynamic nature. """ a = A() return type(a).func(x) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 4]"): add: "f32[3, 4]" = torch.ops.aten.add.Tensor(l_x_, 1); l_x_ = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {} Equality constraints: [] You can rewrite the example above to something like the following: .. code-block:: python def type_reflection_method_rewrite(x): """ Custom object class methods will be inlined. """ return A.func(x) user_input_mutation ^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.mutation ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class UserInputMutation(torch.nn.Module): """ Directly mutate user input in forward """ def forward(self, x): x.mul_(2) return x.cos() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, l_x_: "f32[3, 2]"): mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(l_x_, 2); l_x_ = None cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul) return (mul, cos) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='mul'), target='l_x_'), OutputSpec(kind=, arg=TensorArgument(name='cos'), target=None)]) Range constraints: {} Equality constraints: [] Not Supported Yet ----------------- dynamic_shape_round ^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.builtin `, :doc:`torch.dynamic-shape ` Support Level: NOT_SUPPORTED_YET Original source code: .. code-block:: python import torch from torch.export import Dim x = torch.ones(3, 2) dim0_x = Dim("dim0_x") def dynamic_shape_round(x): """ Calling round on dynamic shapes is not supported. """ return x[: round(x.shape[0] / 2)] Result: .. code-block:: Unsupported: Calling round() on symbolic value is not supported. You can use floor() to implement this functionality model_attr_mutation ^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.object-model ` Support Level: NOT_SUPPORTED_YET Original source code: .. code-block:: python import torch class ModelAttrMutation(torch.nn.Module): """ Attribute mutation is not supported. """ def __init__(self): super().__init__() self.attr_list = [torch.ones(3, 2), torch.ones(3, 2)] def recreate_list(self): return [torch.zeros(3, 2), torch.zeros(3, 2)] def forward(self, x): self.attr_list = self.recreate_list() return x.sum() + self.attr_list[0].sum() Result: .. code-block:: AssertionError: Mutating module attribute attr_list during export. optional_input ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.object-model ` Support Level: NOT_SUPPORTED_YET Original source code: .. code-block:: python import torch class OptionalInput(torch.nn.Module): """ Tracing through optional input is not supported yet """ def forward(self, x, y=torch.ones(2, 3)): if y is not None: return x + y return x Result: .. code-block:: AssertionError: graph-captured input #2, of type , is not among original inputs of types: () torch_sym_min ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.operator ` Support Level: NOT_SUPPORTED_YET Original source code: .. code-block:: python import torch class TorchSymMin(torch.nn.Module): """ torch.sym_min operator is not supported in export. """ def forward(self, x): return x.sum() + torch.sym_min(x.size(0), 100) Result: .. code-block:: Unsupported: torch.* op returned non-Tensor int call_function