Skip to content

Commit

Permalink
Merge pull request #10 from kangyiyang/master
Browse files Browse the repository at this point in the history
  • Loading branch information
lcp29 authored Feb 18, 2025
2 parents a5f6252 + 0f33b34 commit 4ade969
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions triro/ray/ray_optix.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,13 @@ def contains_points(
points: Float32[torch.Tensor, "*b 3"],
check_direction: Optional[Float32[torch.Tensor, "3"]] = None,
) -> Bool[torch.Tensor, "*b 3"]:
contains = torch.zeros(points.shape[:-1], dtype=torch.bool)
contains = torch.zeros(points.shape[:-1], dtype=torch.bool, device=points.device)
# check if points are in the aabb
inside_aabb = ~(
(~(points > self.mesh_aabb[0])).any()
| (~(points < self.mesh_aabb[1])).any()
(~(points > self.mesh_aabb[0])).any(dim=1)
| (~(points < self.mesh_aabb[1])).any(dim=1)
)

if not inside_aabb.any():
return contains
default_direction = torch.Tensor(
Expand All @@ -257,19 +258,23 @@ def contains_points(
],
dim=0,
)

# if hit count in two directions are all odd number then the point is likely to be inside the mesh
hit_count_mod_2 = torch.remainder(hit_count, 2)
agree = torch.equal(hit_count_mod_2[0], hit_count_mod_2[1])

agree = torch.all(hit_count_mod_2, dim=0)
contain = inside_aabb & agree & hit_count_mod_2[0] == 1

broken_mask = ~agree & (hit_count == 0).any(dim=-1)
broken_mask = ~agree & (hit_count == 0).any(dim=0)

if not broken_mask.any():
return contain

if check_direction is None:
new_direction = (torch.rand(3) - 0.5).cuda()
contains[broken_mask] = self.contains_points(self, points, new_direction)
contains = contain.cuda()
broken_mask = broken_mask.cuda()
points = points.cuda()
contains[broken_mask] = self.contains_points(points[broken_mask], new_direction)

return contains

Expand Down

0 comments on commit 4ade969

Please sign in to comment.