torch.escape-hatch¶
assume_constant_result¶
Original source code:
import torch
import torch._dynamo as torchdynamo
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
def __init__(self):
super().__init__()
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 2]", arg1_1: "i64[]"):
slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 = None
return (slice_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}
constrain_as_size_example¶
Original source code:
import torch
class ConstrainAsSizeExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at constrain_as_value and constrain_as_size APIs
constrain_as_size is used for values that NEED to be used for constructing
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x):
a = x.item()
torch._constrain_as_size(a, min=0, max=5)
return torch.ones((a, 5))
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "i64[]"):
_local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0
scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None
le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5
scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense, min = 0, max = 5)
ones: "f32[u4, 5]" = torch.ops.aten.ones.default([_local_scalar_dense, 5], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
return (ones,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='ones'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)}
constrain_as_value_example¶
Original source code:
import torch
class ConstrainAsValueExample(torch.nn.Module):
"""
If the value is not known at tracing time, you can provide hint so that we
can trace further. Please look at constrain_as_value and constrain_as_size APIs.
constrain_as_value is used for values that don't need to be used for constructing
tensor.
"""
def __init__(self):
super().__init__()
def forward(self, x, y):
a = x.item()
torch._constrain_as_value(a, min=0, max=5)
if a < 6:
return y.sin()
return y.cos()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "i64[]", arg1_1: "f32[5, 5]"):
_local_scalar_dense: "Sym(u4)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge: "Sym(u4 >= 0)" = _local_scalar_dense >= 0
scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor = None
le: "Sym(u4 <= 5)" = _local_scalar_dense <= 5
scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_1 = torch.ops.aten._assert_async.msg(scalar_tensor_1, '_local_scalar_dense is outside of inline constraint [0, 5].'); scalar_tensor_1 = None
sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 0, max = 5); _local_scalar_dense = None
sin: "f32[5, 5]" = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
return (sin,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: ValueRanges(lower=0, upper=5, is_bool=False), u1: ValueRanges(lower=0, upper=5, is_bool=False), u4: ValueRanges(lower=0, upper=5, is_bool=False)}