Shortcuts

torch.escape-hatch

assume_constant_result

Note

Tags: torch.escape-hatch

Support Level: SUPPORTED

Original source code:

import torch
import torch._dynamo as torchdynamo



class AssumeConstantResult(torch.nn.Module):
    """
    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
    """

    def __init__(self):
        super().__init__()

    @torchdynamo.assume_constant_result
    def get_item(self, y):
        return y.int().item()

    def forward(self, x, y):
        return x[: self.get_item(y)]

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 2]", arg1: "i64[]"):
                slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(l_x_, 0, 0, 4);  l_x_ = None
            return (slice_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}
Equality constraints: []

constrain_as_size_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch



def constrain_as_size_example(x):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs
    constrain_as_size is used for values that NEED to be used for constructing
    tensor.
    """
    a = x.item()
    torch._constrain_as_size(a, min=0, max=5)
    return torch.ones((a, 5))

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "i64[]"):
                _local_scalar_dense: "Sym(i4)" = torch.ops.aten._local_scalar_dense.default(l_x_);  l_x_ = None

                ge: "Sym(i4 >= 0)" = _local_scalar_dense >= 0
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: "Sym(i4 <= 5)" = _local_scalar_dense <= 5
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5)

                ones: "f32[i4, 5]" = torch.ops.aten.ones.default([_local_scalar_dense, 5], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None
            return (ones,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {i0: ValueRanges(lower=2, upper=5, is_bool=False), i1: ValueRanges(lower=2, upper=5, is_bool=False), i2: ValueRanges(lower=2, upper=5, is_bool=False), i3: ValueRanges(lower=2, upper=5, is_bool=False), i4: ValueRanges(lower=2, upper=5, is_bool=False)}
Equality constraints: []

constrain_as_value_example

Note

Tags: torch.escape-hatch, torch.dynamic-value

Support Level: SUPPORTED

Original source code:

import torch



def constrain_as_value_example(x, y):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at constrain_as_value and constrain_as_size APIs.
    constrain_as_value is used for values that don't need to be used for constructing
    tensor.
    """
    a = x.item()
    torch._constrain_as_value(a, min=0, max=5)

    if a < 6:
        return y.sin()
    return y.cos()

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "i64[]", l_y_: "f32[5, 5]"):
                _local_scalar_dense: "Sym(i4)" = torch.ops.aten._local_scalar_dense.default(l_x_);  l_x_ = None

                ge: "Sym(i4 >= 0)" = _local_scalar_dense >= 0
            scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge);  ge = None
            _assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor = None
            le: "Sym(i4 <= 5)" = _local_scalar_dense <= 5
            scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le);  le = None
            _assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].');  scalar_tensor_1 = None

                sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5);  _local_scalar_dense = None

                sin: "f32[5, 5]" = torch.ops.aten.sin.default(l_y_);  l_y_ = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_x_'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_y_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {i0: ValueRanges(lower=0, upper=5, is_bool=False), i1: ValueRanges(lower=0, upper=5, is_bool=False), i2: ValueRanges(lower=0, upper=5, is_bool=False), i3: ValueRanges(lower=0, upper=5, is_bool=False), i4: ValueRanges(lower=0, upper=5, is_bool=False)}
Equality constraints: []

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources