torch.func.linearize¶
- torch.func.linearize(func, *primals)¶
Returns the value of
func
atprimals
and linear approximation atprimals
.- Parameters
func (Callable) – A Python function that takes one or more arguments.
primals (Tensors) – Positional arguments to
func
that must all be Tensors. These are the values at which the function is linearly approximated.
- Returns
Returns a
(output, jvp_fn)
tuple containing the output offunc
applied toprimals
and a function that computes the jvp offunc
evaluated atprimals
.- Return type
linearize is useful if jvp is to be computed multiple times at
primals
. However, to achieve this, linearize saves intermediate computation and has higher memory requirements than directly applying jvp. So, if all thetangents
are known, it maybe more efficient to compute vmap(jvp) instead of using linearize.Note
linearize evaluates
func
twice. Please file an issue for an implementation with a single evaluation.- Example::
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>