Skip to content

Commit

Permalink
add tensor stride support
Browse files Browse the repository at this point in the history
  • Loading branch information
lcp29 committed Feb 18, 2024
1 parent b8ec7a8 commit 03a7288
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 36 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ cd trimesh-ray-optix
pip install .
```
## 📖️ Example
> ⚠️ Triro requires the tensor inputs to be contiguous in memory. You will receive an error if using any non-contiguous input.
```python
import trimesh
import matplotlib.pyplot as plt
Expand All @@ -34,8 +33,8 @@ intersector = RayMeshIntersector(mesh)
x, y = torch.meshgrid([torch.linspace([-1, 1, 800]),
torch.linspace([-1, 1, 800])], indexing='ij')
z = -torch.ones_like(x)
ray_directions = torch.cat([x, y, z], dim=-1).cuda().contiguous()
ray_origins = torch.Tensor([0, 0, 3]).cuda().broadcast_to(ray_directions.shape).contiguous()
ray_directions = torch.cat([x, y, z], dim=-1).cuda()
ray_origins = torch.Tensor([0, 0, 3]).cuda().broadcast_to(ray_directions.shape)

# OptiX, Launch!
hit, front, ray_idx, tri_idx, location, uv = sr.intersects_closest(
Expand All @@ -55,15 +54,15 @@ The above code generates the following result:
## 🕊️ TODOs

- [x] Installation on Windows
- [ ] Supporting Tensor strides
- [x] Supporting Tensor strides

## 🚀️ Performance Comparison

Scene closest-hit ray tracing tested under Ubuntu 22.04, i5-13490F and RTX 3090 ([performance_test.py](test/performance_test.py)):
```
GPU time: 8.121 s / 100000 iters
Trimesh & PyEmbree CPU time: 19.454 s / 100 iters
speedup: 2395x
GPU time: 8.362 s / 100000 iters
Trimesh & PyEmbree CPU time: 18.175 s / 100 iters
speedup: 2173x
```

![](assets/testcase.png)
Expand Down
1 change: 0 additions & 1 deletion test/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def gen_rays(cam_mat, w, h, f):
torch.Tensor([4.3092918e01, -2.9232937e01, 3.7687759e01])
.cuda()
.broadcast_to(ray_dirs.shape)
.contiguous()
)
mesh = trimesh.load('test/models/iscv2.obj', force='mesh')
r = RayMeshIntersector(mesh)
Expand Down
6 changes: 3 additions & 3 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
x, y = torch.meshgrid(
[torch.linspace(-1, 1, 800), torch.linspace(-1, 1, 800)], indexing="ij"
)

z = -torch.ones_like(x)

dirs = torch.stack([x, -y, z], dim=-1).cuda().contiguous()
origin = torch.Tensor([[0, 0, 3]]).cuda().broadcast_to(dirs.shape).contiguous()
dirs = torch.stack([x, -y, z], dim=-1).cuda()
print(f'dirs: {dirs.shape}, stride: {dirs.stride()}')
origin = torch.Tensor([[0, 0, 3]]).cuda().broadcast_to(dirs.shape)

hit, front, ray_idx, tri_idx, location, uv = sr.intersects_closest(
origin, dirs, stream_compaction=True
Expand Down
11 changes: 9 additions & 2 deletions triro/backend/LaunchParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
namespace hmesh {

constexpr int MAX_ANYHIT_SIZE = 8;
constexpr int MAX_SIZE_LENGTH = 4;

struct RayInput {
// ray count
size_t nray;
// ray shape
long rayShape[MAX_SIZE_LENGTH];
// ray origins
float3 *origins;
float *origins;
// ray origins stride
long originsStride[MAX_SIZE_LENGTH];
// ray directions
float3 *directions;
float *directions;
// ray directions stride
long directionsStride[MAX_SIZE_LENGTH];
// hit counts
int *hitCounts;
// global index
Expand Down
53 changes: 40 additions & 13 deletions triro/backend/ray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "CUDABuffer.h"
#include "LaunchParams.h"
#include "base.h"
#include "c10/core/Layout.h"
#include "c10/core/ScalarType.h"
#include "c10/core/TensorOptions.h"
#include "c10/util/ArrayRef.h"
Expand All @@ -19,6 +20,7 @@
#include "optix_types.h"
#include "sbtdef.h"
#include "type.h"
#include <limits>

namespace hmesh {

Expand Down Expand Up @@ -109,10 +111,10 @@ template <typename... Ts> inline bool tensorInputCheck(Ts... ts) {
<< ": input tensors must reside in cuda device.\n";
valid = false;
}
if (!ts.is_contiguous()) {
if (ts.layout() != torch::kStrided) {
std::cerr << "error in file " << __FILE__ << " line "
<< __LINE__
<< ": input tensors must be contiguous.\n";
<< ": input tensor layout must be torch::kStrided.\n";
valid = false;
}
}(),
Expand Down Expand Up @@ -142,10 +144,20 @@ inline std::vector<int64_t> changeLastDim(const c10::IntArrayRef dims,
return dimsVec;
}

template <typename T> inline T *data_ptr(torch::Tensor t) {
template <typename T> inline T *data_ptr(const torch::Tensor &t) {
return (T *)t.data_ptr();
}

template <typename T>
void fillArray(T *dst, c10::ArrayRef<T> src, T defaultValue) {
int i = 0;
const int src_size = src.size();
for (; i < MAX_SIZE_LENGTH - src_size; i++)
dst[i] = defaultValue;
for (; i < MAX_SIZE_LENGTH; i++)
dst[i] = src[i + src_size - MAX_SIZE_LENGTH];
}

torch::Tensor intersectsAny(OptixAccelStructureWrapperCPP as,
const torch::Tensor &origins,
const torch::Tensor &directions) {
Expand All @@ -159,9 +171,12 @@ torch::Tensor intersectsAny(OptixAccelStructureWrapperCPP as,
auto result = torch::empty(resultSize, options);
// fill launch params
LaunchParams lp = {};
lp.rays.origins = data_ptr<float3>(origins);
lp.rays.directions = data_ptr<float3>(directions);
lp.rays.origins = data_ptr<float>(origins);
lp.rays.directions = data_ptr<float>(directions);
lp.rays.nray = nray;
fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<long>::max());
fillArray(lp.rays.originsStride, origins.strides(), (long) 0);
fillArray(lp.rays.directionsStride, directions.strides(), (long) 0);
lp.traversable = as.asHandle;
lp.results.hit = data_ptr<bool>(result);
CUDABuffer lpBuffer;
Expand All @@ -186,9 +201,12 @@ torch::Tensor intersectsFirst(OptixAccelStructureWrapperCPP as,
auto result = torch::empty(resultSize, options);
// fill launch params
LaunchParams lp = {};
lp.rays.origins = data_ptr<float3>(origins);
lp.rays.directions = data_ptr<float3>(directions);
lp.rays.origins = data_ptr<float>(origins);
lp.rays.directions = data_ptr<float>(directions);
lp.rays.nray = nray;
fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<long>::max());
fillArray(lp.rays.originsStride, origins.strides(), (long) 0);
fillArray(lp.rays.directionsStride, directions.strides(), (long) 0);
lp.traversable = as.asHandle;
lp.results.triIdx = data_ptr<int>(result);
CUDABuffer lpBuffer;
Expand Down Expand Up @@ -244,8 +262,11 @@ intersectsClosest(OptixAccelStructureWrapperCPP as, torch::Tensor origins,
// fill and upload launchParams
LaunchParams lp = {};
lp.rays.nray = nray;
lp.rays.origins = data_ptr<float3>(origins);
lp.rays.directions = data_ptr<float3>(directions);
lp.rays.origins = data_ptr<float>(origins);
lp.rays.directions = data_ptr<float>(directions);
fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<long>::max());
fillArray(lp.rays.originsStride, origins.strides(), (long) 0);
fillArray(lp.rays.directionsStride, directions.strides(), (long) 0);

lp.results.hit = data_ptr<bool>(hitbuf);
lp.results.location = data_ptr<float3>(locbuf);
Expand Down Expand Up @@ -281,8 +302,11 @@ torch::Tensor intersectsCount(OptixAccelStructureWrapperCPP as,

LaunchParams lp = {};
lp.rays.nray = nray;
lp.rays.origins = data_ptr<float3>(origins);
lp.rays.directions = data_ptr<float3>(directions);
lp.rays.origins = data_ptr<float>(origins);
lp.rays.directions = data_ptr<float>(directions);
fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<long>::max());
fillArray(lp.rays.originsStride, origins.strides(), (long) 0);
fillArray(lp.rays.directionsStride, directions.strides(), (long) 0);
lp.results.hitCount = data_ptr<int>(hitCountBuf);
lp.traversable = as.asHandle;

Expand Down Expand Up @@ -329,10 +353,13 @@ intersectsLocation(OptixAccelStructureWrapperCPP as, torch::Tensor origins,
LaunchParams lp = {};
lp.traversable = as.asHandle;
lp.rays.nray = nray;
lp.rays.origins = data_ptr<float3>(origins);
lp.rays.directions = data_ptr<float3>(directions);
lp.rays.origins = data_ptr<float>(origins);
lp.rays.directions = data_ptr<float>(directions);
lp.rays.hitCounts = data_ptr<int>(hitCountBuf);
lp.rays.globalIdx = data_ptr<int>(globalIdxBuf);
fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<long>::max());
fillArray(lp.rays.originsStride, origins.strides(), (long) 0);
fillArray(lp.rays.directionsStride, directions.strides(), (long) 0);
lp.results.hitCount = data_ptr<int>(hitCountBuf);
lp.results.location = data_ptr<float3>(locbuf);
lp.results.triIdx = data_ptr<int>(tibuf);
Expand Down
1 change: 1 addition & 0 deletions triro/backend/ray.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "base.h"
#include "optix_types.h"
#include <torch/extension.h>
#include <limits>

namespace hmesh {

Expand Down
53 changes: 43 additions & 10 deletions triro/backend/shaders.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@ __forceinline__ __host__ __device__ T *getPayloadPointer() {
return (T *)p;
}

__forceinline__ __device__ void getIndices(long indices[MAX_SIZE_LENGTH], long shape[MAX_SIZE_LENGTH], int idx) {
#pragma unroll
for (int i = MAX_SIZE_LENGTH - 1; i >= 0; i--) {
indices[i] = idx % shape[i];
idx /= shape[i];
}
}

__forceinline__ __device__ std::tuple<float3, float3> getRay(int idx) {
// corresponding float idx in [0, 3N)
int float_idx = idx * 3;
// thread index in all dims
long indices[MAX_SIZE_LENGTH];
getIndices(indices, launchParams.rays.rayShape, float_idx);
// index in the flat array
long ori_real_idx = 0;
long dir_real_idx = 0;
#pragma unroll
for (int i = 0; i < MAX_SIZE_LENGTH; i++) {
ori_real_idx += indices[i] * launchParams.rays.originsStride[i];
dir_real_idx += indices[i] * launchParams.rays.directionsStride[i];
}
// ray info
float3 ray_origin;
long last_stride = launchParams.rays.originsStride[MAX_SIZE_LENGTH - 1];
ray_origin.x = launchParams.rays.origins[ori_real_idx];
ray_origin.y = launchParams.rays.origins[ori_real_idx + last_stride];
ray_origin.z = launchParams.rays.origins[ori_real_idx + 2 * last_stride];
float3 ray_dir;
last_stride = launchParams.rays.directionsStride[MAX_SIZE_LENGTH - 1];
ray_dir.x = launchParams.rays.directions[dir_real_idx];
ray_dir.y = launchParams.rays.directions[dir_real_idx + last_stride];
ray_dir.z = launchParams.rays.directions[dir_real_idx + 2 * last_stride];
// printf("idx: %d, ray_origin: (%f, %f, %f), ray_dir: (%f, %f, %f), ori_real_idx: %ld, dir_real_idx: %ld, indices: (%ld, %ld, %ld, %ld)\n",
// idx, ray_origin.x, ray_origin.y, ray_origin.z, ray_dir.x, ray_dir.y, ray_dir.z, ori_real_idx, dir_real_idx, indices[0], indices[1], indices[2], indices[3]);
return {ray_origin, ray_dir};
}

// intersects_any

extern "C" __global__ void __miss__intersectsAny() {
Expand All @@ -42,8 +80,7 @@ extern "C" __global__ void __raygen__intersectsAny() {
// intersection result, to be overwritten by the shader
bool isect_result = false;
// ray info
float3 ray_origin = launchParams.rays.origins[idx];
float3 ray_dir = launchParams.rays.directions[idx];
auto [ray_origin, ray_dir] = getRay(idx);
// result pointer
auto [u0, u1] = setPayloadPointer(&isect_result);
optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
Expand All @@ -69,8 +106,7 @@ extern "C" __global__ void __raygen__intersectsFirst() {
// first hit triangle index, to be overwritten by the shader
int ch_idx = -1;
// ray info
float3 ray_origin = launchParams.rays.origins[idx];
float3 ray_dir = launchParams.rays.directions[idx];
auto [ray_origin, ray_dir] = getRay(idx);
// result pointer
auto [u0, u1] = setPayloadPointer(&ch_idx);
optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
Expand Down Expand Up @@ -121,8 +157,7 @@ extern "C" __global__ void __raygen__intersectsClosest() {
int idx = optixGetLaunchIndex().x;
WBData wbdata;
// ray info
float3 ray_origin = launchParams.rays.origins[idx];
float3 ray_dir = launchParams.rays.directions[idx];
auto [ray_origin, ray_dir] = getRay(idx);
// result pointer
auto [u0, u1] = setPayloadPointer(&wbdata);
optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
Expand Down Expand Up @@ -150,8 +185,7 @@ extern "C" __global__ void __raygen__intersectsCount() {
int idx = optixGetLaunchIndex().x;
int hitCount = 0;
// ray info
float3 ray_origin = launchParams.rays.origins[idx];
float3 ray_dir = launchParams.rays.directions[idx];
auto [ray_origin, ray_dir] = getRay(idx);
// result pointer
auto [u0, u1] = setPayloadPointer(&hitCount);
optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
Expand Down Expand Up @@ -198,8 +232,7 @@ extern "C" __global__ void __raygen__intersectsLocation() {
payload.hitCount = 0;
payload.globalIdx = globalIdx;
// ray info
float3 ray_origin = launchParams.rays.origins[idx];
float3 ray_dir = launchParams.rays.directions[idx];
auto [ray_origin, ray_dir] = getRay(idx);
// result pointer
auto [u0, u1] = setPayloadPointer(&payload);
optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
Expand Down

0 comments on commit 03a7288

Please sign in to comment.