Shortcuts

torch.scatter_reduce

torch.scatter_reduce(input, dim, index, reduce, *, output_size=None)Tensor

Reduces all values from the input tensor to the indices specified in the index tensor. For each value in input, its output index is specified by its index in input for dimension != dim and by the corresponding value in index for dimension = dim. The applied reduction for non-unique indices is defined via the reduce argument ("sum", "prod", "mean", "amax", "amin"). For non-existing indices, the output will be filled with the identity of the applied reduction (1 for "prod" and 0 otherwise).

It is also required that index.size(d) == input.size(d) for all dimensions d. Moreover, if output_size is defined the the values of index must be between 0 and output_size - 1 inclusive.

For a 3-D tensor with reduce="sum", the output is given as:

out[index[i][j][k]][j][k] += input[i][j][k]  # if dim == 0
out[i][index[i][j][k]][k] += input[i][j][k]  # if dim == 1
out[i][j][index[i][j][k]] += input[i][j][k]  # if dim == 2

Note

This out-of-place operation is similar to the in-place versions of scatter_() and scatter_add_(), in which the output tensor is automatically created according to the maximum values in index and filled based on the identity of the applied reduction.

Note

This operation may behave nondeterministically when given tensors on a CUDA device. See Reproducibility for more information.

Parameters
  • input (Tensor) – the input tensor

  • dim (int) – the axis along which to index

  • index (LongTensor) – the indices of elements to scatter and reduce.

  • src (Tensor) – the source elements to scatter and reduce

  • reduce (str) – the reduction operation to apply for non-unique indices ("sum", "prod", "mean", "amax", "amin")

  • output_size (int, optional) – the size of the output at dimension dim. If set to None, will get automatically inferred according to index.max() + 1

Example:

>>> input = torch.tensor([1, 2, 3, 4, 5, 6])
>>> index = torch.tensor([0, 1, 0, 1, 2, 1])
>>> torch.scatter_reduce(input, 0, index, reduce="sum", output_size=3)
tensor([4, 12, 5])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources