Shortcuts

CosineEmbeddingLoss

class torch.nn.CosineEmbeddingLoss(margin=0.0, size_average=None, reduce=None, reduction='mean')[source]

Creates a criterion that measures the loss given input tensors x1x_1, x2x_2 and a Tensor label yy with values 1 or -1. Use (y=1y=1) to maximize the cosine similarity of two inputs, and (y=1y=-1) otherwise. This is typically used for learning nonlinear embeddings or semi-supervised learning.

The loss function for each sample is:

loss(x,y)={1cos(x1,x2),if y=1max(0,cos(x1,x2)margin),if y=1\text{loss}(x, y) = \begin{cases} 1 - \cos(x_1, x_2), & \text{if } y = 1 \\ \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1 \end{cases}
Parameters
  • margin (float, optional) – Should be a number from 1-1 to 11, 00 to 0.50.5 is suggested. If margin is missing, the default value is 00.

  • size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True

  • reduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True

  • reduction (str, optional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

Shape:
  • Input1: (N,D)(N, D) or (D)(D), where N is the batch size and D is the embedding dimension.

  • Input2: (N,D)(N, D) or (D)(D), same shape as Input1.

  • Target: (N)(N) or ()().

  • Output: If reduction is 'none', then (N)(N), otherwise scalar.

Examples:

>>> loss = nn.CosineEmbeddingLoss()
>>> input1 = torch.randn(3, 5, requires_grad=True)
>>> input2 = torch.randn(3, 5, requires_grad=True)
>>> target = torch.ones(3)
>>> output = loss(input1, input2, target)
>>> output.backward()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources