torch.func.functional_call¶
- torch.func.functional_call(module, parameter_and_buffer_dicts, args, kwargs=None, *, tie_weights=True, strict=False)¶
Performs a functional call on the module by replacing the module parameters and buffers with the provided ones.
Note
If the module has active parametrizations, passing a value in the
parameters_and_buffers
argument with the name set to the regular parameter name will completely disable the parametrization. If you want to apply the parametrization function to the value passed please set the key as{submodule_name}.parametrizations.{parameter_name}.original
.Note
If the module performs in-place operations on parameters/buffers, these will be reflected in the
parameters_and_buffers
input.Example:
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # does self.foo = self.foo + 1 >>> print(mod.foo) # tensor(0.) >>> functional_call(mod, a, torch.ones(())) >>> print(mod.foo) # tensor(0.) >>> print(a['foo']) # tensor(1.)
Note
If the module has tied weights, whether or not functional_call respects the tying is determined by the tie_weights flag.
Example:
>>> a = {'foo': torch.zeros(())} >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied >>> print(mod.foo) # tensor(1.) >>> mod(torch.zeros(())) # tensor(2.) >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated >>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())} >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
An example of passing mutliple dictionaries
a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer print(mod.weight) # tensor(...) print(mod.buffer) # tensor(...) x = torch.randn((1, 1)) print(x) functional_call(mod, a, x) # same as x print(mod.weight) # same as before functional_call
And here is an example of applying the grad transform over the parameters of a model.
import torch import torch.nn as nn from torch.func import functional_call, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) def compute_loss(params, x, t): y = functional_call(model, params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
Note
If the user does not need grad tracking outside of grad transforms, they can detach all of the parameters for better performance and memory usage
Example:
>>> detached_params = {k: v.detach() for k, v in model.named_parameters()} >>> grad_weights = grad(compute_loss)(detached_params, x, t) >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
This means that the user cannot call
grad_weight.backward()
. However, if they don’t need autograd tracking outside of the transforms, this will result in less memory usage and faster speeds.- Parameters:
module (torch.nn.Module) – the module to call
parameters_and_buffers (Dict[str, Tensor] or tuple of Dict[str, Tensor]) – the parameters that will be used in the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can be used together
args (Any or tuple) – arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict) – keyword arguments to be passed to the module call
tie_weights (bool, optional) – If True, then parameters and buffers tied in the original model will be treated as tied in the reparamaterized version. Therefore, if True and different values are passed for the tied paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and buffers unless the values passed for both weights are the same. Default: True.
strict (bool, optional) – If True, then the parameters and buffers passed in must match the parameters and buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will error. Default: False.
- Returns:
the result of calling
module
.- Return type:
Any