From abee7736f65276059ea5d81f80634ae2c1833ae9 Mon Sep 17 00:00:00 2001 From: Tom Julius <42802270+Tom-Julux@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:13:16 +0200 Subject: [PATCH] Speedup patch embedding by moving coord creation on device --- segment_anything/modeling/image_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py index 66351d9d7..0897b4c6e 100644 --- a/segment_anything/modeling/image_encoder.py +++ b/segment_anything/modeling/image_encoder.py @@ -315,8 +315,8 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()]