torch.nn.utils.prune.random_structured¶
- torch.nn.utils.prune.random_structured(module, name, amount, dim)[source]¶
Prunes tensor corresponding to parameter called
name
inmodule
by removing the specifiedamount
of (currently unpruned) channels along the specifieddim
selected at random. Modifies module in place (and also return the modified module) by:adding a named buffer called
name+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method.replacing the parameter
name
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.
- Parameters
module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
module
on which pruning will act.amount (int or float) – quantity of parameters to prune. If
float
, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. Ifint
, it represents the absolute number of parameters to prune.dim (int) – index of the dim along which we define channels to prune.
- Returns
modified (i.e. pruned) version of the input module
- Return type
module (nn.Module)
Examples
>>> m = prune.random_structured( ... nn.Linear(5, 3), 'weight', amount=3, dim=1 ... ) >>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0)) >>> print(columns_pruned) 3