Skip to content

Commit

Permalink
Enforce both input tensor shapes of CosineEmbeddingLoss to be equal. (p…
Browse files Browse the repository at this point in the history
…ytorch#112782)

…Added a test to prevent regressions.

Fixes pytorch#112732.

Pull Request resolved: pytorch#112782
Approved by: https://github.com/lezcano
  • Loading branch information
tringwald authored and pytorchmergebot committed Nov 3, 2023
1 parent 2337d8d commit 29716e8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
5 changes: 2 additions & 3 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,17 @@ Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const T
TORCH_CHECK(
targ_dim == 1 || targ_dim == 0,
"0D or 1D target tensor expected, multi-target not supported");

if (targ_dim == 1) {
TORCH_CHECK(
input1.dim() == 2,
input1.dim() == 2 && input2.dim() == 2,
"1D target tensor expects 2D input tensors, but found inputs with sizes ",
input1.sizes(),
" and ",
input2.sizes(),
".");
} else {
TORCH_CHECK(
input1.dim() == 1,
input1.dim() == 1 && input2.dim() == 1,
"0D target tensor expects 1D input tensors, but found inputs with sizes ",
input1.sizes(),
" and ",
Expand Down
16 changes: 16 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5395,6 +5395,22 @@ def test_cosine_embedding_loss_with_diff_type(self):
result = torch.nn.functional.cosine_embedding_loss(input1, input2, target)
self.assertEqual(result.item(), expected.item(), atol=0.001, rtol=0)

def test_cosine_embedding_loss_error_on_diff_shapes(self):
for device in device_():
input1 = torch.empty((0, 0), dtype=torch.double, device=device)
input2 = torch.empty((0,), dtype=torch.double, device=device)
target = torch.empty((0,), dtype=torch.int, device=device)
with self.assertRaisesRegex(RuntimeError, ".*expects 2D.*"):
torch.nn.functional.cosine_embedding_loss(input1, input2, target)

def test_cosine_embedding_loss_error_on_nonexpandable_shapes(self):
for device in device_():
input1 = torch.empty((1, 5), dtype=torch.double, device=device)
input2 = torch.empty((1, 6), dtype=torch.double, device=device)
target = torch.ones((1,), dtype=torch.int, device=device)
with self.assertRaisesRegex(RuntimeError, ".*must match the size.*"):
torch.nn.functional.cosine_embedding_loss(input1, input2, target)

def test_kl_div_with_diff_type(self):
for device in device_():
input = torch.tensor([[2, 3, 5], [3, 2, 1]], dtype=torch.double, device=device)
Expand Down

0 comments on commit 29716e8

Please sign in to comment.