importcollectionsimportfunctoolsimportwarningsfromitertoolsimportproductfromtypingimportCallable,Dict,Iterable,List,Optional,Tuple,Unionimporttorchimporttorch.testingfromtorch._vmap_internalsimport_vmap,vmapfromtorch.overridesimportis_tensor_likefromtorch.typesimport_TensorOrTensors# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public# since they have been exposed from before we added `__all__` and we already maintain BC for them# We should eventually deprecate them and remove them from `__all__`__all__=["gradcheck","gradgradcheck","GradcheckError","get_numerical_jacobian","get_analytical_jacobian","get_numerical_jacobian_wrt_specific_input",]classGradcheckError(RuntimeError):r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`."""passdef_is_sparse_compressed_tensor(obj:torch.Tensor):returnobj.layoutin{torch.sparse_csr,torch.sparse_csc,torch.sparse_bsr,torch.sparse_bsc,}def_is_sparse_any_tensor(obj:torch.Tensor):return_is_sparse_compressed_tensor(obj)orobj.layoutistorch.sparse_coodef_is_float_or_complex_tensor(obj):returnis_tensor_like(obj)and(obj.is_floating_point()orobj.is_complex())def_allocate_jacobians_with_inputs(input_tensors:Tuple,numel_output)->Tuple[torch.Tensor,...]:# Makes zero-filled tensors from inputs. If `numel_output` is not None, for# each tensor in `input_tensors`, returns a new zero-filled tensor with height# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have# the same dtype and device as those of the corresponding input.out:List[torch.Tensor]=[]fortininput_tensors:if_is_float_or_complex_tensor(t)andt.requires_grad:out.append(t.new_zeros((t.numel(),numel_output),layout=torch.strided))returntuple(out)def_allocate_jacobians_with_outputs(output_tensors:Tuple,numel_input,dtype=None,device=None)->Tuple[torch.Tensor,...]:# Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size# (t.numel,).out:List[torch.Tensor]=[]options={"dtype":dtype,"device":device,"layout":torch.strided}fortinoutput_tensors:if_is_float_or_complex_tensor(t):out.append(t.new_zeros((numel_input,t.numel()),**options))returntuple(out)def_iter_tensors(x:Union[torch.Tensor,Iterable[torch.Tensor]],only_requiring_grad:bool=False)->Iterable[torch.Tensor]:ifis_tensor_like(x):# mypy doesn't narrow type of `x` to torch.Tensorifx.requires_gradornotonly_requiring_grad:# type: ignore[union-attr]yieldx# type: ignore[misc]elifisinstance(x,collections.abc.Iterable)andnotisinstance(x,str):foreleminx:yield from_iter_tensors(elem,only_requiring_grad)def_densify(x):# return a copy of sparse x with all unspecified elements# "replaced" with zero-valued elementsifisinstance(x,(list,tuple)):returntype(x)(map(_densify,x))elifnotis_tensor_like(x)orx.layoutin{torch.strided,torch._mkldnn}:# type: ignore[attr-defined] # no attr _mkldnnreturnxelifx.layoutistorch.sparse_coo:device=x.deviceindices_dtype=x._indices().dtypetmp=torch.ones(x.shape[:x.sparse_dim()],dtype=torch.int8,device=device)indices=tmp.nonzero().t().to(dtype=indices_dtype)values=torch.zeros((tmp.numel(),*x.shape[x.sparse_dim():]),dtype=x.dtype,device=device)x_coalesced=x.detach().coalesce()ifx_coalesced.numel()>0:stride=tmp.stride()flat_indices=(x_coalesced.indices().mul(torch.tensor(stride,dtype=indices_dtype,device=device).unsqueeze(1)).sum(0))values[flat_indices]=x_coalesced.values()return(torch.sparse_coo_tensor(indices,values,x.shape)._coalesced_(True).requires_grad_(x.requires_grad))elif_is_sparse_compressed_tensor(x):blocksize=(x.values().shape[1:3]ifx.layoutin{torch.sparse_bsr,torch.sparse_bsc}elseNone)compressed_indices=(x.crow_indices()ifx.layoutin{torch.sparse_csr,torch.sparse_bsr}elsex.ccol_indices())# We'll use intermediate sparse COO for simplicityr=_densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse(layout=x.layout,blocksize=blocksize)# Check that all elements are specified also after `to_sparse` op:dense_numel=r.values().numel()//max(1,r.values().shape[0])batch_numel=compressed_indices.numel()//compressed_indices.shape[-1]sparse_numel=r.numel()//max(1,dense_numel*batch_numel)ifsparse_numel!=r._nnz():raiseAssertionError(f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}")returnr.requires_grad_(x.requires_grad)elif_is_sparse_any_tensor(x):raiseNotImplementedError(x.layout)returnxdef_iter_tensor(x_tensor):# (Only used for slow gradcheck) Returns a generator that yields the following# elements at each iteration:# 1) a tensor: the same tensor is returned across all iterations. The tensor# is not the same as the original x_tensor as given as input - it is# prepared so that it can be modified in-place. Depending on whether the# input tensor is strided, sparse, or dense, the returned tensor may or may# not share storage with x_tensor.# 2) a tuple of indices that can be used with advanced indexing (yielded in# dictionary order)# 3) flattened index that will be used to index into the Jacobian tensor## For a tensor t with size (2, 2), _iter_tensor yields:# `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3`## where x is the t.data of the original tensor. Perturbing the entry of x# at index (1, 1) yields the 3rd column of the overall Jacobian matrix.if_is_sparse_any_tensor(x_tensor):defget_stride(size):dim=len(size)tmp=1stride=[0]*dimforiinreversed(range(dim)):stride[i]=tmptmp*=size[i]returnstridex_nnz=x_tensor._nnz()x_size=list(x_tensor.size())ifx_tensor.layoutistorch.sparse_coo:x_indices=x_tensor._indices().t()x_values=x_tensor._values()elifx_tensor.layoutistorch.sparse_csr:x_indices=torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(),x_tensor.col_indices()).t()x_values=x_tensor.values()elifx_tensor.layoutistorch.sparse_csc:x_indices=torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(),x_tensor.row_indices(),transpose=True).t()x_values=x_tensor.values()elifx_tensor.layoutistorch.sparse_bsr:x_block_values=x_tensor.values()x_blocksize=x_block_values.size()[1:3]x_indices=(torch._convert_indices_from_csr_to_coo(x_tensor.crow_indices(),x_tensor.col_indices()).repeat_interleave(x_blocksize[0]*x_blocksize[1],1).mul_(torch.tensor(x_blocksize,device=x_tensor.device).reshape(2,1)).add_(torch.stack(torch.where(torch.ones(x_blocksize,device=x_tensor.device))).repeat(1,x_nnz)).t())x_values=x_block_values.flatten(0,2)x_nnz=x_values.size(0)elifx_tensor.layoutistorch.sparse_bsc:x_block_values=x_tensor.values()x_blocksize=x_block_values.size()[1:3]x_indices=(torch._convert_indices_from_csr_to_coo(x_tensor.ccol_indices(),x_tensor.row_indices(),transpose=True).repeat_interleave(x_blocksize[0]*x_blocksize[1],1).mul_(torch.tensor(x_blocksize,device=x_tensor.device).reshape(2,1)).add_(torch.stack(torch.where(torch.ones(x_blocksize,device=x_tensor.device))).repeat(1,x_nnz)).t())x_values=x_block_values.flatten(0,2)x_nnz=x_values.size(0)else:raiseNotImplementedError(f"_iter_tensor for {x_tensor.layout} input")x_stride=get_stride(x_size)# Use .data here to get around the version checkx_values=x_values.dataforiinrange(x_nnz):x_value=x_values[i]forx_idxinproduct(*[range(m)forminx_values.size()[1:]]):indices=x_indices[i].tolist()+list(x_idx)d_idx=sum(indices[k]*x_stride[k]forkinrange(len(x_size)))yieldx_value,x_idx,d_idxelifx_tensor.layout==torch._mkldnn:# type: ignore[attr-defined]ford_idx,x_idxinenumerate(product(*[range(m)forminx_tensor.size()])):# this is really inefficient, but without indexing implemented, there's# not really a better way than converting back and forthx_tensor_dense=x_tensor.to_dense()yieldx_tensor_dense,x_idx,d_idxelse:# Use .data here to get around the version checkx_tensor=x_tensor.dataford_idx,x_idxinenumerate(product(*[range(m)forminx_tensor.size()])):yieldx_tensor,x_idx,d_idxdef_get_numerical_jacobian(fn,inputs,outputs=None,target=None,eps=1e-3,is_forward_ad=False)->List[Tuple[torch.Tensor,...]]:"""Compute the numerical Jacobian of `fn(inputs)` with respect to `target`. If not specified, targets are the input. Returns M * N Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral outputs. Args: fn: the function to compute the jacobian for inputs: inputs to `fn` outputs: provide precomputed outputs to avoid one extra invocation of fn target: the Tensors wrt whom Jacobians are calculated (default=`inputs`) eps: the magnitude of the perturbation during finite differencing (default=`1e-3`) is_forward_ad: if this numerical jacobian is computed to be checked wrt forward AD gradients (this is used for error checking only) Returns: A list of M N-tuples of tensors Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """jacobians:List[Tuple[torch.Tensor,...]]=[]ifoutputsisNone:outputs=_as_tuple(fn(*_as_tuple(inputs)))ifnotis_forward_adandany(o.is_complex()foroinoutputs):raiseValueError("Expected output to be non-complex. get_numerical_jacobian no ""longer supports functions that return complex outputs.")iftargetisNone:target=inputsinp_indices=[ifori,ainenumerate(target)ifis_tensor_like(a)anda.requires_grad]fori,(inp,inp_idx)inenumerate(zip(_iter_tensors(target,True),inp_indices)):jacobians+=[get_numerical_jacobian_wrt_specific_input(fn,inp_idx,inputs,outputs,eps,input=inp,is_forward_ad=is_forward_ad,)]returnjacobiansdefget_numerical_jacobian(fn,inputs,target=None,eps=1e-3,grad_out=1.0):"""Compute the numerical Jacobian for a given fn and its inputs. This is a Deprecated API. Args: fn: the function to compute the Jacobian for (must take inputs as a tuple) input: input to `fn` target: the Tensors wrt whom Jacobians are calculated (default=`input`) eps: the magnitude of the perturbation during finite differencing (default=`1e-3`) Returns: A list of Jacobians of `fn` (restricted to its first output) with respect to each input or target, if provided. Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """warnings.warn("get_numerical_jacobian was part of PyTorch's private API and not ""meant to be exposed. We are deprecating it and it will be removed ""in a future version of PyTorch. If you have a specific use for ""this or feature request for this to be a stable API, please file ""us an issue at https://github.com/pytorch/pytorch/issues/new")if(grad_out!=1.0):# grad_out param is only kept for backward compatibility reasonsraiseValueError("Expected grad_out to be 1.0. get_numerical_jacobian no longer ""supports values of grad_out != 1.0.")deffn_pack_inps(*inps):returnfn(inps)jacobians=_get_numerical_jacobian(fn_pack_inps,inputs,None,target,eps)returntuple(jacobian_for_each_output[0]forjacobian_for_each_outputinjacobians)def_compute_numerical_gradient(fn,entry,v,norm_v,nbhd_checks_fn):# Performs finite differencing by perturbing `entry` in-place by `v` and# returns the gradient of each of the outputs wrt to x at idx.if_is_sparse_compressed_tensor(entry):# sparse compressed tensors don't implement sub/add/copy_# yet. However, in non-masked semantics context entry and v# have the same sparse indices ...assertentry.layout==v.layout,(entry.layout,v.layout)assertentry._nnz()==v._nnz(),(entry._nnz(),v._nnz(),entry.shape)# ... the finite differencing can be performed on values only:entry=entry.values()v=v.values()# we'll detach to avoid backward computations that sparse# tensors have limited support for.entry=entry.detach()orig=entry.clone()entry.copy_(orig-v)outa=fn()entry.copy_(orig+v)outb=fn()entry.copy_(orig)defcompute(a,b):nbhd_checks_fn(a,b)ret=(b-a)/(2*norm_v)returnret.detach().reshape(-1)returntuple(compute(a,b)for(a,b)inzip(outa,outb))def_compute_numerical_jvps_wrt_specific_input(jvp_fn,delta,input_is_complex,is_forward_ad=False)->List[torch.Tensor]:# Computing the jacobian only works for real delta# For details on the algorithm used here, refer:# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf# s = fn(z) where z = x for real valued input# and z = x + yj for complex valued inputjvps:List[torch.Tensor]=[]ds_dx_tup=jvp_fn(delta[0]ifisinstance(delta,tuple)elsedelta)ifinput_is_complex:# C -> Rds_dy_tup=(jvp_fn(delta[1]*1j)ifisinstance(delta,tuple)elsejvp_fn(delta*1j))fords_dx,ds_dyinzip(ds_dx_tup,ds_dy_tup):assertnotds_dx.is_complex()# conjugate wirtinger derivativeconj_w_d=ds_dx+ds_dy*1jjvps.append(conj_w_d)else:fords_dxinds_dx_tup:# R -> R or (R -> C for the forward AD case)assertis_forward_adornotds_dx.is_complex()jvps.append(ds_dx)returnjvpsdef_combine_jacobian_cols(jacobians_cols:Dict[int,List[torch.Tensor]],outputs,input,numel)->Tuple[torch.Tensor,...]:# jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor# we return a list that maps output_idx -> full jacobian Tensorjacobians=_allocate_jacobians_with_outputs(outputs,numel,dtype=input.dtypeifinput.dtype.is_complexelseNone)fori,jacobianinenumerate(jacobians):fork,vinjacobians_cols.items():jacobian[k]=v[i]returnjacobiansdef_prepare_input(input:torch.Tensor,maybe_perturbed_input:Optional[torch.Tensor],fast_mode=False)->torch.Tensor:# Prepares the inputs to be passed into the function while including the new# modified input.ifinput.layout==torch._mkldnn:# type: ignore[attr-defined] # no attr _mkldnn# Convert back to mkldnnifmaybe_perturbed_inputisnotNone:returnmaybe_perturbed_input.to_mkldnn()else:returninputelif_is_sparse_any_tensor(input):iffast_modeandmaybe_perturbed_inputisnotNone:# entry is already a "cloned" version of the original tensor# thus changes to entry are not reflected in the inputreturnmaybe_perturbed_inputelse:returninputelse:# We cannot use entry (input.data) if we want gradgrad to work because# fn (in the gradgrad case) needs to compute grad wrt inputreturninputdef_check_outputs_same_dtype_and_shape(output1,output2,eps,idx=None)->None:# Check that the returned outputs don't have different dtype or shape when you# perturb the inputon_index="on index {idx} "ifidxisnotNoneelse""assertoutput1.shape==output2.shape,(f"Expected `func` to return outputs with the same shape"f" when inputs are perturbed {on_index}by {eps}, but got:"f" shapes {output1.shape} and {output2.shape}.")assertoutput1.dtype==output2.dtype,(f"Expected `func` to return outputs with the same dtype"f" when inputs are perturbed {on_index}by {eps}, but got:"f" dtypes {output1.dtype} and {output2.dtype}.")defget_numerical_jacobian_wrt_specific_input(fn,input_idx,inputs,outputs,eps,input=None,is_forward_ad=False)->Tuple[torch.Tensor,...]:# Computes the numerical jacobians wrt to a single input. Returns N jacobian# tensors, where N is the number of outputs. We use a dictionary for# jacobian_cols because indices aren't necessarily consecutive for sparse inputs# When we perturb only a single element of the input tensor at a time, the jvp# is equivalent to a single col of the Jacobian matrix of fn.jacobian_cols:Dict[int,List[torch.Tensor]]={}input=inputs[input_idx]ifinputisNoneelseinputassertinput.requires_gradforx,idx,d_idxin_iter_tensor(input):wrapped_fn=_with_prepare_inputs(fn,inputs,input_idx,x)input_to_perturb=x[idx]nbhd_checks_fn=functools.partial(_check_outputs_same_dtype_and_shape,idx=idx,eps=eps)jvp_fn=_get_numerical_jvp_fn(wrapped_fn,input_to_perturb,eps,nbhd_checks_fn)jacobian_cols[d_idx]=_compute_numerical_jvps_wrt_specific_input(jvp_fn,eps,x.is_complex(),is_forward_ad)return_combine_jacobian_cols(jacobian_cols,outputs,input,input.numel())def_get_analytical_jacobian_forward_ad(fn,inputs,outputs,*,check_grad_dtypes=False,all_u=None)->Tuple[Tuple[torch.Tensor,...],...]:"""Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. Return N * M Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral outputs. Contrary to other functions here, this function requires "inputs" to actually be used by the function. The computed value is expected to be wrong if the function captures the inputs by side effect instead of using the passed ones (many torch.nn tests do this). Args: fn: the function to compute the jacobian for inputs: inputs to `fn` outputs: provide precomputed outputs to avoid one extra invocation of fn check_grad_dtypes: if True, will check that the gradient dtype are valid all_u (optional): if provided, the Jacobian will be right multiplied with this vector Returns: A tuple of M N-tuples of tensors """# To avoid early import issuesfwAD=torch.autograd.forward_adtensor_inputs=tuple(iforiininputsifis_tensor_like(i)andi.requires_grad)ifany(i.is_complex()foriintensor_inputs):raiseValueError("Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad.")ifall_u:jacobians=tuple(_allocate_jacobians_with_outputs(outputs,1)foriintensor_inputs)else:jacobians=tuple(_allocate_jacobians_with_outputs(outputs,i.numel())foriintensor_inputs)withfwAD.dual_level():fw_grads=[]dual_inputs=[]fori,inpinenumerate(inputs):ifis_tensor_like(inp)andinp.requires_grad:ifinp.layout==torch._mkldnn:# type: ignore[attr-defined]raiseValueError("MKLDNN inputs are not support for forward AD gradcheck.")inp=fwAD.make_dual(inp.detach(),torch.zeros_like(inp))# If inp is a differentiable view, the dual might not be the tangent given to# make_dual, so read it explicitly from the dual tensorfw_grads.append(fwAD.unpack_dual(inp)[1])dual_inputs.append(inp)ifall_u:# Do the full reduction in one pass# To be consistent with numerical evaluation, we actually compute one reduction per inputfori,(fw_grad,u)inenumerate(zip(fw_grads,all_u)):fw_grad.copy_(u.view_as(fw_grad))raw_outputs=_as_tuple(fn(*dual_inputs))dual_outputs=filter(_is_float_or_complex_tensor,raw_outputs)forindex_o,d_oinenumerate(dual_outputs):val,res=fwAD.unpack_dual(d_o)if(check_grad_dtypesandresisnotNoneandval.is_complex()!=res.is_complex()):raiseGradcheckError("Forward AD gradient has dtype mismatch.")# Remove extra dimension of size 1 corresponding to the reduced inputjacobians[i][index_o].squeeze_(0)ifresisNone:jacobians[i][index_o].zero_()else:jacobians[i][index_o].copy_(res.reshape(-1))fw_grad.zero_()else:# Reconstruct the full Jacobian column by columnfori,fw_gradinenumerate(fw_grads):forlin_idx,grad_idxinenumerate(product(*[range(m)forminfw_grad.size()])):fw_grad[grad_idx]=1.0raw_outputs=_as_tuple(fn(*dual_inputs))dual_outputs=filter(_is_float_or_complex_tensor,raw_outputs)forindex_o,d_oinenumerate(dual_outputs):val,res=fwAD.unpack_dual(d_o)if(check_grad_dtypesandresisnotNoneandval.is_complex()!=res.is_complex()):raiseGradcheckError("Forward AD gradient has dtype mismatch.")ifresisNone:jacobians[i][index_o][lin_idx].zero_()else:jacobians[i][index_o][lin_idx].copy_(res.reshape(-1))fw_grad[grad_idx]=0.0returnjacobiansdef_get_input_to_perturb(input):# Prepare the input so that it can be modified in-place and do certain# operations that require the tensor to have strides. If fast_mode=False,# _iter_tensor would handle the below cases:ifinput.layout==torch._mkldnn:# type: ignore[attr-defined] # no attr _mkldnn# Convert to dense so we can perform operations that require strided tensorsinput_to_perturb=input.to_dense()elif_is_sparse_any_tensor(input):# Clone because input may require grad, and copy_ calls resize_,# which is not allowed for .datainput_to_perturb=input.clone()else:input_to_perturb=input.datareturninput_to_perturbdef_with_prepare_inputs(fn,inputs,input_idx,input_to_perturb,fast_mode=False):# Wraps `fn` so that its inputs are already supplieddefwrapped_fn():inp=tuple(_prepare_input(a,input_to_perturbifi==input_idxelseNone,fast_mode)ifis_tensor_like(a)elseafori,ainenumerate(_as_tuple(inputs)))returntuple(a.clone()forain_as_tuple(fn(*inp)))returnwrapped_fndef_get_numerical_jvp_fn(wrapped_fn,input_to_perturb,eps,nbhd_checks_fn):# Wraps jvp_fn so that certain arguments are already supplieddefjvp_fn(delta):return_compute_numerical_gradient(wrapped_fn,input_to_perturb,delta,eps,nbhd_checks_fn)returnjvp_fndef_reshape_tensor_or_tuple(u,shape):# We don't need to reshape when input corresponding to u is sparseifisinstance(u,tuple):ifnot_is_sparse_any_tensor(u[0]):return(u[0].reshape(shape),u[1].reshape(shape))else:ifnot_is_sparse_any_tensor(u):returnu.reshape(shape)returnudef_mul_tensor_or_tuple(u,k):ifisinstance(u,tuple):return(k*u[0],k*u[1])else:returnk*udef_get_numerical_jvp_wrt_specific_input(fn,input_idx,inputs,u,eps,is_forward_ad=False)->List[torch.Tensor]:input=inputs[input_idx]input_to_perturb=_get_input_to_perturb(input)wrapped_fn=_with_prepare_inputs(fn,inputs,input_idx,input_to_perturb,True)nbhd_checks_fn=functools.partial(_check_outputs_same_dtype_and_shape,eps=eps)jvp_fn=_get_numerical_jvp_fn(wrapped_fn,input_to_perturb,eps,nbhd_checks_fn)u=_reshape_tensor_or_tuple(u,input_to_perturb.shape)u=_mul_tensor_or_tuple(u,eps)return_compute_numerical_jvps_wrt_specific_input(jvp_fn,u,input.is_complex(),is_forward_ad)def_get_numerical_vJu(fn,inputs,inp_indices,func_out,all_u,all_v,eps,is_forward_ad):# Note that all_v can also be None, in that case, this function only computes Ju.reduced_jacobians:List[List[torch.Tensor]]=[]fori,(inp_idx,u)inenumerate(zip(inp_indices,all_u)):all_Ju=_get_numerical_jvp_wrt_specific_input(fn,inp_idx,inputs,u,eps,is_forward_ad)# Filter out the Ju for non floating point outputsfiltered_Ju=[]func_out=_as_tuple(func_out)assertlen(all_Ju)==len(func_out)forJu,outputinzip(all_Ju,func_out):if_is_float_or_complex_tensor(output):filtered_Ju.append(Ju)else:# TODO: handle the other Jupassifall_visnotNone:jacobian_scalars:List[torch.Tensor]=[]forv,Juinzip(all_v,filtered_Ju):jacobian_scalars.append(_dot_with_type_promotion(v,Ju))reduced_jacobians.append(jacobian_scalars)else:reduced_jacobians.append(filtered_Ju)returnreduced_jacobiansdef_check_jacobians_equal(j1,j2,atol):# Check whether the max difference between two Jacobian tensors are within some# tolerance `atol`.forj1_x,j2_xinzip(j1,j2):ifj1_x.numel()!=0and(j1_x-j2_x).abs().max()>atol:returnFalsereturnTruedef_stack_and_check_tensors(list_of_list_of_tensors,inputs,numel_outputs)->Tuple[Tuple[torch.Tensor,...],bool,bool]:# For the ith tensor in the inner list checks whether it has the same size and# dtype as the ith differentiable input.out_jacobians=_allocate_jacobians_with_inputs(inputs,numel_outputs)diff_input_list=list(_iter_tensors(inputs,True))correct_grad_sizes=Truecorrect_grad_types=Truefori,tensor_listinenumerate(list_of_list_of_tensors):inp=diff_input_list[i]out_jacobian=out_jacobians[i]forj,tensorinenumerate(tensor_list):iftensorisnotNoneandtensor.size()!=inp.size():correct_grad_sizes=FalseeliftensorisnotNoneandtensor.dtype!=inp.dtype:correct_grad_types=FalseiftensorisNone:out_jacobian[:,j].zero_()else:dense=(tensor.to_dense()ifnottensor.layout==torch.stridedelsetensor)assertout_jacobian[:,j].numel()==dense.numel()out_jacobian[:,j]=dense.reshape(-1)returnout_jacobians,correct_grad_sizes,correct_grad_typesFAILED_NONDET_MSG="""\nNOTE: If your op relies on non-deterministic operations i.e., it is listed here:https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.htmlthis failure might be expected.If you are adding a new operator, please file an issue and then use one of theworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.If the test- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `nondet_tol=<tol>` as a keyword argument.- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `gradcheck_nondet_tol=<tol>`.- is a Module test (e.g., in common_nn.py), then modify the corresponding module_test entry to have `gradcheck_nondet_tol=<tol>`"""def_check_analytical_jacobian_attributes(inputs,output,nondet_tol,check_grad_dtypes,fast_mode=False,v=None)->Tuple[torch.Tensor,...]:# This is used by both fast and slow mode:# - For slow mode, vjps[i][j] is the jth row the Jacobian wrt the ith# input.# - For fast mode, vjps[i][0] is a linear combination of the rows# of the Jacobian wrt the ith inputdiff_input_list=list(_iter_tensors(inputs,True))defvjp_fn(grad_output):returntorch.autograd.grad(output,diff_input_list,grad_output,retain_graph=True,allow_unused=True)# Compute everything twice to check for nondeterminism (which we call reentrancy)iffast_mode:vjps1=_get_analytical_vjps_wrt_specific_output(vjp_fn,output.clone(),v)vjps2=_get_analytical_vjps_wrt_specific_output(vjp_fn,output.clone(),v)else:vjps1=_compute_analytical_jacobian_rows(vjp_fn,output.clone())vjps2=_compute_analytical_jacobian_rows(vjp_fn,output.clone())output_numel=output.numel()ifnotfast_modeelse1jacobians1,types_ok,sizes_ok=_stack_and_check_tensors(vjps1,inputs,output_numel)jacobians2,_,_=_stack_and_check_tensors(vjps2,inputs,output_numel)reentrant=_check_jacobians_equal(jacobians1,jacobians2,nondet_tol)ifnottypes_okandcheck_grad_dtypes:raiseGradcheckError("Gradient has dtype mismatch")ifnotsizes_ok:raiseGradcheckError("Analytical gradient has incorrect size")ifnotreentrant:raiseGradcheckError("Backward is not reentrant, i.e., running backward with ""same input and grad_output multiple times gives different values, ""although analytical gradient matches numerical gradient."f"The tolerance for nondeterminism was {nondet_tol}."+FAILED_NONDET_MSG)returnjacobians1def_get_analytical_vJu_backward_mode(inputs,outputs,nondet_tol,check_grad_dtypes,all_v,all_u):reduced_jacobians:List[List[torch.Tensor]]=[]foroutput,vinzip(outputs,all_v):all_vJ=_check_analytical_jacobian_attributes(inputs,output,nondet_tol,check_grad_dtypes,fast_mode=True,v=v)jacobian_scalars:List[torch.Tensor]=[]forvJ,uinzip(all_vJ,all_u):# Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse# the error checking logic from slow modevJ=vJ.T.squeeze(0)ifvJ.is_complex():# C -> Rtv=torch.view_as_real(vJ.resolve_conj())tr=tv.select(-1,0)ti=tv.select(-1,1)jacobian_scalars.append(tr.dot(u[0])+1j*ti.dot(u[1]))else:# R -> Rjacobian_scalars.append(vJ.dot(u))reduced_jacobians.append(jacobian_scalars)returnreduced_jacobiansdefget_analytical_jacobian(inputs,output,nondet_tol=0.0,grad_out=1.0):# Replicates the behavior of the old get_analytical_jacobian before the refactor# This shares much of its code with _check_analytical_jacobian_attributeswarnings.warn("get_analytical_jacobian was part of PyTorch's private API and not ""meant to be exposed. We are deprecating it and it will be removed ""in a future version of PyTorch. If you have a specific use for ""this or feature request for this to be a stable API, please file ""us an issue at https://github.com/pytorch/pytorch/issues/new")if(grad_out!=1.0):# grad_out param is only kept for backward compatibility reasonsraiseValueError("Expected grad_out to be 1.0. get_analytical_jacobian no longer ""supports values of grad_out != 1.0.")ifoutput.is_complex():raiseValueError("Expected output to be non-complex. get_analytical_jacobian no ""longer supports functions that return complex outputs.")diff_input_list=list(_iter_tensors(inputs,True))defvjp_fn(grad_output):returntorch.autograd.grad(output,diff_input_list,grad_output,retain_graph=True,allow_unused=True)# Compute everything twice to check for nondeterminism (which we call reentrancy)vjps1=_compute_analytical_jacobian_rows(vjp_fn,output.clone())vjps2=_compute_analytical_jacobian_rows(vjp_fn,output.clone())output_numel=output.numel()jacobians1,types_ok,sizes_ok=_stack_and_check_tensors(vjps1,inputs,output_numel)jacobians2,_,_=_stack_and_check_tensors(vjps2,inputs,output_numel)reentrant=_check_jacobians_equal(jacobians1,jacobians2,nondet_tol)returnjacobians1,reentrant,sizes_ok,types_okdef_get_analytical_jacobian(inputs,outputs,input_idx,output_idx):# Computes the analytical Jacobian in slow mode for a single input-output pair.# Forgoes performing checks on dtype, shape, and reentrancy.jacobians=_check_analytical_jacobian_attributes(inputs,outputs[output_idx],nondet_tol=float("inf"),check_grad_dtypes=False)returnjacobians[input_idx]def_compute_analytical_jacobian_rows(vjp_fn,sample_output)->List[List[Optional[torch.Tensor]]]:# Computes Jacobian row-by-row using backward function `vjp_fn` = v^T J# NB: this function does not assume vjp_fn(v) to return tensors with the same# number of elements for different v. This is checked when we later combine the# rows into a single tensor.grad_out_base=torch.zeros_like(sample_output,memory_format=torch.legacy_contiguous_format)flat_grad_out=grad_out_base.view(-1)# jacobians_rows[i][j] represents the jth row of the ith inputjacobians_rows:List[List[Optional[torch.Tensor]]]=[]forjinrange(flat_grad_out.numel()):flat_grad_out.zero_()flat_grad_out[j]=1.0grad_inputs=vjp_fn(grad_out_base)fori,d_xinenumerate(grad_inputs):ifj==0:jacobians_rows.append([])jacobians_rows[i]+=[d_x.clone()ifisinstance(d_x,torch.Tensor)elseNone]returnjacobians_rowsdef_get_analytical_vjps_wrt_specific_output(vjp_fn,sample_output,v)->List[List[Optional[torch.Tensor]]]:vjps:List[List[Optional[torch.Tensor]]]=[]grad_inputs=vjp_fn(v.reshape(sample_output.shape))forvjpingrad_inputs:vjps.append([vjp.clone()ifisinstance(vjp,torch.Tensor)elseNone])returnvjpsdef_check_inputs(tupled_inputs)->bool:# Make sure that gradients are saved for at least one inputany_input_requiring_grad=Falseforidx,inpinenumerate(tupled_inputs):ifis_tensor_like(inp)andinp.requires_grad:ifnot(inp.dtype==torch.float64orinp.dtype==torch.complex128):warnings.warn(f"Input #{idx} requires gradient and ""is not a double precision floating point or complex. ""This check will likely fail if all the inputs are ""not of double precision floating point or complex. ")ifinp.is_sparse:content=inp._values()elif_is_sparse_compressed_tensor(inp):content=inp.values()else:content=inp# TODO: To cover more problematic cases, replace stride = 0 check with# "any overlap in memory" once we have a proper function to check it.ifcontent.layoutisnottorch._mkldnn:# type: ignore[attr-defined]ifnotall(st>0orsz<=1forst,szinzip(content.stride(),content.size())):raiseRuntimeError(f"The {idx}th input has a dimension with stride 0. gradcheck only ""supports inputs that are non-overlapping to be able to ""compute the numerical gradients correctly. You should call "".contiguous on the input before passing it to gradcheck.")any_input_requiring_grad=Trueifnotany_input_requiring_grad:raiseValueError("gradcheck expects at least one input tensor to require gradient, ""but none of the them have requires_grad=True.")returnTruedef_check_outputs(outputs)->None:ifany(_is_sparse_any_tensor(t)fortinoutputsifisinstance(t,torch.Tensor)):# it is easier to call to_dense() on the sparse output than# to modify analytical jacobianraiseValueError("Sparse output is not supported at gradcheck yet. ""Please call to_dense(masked_grad=...) on the output of fn for gradcheck.")ifany(t.layout==torch._mkldnnfortinoutputsifisinstance(t,torch.Tensor)):# type: ignore[attr-defined]raiseValueError("MKLDNN output is not supported at gradcheck yet. ""Please call to_dense(masked_grad=...) on the output of fn for gradcheck.")def_check_no_differentiable_outputs(func,inputs,func_out,eps,*,is_forward_ad)->bool:# When there are no differentiable outputs, numerical gradient for a function is# expected to be zero.jacobians_all_inputs_outputs=_get_numerical_jacobian(func,inputs,func_out,eps=eps,is_forward_ad=is_forward_ad)forjacobians_all_outputs_and_fixed_inputinjacobians_all_inputs_outputs:forjacobianinjacobians_all_outputs_and_fixed_input:iftorch.ne(jacobian,0).sum()>0:raiseGradcheckError("Numerical gradient for function expected to be zero")returnTruedef_check_no_differentiable_outputs_fast(func,func_out,all_inputs,inputs_indices,all_u,eps,nondet_tol):forinp_idx,uinzip(inputs_indices,all_u):jvps=_get_numerical_jvp_wrt_specific_input(func,inp_idx,all_inputs,u,eps)forjvpinjvps:ifjvp.numel()==0:continueif(jvp-torch.zeros_like(jvp)).abs().max()>nondet_tol:raiseGradcheckError("Numerical gradient for function expected to be zero")returnTrueFAILED_BATCHED_GRAD_MSG="""gradcheck or gradgradcheck failed while testing batched gradient computation.This could have been invoked in a number of ways (via a test that callsgradcheck/gradgradcheck directly or via an autogenerated test).If you are adding a new operator, please file an issue and then use one of theworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.If the test- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `check_batched_grad=False` as a keyword argument.- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`.If you're modifying an existing operator that supports batched grad computation,or wish to make a new operator work with batched grad computation, please readthe following.To compute batched grads (e.g., jacobians, hessians), we vmap over the backwardcomputation. The most common failure case is if there is a 'vmap-incompatibleoperation' in the backward pass. Please seeNOTE: [How to write vmap-compatible backward formulas]in the codebase for an explanation of how to fix this.""".strip()FAILED_BATCHED_GRAD_MSG_FWD_AD="""gradcheck failed while testing batched gradient computation with forward-mode AD.This test is enabled automatically when both `check_batched_grad=True`and `check_forward_ad=True`, but can be disabled in the following waysdependong on how the test was invoked (via a test that calls gradcheckdirectly or via an autogenerated test).If you are adding a new operator, please file an issue and then use one of theworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck.If the test- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `check_batched_forward_grad=False` as a keyword argument.- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `check_batched_forward_grad=False`"""def_get_failed_batched_grad_test_msg(output_idx,input_idx,res,exp,is_forward_ad=False):returnf"""For output {output_idx} and input {input_idx}:{FAILED_BATCHED_GRAD_MSG_FWD_ADifis_forward_adelseFAILED_BATCHED_GRAD_MSG}Got:{res}Expected:{exp}""".strip()def_test_batched_grad_forward_ad(func,inputs)->bool:fwAD=torch.autograd.forward_ad# To avoid early import issues (do we need this?)assertisinstance(inputs,tuple)forinput_idx,current_inputinenumerate(inputs):ifnot(is_tensor_like(current_input)andcurrent_input.requires_grad):continuedefjvp(tangent:torch.Tensor):withfwAD.dual_level():dual=fwAD.make_dual(current_input.detach(),tangent)inputs_with_dual=tuple(dualifidx==input_idxelse(inp.detach()ifis_tensor_like(inp)elseinp)foridx,inpinenumerate(inputs))dual_outputs=_as_tuple(func(*inputs_with_dual))ret=[]fordual_outputindual_outputs:ifdual_outputisNone:continueprimal_out,tangent_out=fwAD.unpack_dual(dual_output)iftangent_outisnotNone:ret.append(tangent_out)else:ret.append(torch.zeros([],dtype=primal_out.dtype,device=primal_out.device).expand(primal_out.shape))returntuple(ret)ifnot_is_float_or_complex_tensor(current_input):continuetangents=[torch.randn_like(current_input)for_inrange(2)]expected=[jvp(t)fortintangents]expected=[torch.stack(shards)forshardsinzip(*expected)]try:result=_vmap(jvp)(torch.stack(tangents))exceptRuntimeErrorasex:# Rethrow to provide a better error messageraiseGradcheckError(f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}")fromexforinput_idx,(res,exp)inenumerate(zip(result,expected)):iftorch.allclose(res,exp):continueraiseGradcheckError(_get_failed_batched_grad_test_msg(input_idx,input_idx,res,exp,is_forward_ad=True))returnTruedef_test_batched_grad(input,output,output_idx)->bool:# NB: _test_batched_grad compares two autograd.grad invocations with a single# vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the# sense that we're not comparing an analytical jacobian with a numeric one,# but it is morally similar (we could have computed a full analytic jac# via vmap, but that is potentially slow)diff_input_list=list(_iter_tensors(input,True))grad=functools.partial(torch.autograd.grad,output,diff_input_list,retain_graph=True,allow_unused=True,)defvjp(v):results=grad(v)results=tuple(gradifgradisnotNoneelsetorch.zeros([],dtype=inp.dtype,device=inp.device).expand(inp.shape)forgrad,inpinzip(results,diff_input_list))returnresultsgrad_outputs=[torch.randn_like(output)for_inrange(2)]expected=[vjp(gO)forgOingrad_outputs]expected=[torch.stack(shards)forshardsinzip(*expected)]# Squash warnings since these are expected to happen in most cases# NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209withwarnings.catch_warnings():warnings.filterwarnings("ignore",message="There is a performance drop")warnings.filterwarnings("ignore",message="Please use torch.vmap")try:result=vmap(vjp)(torch.stack(grad_outputs))exceptRuntimeErrorasex:# It's OK that we're not raising the error at the correct callsite.# That's because the callsite is always going to inside the Python# autograd.grad instead of the C++ traceback of what line in the# backward formularaiseGradcheckError(f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}")fromexforinput_idx,(res,exp)inenumerate(zip(result,expected)):iftorch.allclose(res,exp):continueraiseGradcheckError(_get_failed_batched_grad_test_msg(output_idx,input_idx,res,exp))returnTruedef_test_backward_mul_by_grad_output(outputs,inputs,masked)->bool:# Tests that backward is multiplied by grad_outputdiff_input_list:List[torch.Tensor]=list(_iter_tensors(inputs,True))ifnotdiff_input_list:raiseGradcheckError("no Tensors requiring grad found in input")grads_input=torch.autograd.grad(outputs,diff_input_list,[torch.zeros_like(o,memory_format=torch.legacy_contiguous_format)foroinoutputs],allow_unused=True,)forgi,diinzip(grads_input,diff_input_list):ifgiisNone:continueifisinstance(gi,torch.Tensor)andgi.layout!=torch.strided:ifgi.layout!=di.layout:raiseGradcheckError("grad is incorrect layout ("+str(gi.layout)+" is not "+str(di.layout)+")")if_is_sparse_any_tensor(gi):sparse_kind=str(gi.layout).replace("torch.","").replace("_coo","")ifgi.sparse_dim()!=di.sparse_dim():raiseGradcheckError(f"grad is {sparse_kind} tensor, but has incorrect sparse_dim"f" {gi.sparse_dim()}, expected {di.sparse_dim()}")ifgi.dense_dim()!=di.dense_dim():raiseGradcheckError(f"grad is {sparse_kind} tensor, but has incorrect dense_dim"f" {gi.dense_dim()}, expected {di.dense_dim()}")gi=gi.to_dense()di=di.to_dense()ifmasked:ifnottorch.allclose(gi,torch.zeros_like(gi)):raiseGradcheckError("backward not multiplied by grad_output")elifnotgi.eq(0).all():raiseGradcheckError("backward not multiplied by grad_output")ifgi.dtype!=di.dtype:raiseGradcheckError("grad is incorrect type")ifgi.device!=di.device:raiseGradcheckError("grad is incorrect device")ifgi.size()!=di.size():raiseGradcheckError("grad is incorrect size")returnTruedef_test_undefined_forward_mode(func,outputs,inputs):fwAD=torch.autograd.forward_adinp_tensors_idx,inp_tensors=_get_inp_tensors(inputs)all_v,all_u,all_u_dense=_make_vectors(inp_tensors,outputs,use_forward_ad=True)tensor_inputs=tuple(iforiininputsifis_tensor_like(i)andi.requires_grad)withfwAD.dual_level():fw_grads=[]dual_inputs=[]tensor_indices=set()fori,inpinenumerate(inputs):ifis_tensor_like(inp)andinp.requires_grad:ifinp.layout==torch._mkldnn:# type: ignore[attr-defined]raiseValueError("MKLDNN inputs are not support for forward AD gradcheck.")inp=fwAD.make_dual(inp.detach(),torch.zeros_like(inp))# If inp is a differentiable view, the dual might not be the tangent given to# make_dual, so read it explicitly from the dual tensorfw_grads.append(fwAD.unpack_dual(inp)[1])tensor_indices.add(i)dual_inputs.append(inp)fori,(fw_grad,u)inenumerate(zip(fw_grads,all_u)):fw_grad.copy_(u.view_as(fw_grad))foridx,inpinenumerate(inputs):ifidxnotintensor_indices:continuedual_inp_obj=dual_inputs[idx]# case 1 (Materialized Zero Tensor Tangent)dual_inputs[idx]=fwAD.make_dual(inp.detach(),torch.zeros_like(inp))raw_outputs=_as_tuple(func(*dual_inputs))dual_outputs1=filter(_is_float_or_complex_tensor,raw_outputs)# case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)dual_inputs[idx]=inp.detach()raw_outputs=_as_tuple(func(*dual_inputs))dual_outputs2=filter(_is_float_or_complex_tensor,raw_outputs)# resetdual_inputs[idx]=dual_inp_objforindex_o,(d_o1,d_o2)inenumerate(zip(dual_outputs1,dual_outputs2)):val1,res1=fwAD.unpack_dual(d_o1)val2,res2=fwAD.unpack_dual(d_o2)ifnot(res1isNoneorres2isNone):ifnottorch.allclose(res1,res2):raiseGradcheckError("Mismatch in tangent values for output with index: ",index_o," when input: ",inp," has an undefined tangent value. "," Got: ",res1," but expected: ",res2,)returnTruedef_test_undefined_backward_mode(func,outputs,inputs)->bool:diff_input_list:List[torch.Tensor]=list(_iter_tensors(inputs,True))ifnotdiff_input_list:raiseGradcheckError("no Tensors requiring grad found in input")defwarn_bc_breaking():warnings.warn("Backwards compatibility: New undefined gradient support checking ""feature is enabled by default, but it may break existing callers ""of this function. If this is true for you, you can call this "'function with "check_undefined_grad=False" to disable the feature')defcheck_undefined_grad_support(output_to_check):grads_output=[torch.zeros_like(o,memory_format=torch.legacy_contiguous_format)foroinoutput_to_check]try:grads_input=torch.autograd.grad(output_to_check,diff_input_list,grads_output,allow_unused=True)exceptRuntimeErrorase:warn_bc_breaking()raiseGradcheckError("Expected backward function to handle undefined output grads. "'Please look at "Notes about undefined output gradients" in ''"tools/autograd/derivatives.yaml"')fromeforgi,iinzip(grads_input,diff_input_list):if(giisnotNone)and(notgi.eq(0).all()):warn_bc_breaking()raiseGradcheckError("Expected all input grads to be undefined or zero when all output grads are undefined "'or zero. Please look at "Notes about undefined output gradients" in ''"tools/autograd/derivatives.yaml"')returnTrue# All backward functions must work properly if all output grads are undefinedoutputs_to_check=[[torch._C._functions.UndefinedGrad()(o)foroin_differentiable_outputs(func(*inputs))# This check filters out Tensor-likes that aren't instances of Tensor.ifisinstance(o,torch.Tensor)]]# If there are multiple output grads, we should be able to undef one at a time without erroriflen(outputs_to_check[0])>1:forundef_grad_idxinrange(len(outputs)):output_to_check=_differentiable_outputs(func(*inputs))outputs_to_check.append([torch._C._functions.UndefinedGrad()(o)ifidx==undef_grad_idxelseoforidx,oinenumerate(output_to_check)])returnall(check_undefined_grad_support(output)foroutputinoutputs_to_check)def_as_tuple(x):ifisinstance(x,tuple):returnxelifisinstance(x,list):returntuple(x)else:return(x,)def_differentiable_outputs(x):returntuple(oforoin_as_tuple(x)ifo.requires_grad)def_get_notallclose_msg(analytical,numerical,output_idx,input_idx,complex_indices,test_imag=False,is_forward_ad=False,)->str:out_is_complex=((notis_forward_ad)andcomplex_indicesandoutput_idxincomplex_indices)inp_is_complex=is_forward_adandcomplex_indicesandinput_idxincomplex_indicespart="imaginary"iftest_imagelse"real"element="inputs"ifis_forward_adelse"outputs"prefix=(""ifnot(out_is_complexorinp_is_complex)elsef"While considering the {part} part of complex {element} only, ")mode="computed with forward mode "ifis_forward_adelse""return(prefix+"Jacobian %smismatch for output %d with respect to input %d,\n""numerical:%s\nanalytical:%s\n"%(mode,output_idx,input_idx,numerical,analytical))def_transpose(matrix_of_tensors):# returns list of tuplesreturnlist(zip(*matrix_of_tensors))def_real_and_imag_output(fn):# returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as# the original fn, except torch.real or torch.imag are applied to the complex outputsdefapply_to_c_outs(fn,fn_to_apply):defwrapped_fn(*inputs):outs=_as_tuple(fn(*inputs))returntuple(fn_to_apply(o)ifo.is_complex()elseoforoinouts)returnwrapped_fnreturnapply_to_c_outs(fn,torch.real),apply_to_c_outs(fn,torch.imag)def_real_and_imag_input(fn,complex_inp_indices,tupled_inputs):# returns new functions that take real inputs instead of complex inputs as# (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j).# In each case, the other part is considered constant.# We do not use 0 for the constant here to make sure we always call the user function with a valid input.defapply_to_c_inps(fn,fn_to_apply):defwrapped_fn(*inputs):new_inputs=list(inputs)forshould_be_complexincomplex_inp_indices:new_inputs[should_be_complex]=fn_to_apply(new_inputs[should_be_complex],tupled_inputs[should_be_complex])return_as_tuple(fn(*new_inputs))returnwrapped_fnreal_fn=apply_to_c_inps(fn,lambdainp,orig:inp+orig.imag*1j)imag_fn=apply_to_c_inps(fn,lambdainp,orig:orig.real+inp*1j)returnreal_fn,imag_fndef_gradcheck_real_imag(gradcheck_fn,func,func_out,tupled_inputs,outputs,eps,rtol,atol,check_grad_dtypes,check_forward_ad,check_backward_ad,nondet_tol,check_undefined_grad,):complex_out_indices=[ifori,oinenumerate(outputs)ifo.is_complex()]has_any_complex_output=any(o.is_complex()foroin_as_tuple(func_out))ifcheck_backward_ad:ifhas_any_complex_output:real_fn,imag_fn=_real_and_imag_output(func)imag_func_out=imag_fn(*tupled_inputs)imag_outputs=_differentiable_outputs(imag_func_out)gradcheck_fn(imag_fn,imag_func_out,tupled_inputs,imag_outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,complex_indices=complex_out_indices,test_imag=True,)real_func_out=real_fn(*tupled_inputs)real_outputs=_differentiable_outputs(real_func_out)gradcheck_fn(real_fn,real_func_out,tupled_inputs,real_outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,complex_indices=complex_out_indices,)else:gradcheck_fn(func,func_out,tupled_inputs,outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,)ifcheck_forward_ad:complex_inp_indices=[ifori,inpinenumerate(tupled_inputs)ifis_tensor_like(inp)andinp.is_complex()]ifcomplex_inp_indices:real_fn,imag_fn=_real_and_imag_input(func,complex_inp_indices,tupled_inputs)imag_inputs=[inp.imagifis_tensor_like(inp)andinp.is_complex()elseinpforinpintupled_inputs]imag_func_out=imag_fn(*imag_inputs)diff_imag_func_out=_differentiable_outputs(imag_func_out)gradcheck_fn(imag_fn,imag_func_out,imag_inputs,diff_imag_func_out,eps,rtol,atol,check_grad_dtypes,nondet_tol,complex_indices=complex_inp_indices,test_imag=True,use_forward_ad=True,)real_inputs=[inp.realifis_tensor_like(inp)andinp.is_complex()elseinpforinpintupled_inputs]real_func_out=real_fn(*real_inputs)diff_real_func_out=_differentiable_outputs(real_func_out)gradcheck_fn(real_fn,real_func_out,real_inputs,diff_real_func_out,eps,rtol,atol,check_grad_dtypes,nondet_tol,complex_indices=complex_inp_indices,use_forward_ad=True,)ifcheck_undefined_grad:_test_undefined_forward_mode(imag_fn,imag_func_out,imag_inputs)_test_undefined_forward_mode(real_fn,real_func_out,real_inputs)else:gradcheck_fn(func,func_out,tupled_inputs,outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,use_forward_ad=True,)ifcheck_undefined_grad:_test_undefined_forward_mode(func,outputs,tupled_inputs)def_slow_gradcheck(func,func_out,tupled_inputs,outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,*,use_forward_ad=False,complex_indices=None,test_imag=False,masked=False,):func_out=_as_tuple(func_out)ifnotoutputs:return_check_no_differentiable_outputs(func,tupled_inputs,func_out,eps=eps,is_forward_ad=use_forward_ad)tupled_inputs_numerical=tupled_inputsifmaskedelse_densify(tupled_inputs)numerical=_transpose(_get_numerical_jacobian(func,tupled_inputs_numerical,func_out,eps=eps,is_forward_ad=use_forward_ad,))# Note: [numerical vs analytical output length]# The numerical path returns jacobian quantity for all outputs, even if requires_grad of that# output is False. This behavior is necessary for _check_no_differentiable_outputs to work.numerical=[njforo,njinzip(func_out,numerical)ifo.requires_grad]ifuse_forward_ad:analytical_forward=_get_analytical_jacobian_forward_ad(func,tupled_inputs,func_out,check_grad_dtypes=check_grad_dtypes)fori,n_per_outinenumerate(numerical):forj,ninenumerate(n_per_out):a=analytical_forward[j][i]ifnot_allclose_with_type_promotion(a,n.to(a.device),rtol,atol):raiseGradcheckError(_get_notallclose_msg(a,n,i,j,complex_indices,test_imag,is_forward_ad=True))else:fori,oinenumerate(outputs):analytical=_check_analytical_jacobian_attributes(tupled_inputs,o,nondet_tol,check_grad_dtypes)forj,(a,n)inenumerate(zip(analytical,numerical[i])):ifnot_allclose_with_type_promotion(a,n.to(a.device),rtol,atol):raiseGradcheckError(_get_notallclose_msg(a,n,i,j,complex_indices,test_imag))returnTruedef_dot_with_type_promotion(u,v):assertu.dim()==1andv.dim()==1return(u*v).sum()def_allclose_with_type_promotion(a,b,rtol,atol):promoted_type=torch.promote_types(a.dtype,b.dtype)a=a.to(dtype=promoted_type)b=b.to(dtype=promoted_type)returntorch.allclose(a,b,rtol,atol)def_to_real_dtype(dtype):ifdtype==torch.complex128:returntorch.float64elifdtype==torch.complex64:returntorch.float32else:returndtypedef_vec_from_tensor(x,generator,downcast_complex=False):# Create a random vector with the same number of elements as x and the same# dtype/device. If x is complex and downcast_complex is False, we create a# complex tensor with only real component.ifx.layout==torch.sparse_coo:# For sparse, create a random sparse vec with random values in the same# indices. Make sure size is set so that it isn't inferred to be smaller.x_values=x._values()dtype=_to_real_dtype(x.dtype)ifdowncast_complexelsex.dtypevalues=(torch.rand(x_values.numel(),generator=generator).to(dtype=dtype,device=x.device).view(x_values.shape))values/=values.norm()vec=torch.sparse_coo_tensor(x._indices(),values,x.size(),device=x.device)elif_is_sparse_compressed_tensor(x):ifx.layoutin{torch.sparse_csr,torch.sparse_bsr}:compressed_indices,plain_indices=x.crow_indices(),x.col_indices()else:compressed_indices,plain_indices=x.ccol_indices(),x.row_indices()x_values=x.values()dtype=_to_real_dtype(x.dtype)ifdowncast_complexelsex.dtypevalues=(torch.rand(x_values.numel(),generator=generator).to(dtype=dtype,device=x.device).view(x_values.shape))values/=values.norm()vec=torch.sparse_compressed_tensor(compressed_indices,plain_indices,values,x.size(),layout=x.layout,device=x.device,)else:dtype=_to_real_dtype(x.dtype)ifdowncast_complexelsex.dtypevec=torch.rand(x.numel(),generator=generator).to(dtype=dtype,device=x.device)vec/=vec.norm()returnvecdef_get_inp_tensors(tupled_inputs):inp_idx_tup=[(i,t)fori,tinenumerate(tupled_inputs)ifis_tensor_like(t)andt.requires_grad]return[tup[0]fortupininp_idx_tup],[tup[1]fortupininp_idx_tup]def_adjusted_atol(atol,u,v):# In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we# allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and# q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is# the correctly sized matrix in which each entry is atol.## We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N# matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i)# TODO: properly handle case when u is tuple instead of only taking first elementu=u[0]ifisinstance(u,tuple)elseusum_u=u.sum()sum_v=1.0ifvisNoneelsev.sum()returnatol*float(sum_u)*float(sum_v)FAST_FAIL_SLOW_OK_MSG="""Fast gradcheck failed but element-wise differences are small. This means that thetest might've passed in slow_mode!If you are adding a new operator, please file an issue and then use one of theworkarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck:If the test- manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `fast_mode=False` as a keyword argument.- is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `gradcheck_fast_mode=False`- is a Module test (e.g., in common_nn.py), then modify the corresponding module_test entry to have `gradcheck_fast_mode=False`""".strip()def_run_slow_mode_and_get_error(func,tupled_inputs,outputs,input_idx,output_idx,rtol,atol,is_forward_ad):# Compute jacobians in slow mode for better error messageslow_numerical=_get_numerical_jacobian(func,tupled_inputs,outputs,is_forward_ad=is_forward_ad)[input_idx][output_idx]ifis_forward_ad:defnew_fn(inp):new_inputs=list(tupled_inputs)new_inputs[input_idx]=inpreturn_as_tuple(func(*new_inputs))[output_idx]slow_analytical=_get_analytical_jacobian_forward_ad(new_fn,(tupled_inputs[input_idx],),(outputs[output_idx],))[0][0]else:slow_analytical=_get_analytical_jacobian(tupled_inputs,outputs,input_idx,output_idx)# Assume jacobians are non-empty and have the same shapeslow_max_diff=(slow_numerical-slow_analytical).abs().max()slow_allclose=torch.allclose(slow_analytical,slow_numerical,rtol,atol)msg=("\nThe above quantities relating the numerical and analytical jacobians are computed \n""in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n""about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n"f"Numerical:\n{slow_numerical}\n"f"Analytical:\n{slow_analytical}\n\n"f"The max per-element difference (slow mode) is: {slow_max_diff}.\n")ifslow_allclose:# Slow gradcheck would've passed!msg+=FAST_FAIL_SLOW_OK_MSGreturnmsgdef_to_flat_dense_if_sparse(tensor):if_is_sparse_any_tensor(tensor):returntensor.to_dense().reshape(-1)else:returntensordef_make_vectors(inp_tensors,outputs,*,use_forward_ad):# Use our own generator to avoid messing with the user's RNG stateg_cpu=torch.Generator()def_vec_from_tensor_cpu(*args):# Default allocate all tensors on CPU, so they are on the same device as the generator# even if the user specified a default devicewithtorch.device("cpu"):return_vec_from_tensor(*args)all_u=[]all_u_dense=[]forinpininp_tensors:ur=_vec_from_tensor_cpu(inp,g_cpu,True)ur_dense=_to_flat_dense_if_sparse(ur)ifinp.is_complex():ui=_vec_from_tensor_cpu(inp,g_cpu,True)all_u.append((ur,ui))ui_dense=_to_flat_dense_if_sparse(ui)all_u_dense.append((ur_dense,ui_dense))else:all_u.append(ur)all_u_dense.append(ur_dense)all_v=(Noneifuse_forward_adelse[_vec_from_tensor_cpu(out,g_cpu)foroutinoutputs])returnall_v,all_u,all_u_densedef_check_analytical_numerical_equal(all_analytical,all_numerical,complex_indices,tupled_inputs,outputs,func,all_v,all_u,rtol,atol,test_imag,*,is_forward_ad=False,):fori,all_numerical_for_input_iinenumerate(all_numerical):forj,ninenumerate(all_numerical_for_input_i):# Forward AD generates the transpose of what this function expectsifis_forward_ad:a=all_analytical[i][j]else:a=all_analytical[j][i]n=n.to(device=a.device)updated_atol=_adjusted_atol(atol,all_u[i],all_v[j]ifall_velseNone)ifnot_allclose_with_type_promotion(a,n.to(a.device),rtol,updated_atol):jacobians_str=_run_slow_mode_and_get_error(func,tupled_inputs,outputs,i,j,rtol,atol,is_forward_ad)raiseGradcheckError(_get_notallclose_msg(a,n,j,i,complex_indices,test_imag,is_forward_ad)+jacobians_str)def_fast_gradcheck(func,func_out,inputs,outputs,eps,rtol,atol,check_grad_dtypes,nondet_tol,*,use_forward_ad=False,complex_indices=None,test_imag=False,masked=False,):# See https://github.com/pytorch/pytorch/issues/53876 for detailsinp_tensors_idx,inp_tensors=_get_inp_tensors(inputs)# Backward mode computes v^T * J (VJP)# Since we computed J * u (JVP) through finite difference method, we perform an equality check# between VJP * u, v * JVP# ----# Forward mode computes J * u (JVP)# Since we already compute JVP through finite difference method,# we don't need v for correctness check here as asserted belowall_v,all_u,all_u_dense=_make_vectors(inp_tensors,outputs,use_forward_ad=use_forward_ad)inputs_numerical,all_u_numerical,all_v_numerical=((inputs,all_u,all_v)ifmaskedelse_densify((inputs,all_u,all_v)))numerical_vJu=_get_numerical_vJu(func,inputs_numerical,inp_tensors_idx,func_out,all_u_numerical,all_v_numerical,eps,is_forward_ad=use_forward_ad,)# TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as wellifuse_forward_ad:assertall_visNoneanalytical_vJu=_get_analytical_jacobian_forward_ad(func,inputs,_as_tuple(func_out),all_u=all_u,check_grad_dtypes=check_grad_dtypes,)else:ifnotoutputs:_check_no_differentiable_outputs_fast(func,func_out,inputs,inp_tensors_idx,all_u,eps,nondet_tol)analytical_vJu=_get_analytical_vJu_backward_mode(inputs,outputs,nondet_tol,check_grad_dtypes,all_v,all_u_dense)_check_analytical_numerical_equal(analytical_vJu,numerical_vJu,complex_indices,inputs,outputs,func,all_v,all_u,rtol,atol,test_imag,is_forward_ad=use_forward_ad,)returnTrue# Note [VarArg of Tensors]# ~~~~~~~~~~~~~~~~~~~~~~~~# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,# the '...' first argument of Callable can be replaced with VarArg(Tensor).# For now, we permit any input.
[docs]defgradcheck(func:Callable[...,Union[_TensorOrTensors]],# See Note [VarArg of Tensors]inputs:_TensorOrTensors,*,eps:float=1e-6,atol:float=1e-5,rtol:float=1e-3,raise_exception:bool=True,check_sparse_nnz:Optional[bool]=None,nondet_tol:float=0.0,check_undefined_grad:bool=True,check_grad_dtypes:bool=False,check_batched_grad:bool=False,check_batched_forward_grad:bool=False,check_forward_ad:bool=False,check_backward_ad:bool=True,fast_mode:bool=False,masked:Optional[bool]=None,)->bool:# noqa: D400,D205r"""Check gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` that are of floating point or complex type and with ``requires_grad=True``. The check between numerical and analytical gradients uses :func:`~torch.allclose`. For most of the complex functions we consider for optimization purposes, no notion of Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient computation is done under the assumption that the overall function has a real-valued output, we treat functions with complex output in a special way. For these functions, gradcheck is applied to two real-valued functions corresponding to taking the real components of the complex outputs for the first, and taking the imaginary components of the complex outputs for the second. For more details, check out :ref:`complex_autograd-doc`. .. note:: The default values are designed for :attr:`input` of double precision. This check will likely fail if :attr:`input` is of less precision, e.g., ``FloatTensor``. .. note:: Gradcheck may fail when evaluated on non-differentiable points because the numerically computed gradients via finite differencing may differ those computed analytically (not necessarily because either is incorrect). For more context, see :ref:`non-differentiable-func-grad`. .. warning:: If any checked tensor in :attr:`input` has overlapping memory, i.e., different indices pointing to the same memory address (e.g., from :func:`torch.expand`), this check will likely fail because the numerical gradients computed by point perturbation at such indices will change values at all other indices that share the same memory address. Args: func (function): a Python function that takes Tensor inputs and returns a Tensor or a tuple of Tensors inputs (tuple of Tensor or Tensor): inputs to the function eps (float, optional): perturbation for finite differences atol (float, optional): absolute tolerance rtol (float, optional): relative tolerance raise_exception (bool, optional): indicating whether to raise an exception if the check fails. The exception gives more information about the exact nature of the failure. This is helpful when debugging gradchecks. check_sparse_nnz (bool, optional): if ``True``, gradcheck allows for SparseTensor input, and for any SparseTensor inputs, gradcheck will perform its check at ``nnz`` positions only. The ``check_sparse_nnz`` argument is deprecated, use the ``masked`` argument instead. If ``check_sparse_nnz != masked``, an exception is raised. nondet_tol (float, optional): tolerance for non-determinism. When running identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. check_undefined_grad (bool, optional): if ``True``, check if undefined output grads are supported and treated as zeros, for ``Tensor`` outputs. check_batched_grad (bool, optional): if ``True``, check if we can compute batched gradients using prototype vmap support. Defaults to False. check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``. check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward mode AD match the numerical ones. Defaults to ``False``. check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on backward mode AD to be implemented. Defaults to ``True``. fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only implemented for R to R functions. If none of the inputs and outputs are complex a faster implementation of gradcheck that no longer computes the entire jacobian is run; otherwise, we fall back to the slow implementation. masked (bool, optional): if ``True``, the gradients of unspecified elements of sparse tensors are ignored. Defaults to ``False``. Returns: ``True`` if all differences satisfy allclose condition """ifcheck_sparse_nnzisNone:ifmaskedisNone:check_sparse_nnz=masked=Falseelse:check_sparse_nnz=maskedelse:warnings.warn("Backwards compatibility: check_sparse_nnz is deprecated, it will be removed in a future version of PyTorch."f" Use masked={check_sparse_nnz} instead.")ifmaskedisNone:masked=check_sparse_nnzelifcheck_sparse_nnz!=masked:raiseValueError(f"Expected specified check_sparse_nnz (={check_sparse_nnz}) to be equal to masked (={masked}).")assert(check_forward_adorcheck_backward_ad),"Expected at least one of check_forward_ad or check_backward_ad to be True"assertnot(check_batched_gradandnotcheck_backward_ad),"Setting check_batched_grad=True requires check_backward_ad to be True"assertnot(check_batched_forward_gradandnotcheck_forward_ad),"Setting check_batched_forward_grad=True requires check_forward_ad to be True"args=locals().copy()args.pop("raise_exception")args.pop("check_sparse_nnz")ifnotraise_exception:try:return_gradcheck_helper(**args)exceptGradcheckErrorase:returnFalseelse:return_gradcheck_helper(**args)
def_gradcheck_helper(func,inputs,eps,atol,rtol,nondet_tol,check_undefined_grad,check_grad_dtypes,check_batched_grad,check_batched_forward_grad,check_forward_ad,check_backward_ad,fast_mode,masked,):tupled_inputs=_as_tuple(inputs)_check_inputs(tupled_inputs)func_out=func(*tupled_inputs)outputs=_differentiable_outputs(func_out)_check_outputs(outputs)gradcheck_fn=functools.partial(_fast_gradcheckiffast_modeelse_slow_gradcheck,masked=masked)_gradcheck_real_imag(gradcheck_fn,func,func_out,tupled_inputs,outputs,eps,rtol,atol,check_grad_dtypes,check_forward_ad=check_forward_ad,check_backward_ad=check_backward_ad,nondet_tol=nondet_tol,check_undefined_grad=check_undefined_grad,)ifcheck_batched_forward_grad:_test_batched_grad_forward_ad(func,tupled_inputs)# Short circuit because remaining tests rely on backward AD to be implementedifnotcheck_backward_ad:returnTruefori,oinenumerate(outputs):ifcheck_batched_grad:_test_batched_grad(tupled_inputs,o,i)_test_backward_mul_by_grad_output(outputs,tupled_inputs,masked)ifcheck_undefined_gradandcheck_backward_ad:_test_undefined_backward_mode(func,outputs,tupled_inputs)returnTrue
[docs]defgradgradcheck(func:Callable[...,_TensorOrTensors],# See Note [VarArg of Tensors]inputs:_TensorOrTensors,grad_outputs:Optional[_TensorOrTensors]=None,*,eps:float=1e-6,atol:float=1e-5,rtol:float=1e-3,gen_non_contig_grad_outputs:bool=False,raise_exception:bool=True,nondet_tol:float=0.0,check_undefined_grad:bool=True,check_grad_dtypes:bool=False,check_batched_grad:bool=False,check_fwd_over_rev:bool=False,check_rev_over_rev:bool=True,fast_mode:bool=False,masked:bool=False,)->bool:# noqa: D400,D205r"""Check gradients of gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` and :attr:`grad_outputs` that are of floating point or complex type and with ``requires_grad=True``. This function checks that backpropagating through the gradients computed to the given :attr:`grad_outputs` are correct. The check between numerical and analytical gradients uses :func:`~torch.allclose`. .. note:: The default values are designed for :attr:`input` and :attr:`grad_outputs` of double precision. This check will likely fail if they are of less precision, e.g., ``FloatTensor``. .. warning:: If any checked tensor in :attr:`input` and :attr:`grad_outputs` has overlapping memory, i.e., different indices pointing to the same memory address (e.g., from :func:`torch.expand`), this check will likely fail because the numerical gradients computed by point perturbation at such indices will change values at all other indices that share the same memory address. Args: func (function): a Python function that takes Tensor inputs and returns a Tensor or a tuple of Tensors inputs (tuple of Tensor or Tensor): inputs to the function grad_outputs (tuple of Tensor or Tensor, optional): The gradients with respect to the function's outputs. eps (float, optional): perturbation for finite differences atol (float, optional): absolute tolerance rtol (float, optional): relative tolerance gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the randomly generated gradient outputs are made to be noncontiguous raise_exception (bool, optional): indicating whether to raise an exception if the check fails. The exception gives more information about the exact nature of the failure. This is helpful when debugging gradchecks. nondet_tol (float, optional): tolerance for non-determinism. When running identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. Note that a small amount of nondeterminism in the gradient will lead to larger inaccuracies in the second derivative. check_undefined_grad (bool, optional): if True, check if undefined output grads are supported and treated as zeros check_batched_grad (bool, optional): if True, check if we can compute batched gradients using prototype vmap support. Defaults to False. fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that no longer computes the entire jacobian. masked (bool, optional): if True, the gradients of unspecified elements of sparse tensors are ignored (default, False). Returns: True if all differences satisfy allclose condition """assert(check_fwd_over_revorcheck_rev_over_rev),"Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True"assertnot(check_undefined_gradandnotcheck_rev_over_rev),"Setting check_undefined_grad=True requires check_rev_over_rev to be True"assertnot(check_batched_gradandnotcheck_rev_over_rev),"Setting check_batched_grad=True requires check_rev_over_rev to be True"# TODO: do we want to test this too?# assert not (check_batched_forward_grad and not check_fwd_over_rev), (# "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True")tupled_inputs=_as_tuple(inputs)ifgrad_outputsisNone:# If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputsoutputs=_differentiable_outputs(func(*tupled_inputs))tupled_grad_outputs=tuple(torch.testing.make_tensor(x.shape,dtype=x.dtypeifx.is_floating_point()orx.is_complex()elsetorch.double,device=x.device,low=-1,high=1,requires_grad=True,noncontiguous=gen_non_contig_grad_outputs,)forxinoutputs)else:tupled_grad_outputs=_as_tuple(grad_outputs)num_outputs=len(tupled_grad_outputs)# NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs# before running forward mode ADdiff_input_args_indices={ifori,xinenumerate(tupled_inputs)ifis_tensor_like(x)andx.requires_grad}diff_grad_output_indices={ifori,xinenumerate(tupled_grad_outputs)ifx.requires_grad}defnew_func(*args):# Restore the requires_grad informationinput_args=tuple(x.requires_grad_()ifiindiff_input_args_indiceselsexfori,xinenumerate(args[:-num_outputs]))outputs=_differentiable_outputs(func(*input_args))grad_outputs=tuple(x.requires_grad_()ifiindiff_grad_output_indiceselsexfori,xinenumerate(args[-num_outputs:]))diff_input_args=tuple(xfori,xinenumerate(input_args)ifiindiff_input_args_indices)grad_inputs=torch.autograd.grad(outputs,diff_input_args,grad_outputs,create_graph=True,allow_unused=True)grad_inputs=tuple(gforgingrad_inputsifgisnotNone)returngrad_inputsreturngradcheck(new_func,tupled_inputs+tupled_grad_outputs,eps=eps,atol=atol,rtol=rtol,raise_exception=raise_exception,nondet_tol=nondet_tol,check_undefined_grad=check_undefined_grad,check_grad_dtypes=check_grad_dtypes,check_batched_grad=check_batched_grad,fast_mode=fast_mode,check_forward_ad=check_fwd_over_rev,check_backward_ad=check_rev_over_rev,masked=masked,)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.