torch.overrides¶
This module exposes various helper functions for the __torch_function__
protocol. See Extending torch for more detail on the
__torch_function__
protocol.
Functions¶
- torch.overrides.get_ignored_functions()[source]¶
Return public functions that cannot be overridden by
__torch_function__
.- Returns:
A tuple of functions that are publicly available in the torch API but cannot be overridden with
__torch_function__
. Mostly this is because none of the arguments of these functions are tensors or tensor-likes.- Return type:
Set[Callable]
Examples
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[source]¶
List functions that are overridable via __torch_function__
- Returns:
A dictionary that maps namespaces that contain overridable functions to functions in that namespace that can be overridden.
- Return type:
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[source]¶
Get a human readable string name for a function passed to __torch_function__
- Parameters:
callable (Callable) – Function to resolve the name of.
- Returns:
Name of the function; if eval’ed it should give back the input function.
- Return type:
- torch.overrides.get_testing_overrides()[source]¶
Return a dict containing dummy overrides for all overridable functions
- Returns:
A dictionary that maps overridable functions in the PyTorch API to lambda functions that have the same signature as the real function and unconditionally return -1. These lambda functions are useful for testing API coverage for a type that defines
__torch_function__
.- Return type:
Dict[Callable, Callable]
Examples
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]¶
Implement a function with checks for
__torch_function__
overrides.See torch::autograd::handle_torch_function for the equivalent of this function in the C++ implementation.
- Parameters:
public_api (function) – Function exposed by the public torch API originally called like
public_api(*args, **kwargs)
on which arguments are now being checked.relevant_args (iterable) – Iterable of arguments to check for __torch_function__ methods.
args (tuple) – Arbitrary positional arguments originally passed into
public_api
.kwargs (tuple) – Arbitrary keyword arguments originally passed into
public_api
.
- Returns:
Result from calling
implementation
or an__torch_function__
method, as appropriate.- Return type:
:raises TypeError : if no implementation is found.:
Example
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()¶
Check for __torch_function__ implementations in the elements of an iterable or if a __torch_function__ mode is enabled. Considers exact
Tensor
s andParameter
s non-dispatchable. Use this to guard a call tohandle_torch_function()
; don’t use it to test if something is Tensor-like, useis_tensor_like()
instead. :param relevant_args: Iterable or aguments to check for __torch_function__ methods. :type relevant_args: iterable- Returns:
True if any of the elements of relevant_args have __torch_function__ implementations, False otherwise.
- Return type:
See also
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact
Tensor
.
- torch.overrides.is_tensor_like(inp)[source]¶
Returns
True
if the passed-in input is a Tensor-like.Currently, this occurs whenever there’s a
__torch_function__
attribute on the type of the input.Examples
A subclass of tensor is generally a Tensor-like.
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
Built-in or user types aren’t usually Tensor-like.
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[source]¶
Returns True if the function passed in is a handler for a method or property belonging to
torch.Tensor
, as passed into__torch_function__
.Note
For properties, their
__get__
method must be passed in.This may be needed, in particular, for the following reasons:
Methods/properties sometimes don’t contain a __module__ slot.
They require that the first passed-in argument is an instance of
torch.Tensor
.
Examples
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- Return type:
- torch.overrides.wrap_torch_function(dispatcher)[source]¶
Wraps a given function with
__torch_function__
-related functionality.- Parameters:
dispatcher (Callable) – A callable that returns an iterable of Tensor-likes passed into the function.
Note
This decorator may reduce the performance of your code. Generally, it’s enough to express your code as a series of functions that, themselves, support __torch_function__. If you find yourself in the rare situation where this is not the case, e.g. if you’re wrapping a low-level library and you also need it to work for Tensor-likes, then this function is available.
Examples
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0