torch.escape-hatch ====================== assume_constant_result ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ def __init__(self): super().__init__() @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]"): slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 = None return (slice_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='slice_1'), target=None)]) Range constraints: {} constrain_as_size_example ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class ConstrainAsSizeExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs constrain_as_size is used for values that NEED to be used for constructing tensor. """ def __init__(self): super().__init__() def forward(self, x): a = x.item() torch._constrain_as_size(a, min=0, max=5) return torch.zeros((a, 5)) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "i64[]"): _local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0 scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5 scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5) zeros: "f32[u4, 5]" = torch.ops.aten.zeros.default([_local_scalar_dense, 5], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return (zeros,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='zeros'), target=None)]) Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)} constrain_as_value_example ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch `, :doc:`torch.dynamic-value ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class ConstrainAsValueExample(torch.nn.Module): """ If the value is not known at tracing time, you can provide hint so that we can trace further. Please look at constrain_as_value and constrain_as_size APIs. constrain_as_value is used for values that don't need to be used for constructing tensor. """ def __init__(self): super().__init__() def forward(self, x, y): a = x.item() torch._constrain_as_value(a, min=0, max=5) if a < 6: return y.sin() return y.cos() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: "i64[]", arg1_1: "f32[5, 5]"): _local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0 scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5 scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5); _local_scalar_dense = None sin: "f32[5, 5]" = torch.ops.aten.sin.default(arg1_1); arg1_1 = None return (sin,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='sin'), target=None)]) Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)}