From 653c1564bf4c46d9758000ea70a610ee3699acf1 Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 15 Sep 2023 10:28:09 +0000 Subject: [PATCH] Fix broadcasting cosine_similarity (#109363) Fixes https://github.com/pytorch/pytorch/issues/109333 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109363 Approved by: https://github.com/peterbell10 --- aten/src/ATen/ExpandUtils.h | 11 ++++++----- aten/src/ATen/native/Distance.cpp | 12 +++++++----- test/test_nn.py | 12 ++++++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 32fb3b483947e0..30077f0ad498d7 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -187,17 +187,18 @@ expand_inplace( // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation. inline std::tuple, c10::MaybeOwned> expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) { - if (to_expand1.sizes().equals(to_expand2.sizes())) { + auto s1 = to_expand1.sym_sizes(); + auto s2 = to_expand2.sym_sizes(); + if (s1.equals(s2)) { return std::make_tuple( c10::MaybeOwned::borrowed(to_expand1), c10::MaybeOwned::borrowed(to_expand2)); } - auto expanded_size = - infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = infer_size_symdimvector(s1, s2); return std::make_tuple( - c10::MaybeOwned::owned(to_expand1.expand(expanded_size)), - c10::MaybeOwned::owned(to_expand2.expand(expanded_size))); + c10::MaybeOwned::owned(to_expand1.expand_symint(expanded_size)), + c10::MaybeOwned::owned(to_expand2.expand_symint(expanded_size))); } inline std::tuple, c10::MaybeOwned> diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index a66ffa1a946d4e..a46ab41967d0f5 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -308,16 +308,18 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub // We accept integral types (and bools lol) but vector_norm does not auto x1_is_int = c10::isIntegralType(x1_.scalar_type(), /*încludeBool=*/true); auto x2_is_int = c10::isIntegralType(x2_.scalar_type(), /*încludeBool=*/true); - auto x1 = x1_is_int ? x1_.to(commonDtype) : x1_; - auto x2 = x2_is_int ? x2_.to(commonDtype) : x2_; + auto x1_t = x1_is_int ? x1_.to(commonDtype) : x1_; + auto x2_t = x2_is_int ? x2_.to(commonDtype) : x2_; + c10::MaybeOwned x1, x2; + std::tie(x1, x2) = expand_outplace(x1_t, x2_t); // We want to divide each tensor by its norm first, as it's more numerically stable. // This keeps the result between -1.0 and 1.0 // We clone them, as we're going to modify them in-place // This allows the gradients to propagate propertly all the way to x1 and x2 - auto x1_norm = at::linalg_vector_norm(x1, 2, /*dim=*/dim, /*keepdim=*/true).clone(); - auto x2_norm = at::linalg_vector_norm(x2, 2, /*dim=*/dim, /*keepdim=*/true).clone(); + auto x1_norm = at::linalg_vector_norm(*x1, 2, /*dim=*/dim, /*keepdim=*/true).clone(); + auto x2_norm = at::linalg_vector_norm(*x2, 2, /*dim=*/dim, /*keepdim=*/true).clone(); { at::NoGradGuard guard; @@ -325,7 +327,7 @@ Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, doub x2_norm.clamp_min_(eps); } - return ((x1 / x1_norm) * (x2 / x2_norm)).sum(dim); + return ((*x1 / x1_norm) * (*x2 / x2_norm)).sum(dim); } }} // namespace at::native diff --git a/test/test_nn.py b/test/test_nn.py index faebb160e3a958..9156ce7772ee12 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5609,6 +5609,18 @@ def test_cosine_similarity(self): out = F.cosine_similarity(input.to(torch.int8), input, dim=-1) self.assertEqual(out, 1.) + # Check broadcasting #109333 + a = torch.ones(2, 3, dtype=torch.float) + b = torch.ones(1, 1, dtype=torch.float) + out = F.cosine_similarity(a, b) + self.assertEqual(out, torch.ones(2, dtype=torch.float)) + + a = torch.ones(2, 3, dtype=torch.float) + b = torch.ones(1, dtype=torch.float) + out = F.cosine_similarity(a, b) + self.assertEqual(out, torch.ones(2, dtype=torch.float)) + + def test_grid_sample_error_checking(self): input = torch.empty(1, 1, 2, 2) grid = torch.empty(1, 1, 1, 2)