-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinterpolation_kernel.cu
113 lines (92 loc) · 4.18 KB
/
interpolation_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <torch/extension.h>
template <typename scalar_t>
__global__ void trilinear_fw_kernel(
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> feat_interp
){
const int n = blockIdx.x * blockDim.x + threadIdx.x;
const int f = blockIdx.y * blockDim.y + threadIdx.y;
if (n>=feats.size(0) || f>=feats.size(2)) return;
// point -1~1
const scalar_t u = (points[n][0]+1)/2;
const scalar_t v = (points[n][1]+1)/2;
const scalar_t w = (points[n][2]+1)/2;
const scalar_t a = (1-v)*(1-w);
const scalar_t b = (1-v)*w;
const scalar_t c = v*(1-w);
const scalar_t d = 1-a-b-c;
feat_interp[n][f] = (1-u)*(a*feats[n][0][f] +
b*feats[n][1][f] +
c*feats[n][2][f] +
d*feats[n][3][f]) +
u*(a*feats[n][4][f] +
b*feats[n][5][f] +
c*feats[n][6][f] +
d*feats[n][7][f]);
}
torch::Tensor trilinear_fw_cu(
const torch::Tensor feats,
const torch::Tensor points
){
const int N = feats.size(0), F = feats.size(2);
torch::Tensor feat_interp = torch::zeros({N, F}, feats.options());
const dim3 threads(16, 16); // 256=16*16 threads
const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y);
AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_fw_cu",
([&] {
trilinear_fw_kernel<scalar_t><<<blocks, threads>>>(
feats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
points.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
feat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
);
}));
return feat_interp;
}
template <typename scalar_t>
__global__ void trilinear_bw_kernel(
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> dL_dfeat_interp,
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> feats,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> points,
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> dL_dfeats
){
const int n = blockIdx.x * blockDim.x + threadIdx.x;
const int f = blockIdx.y * blockDim.y + threadIdx.y;
if (n>=feats.size(0) || f>=feats.size(2)) return;
// point -1~1
const scalar_t u = (points[n][0]+1)/2;
const scalar_t v = (points[n][1]+1)/2;
const scalar_t w = (points[n][2]+1)/2;
const scalar_t a = (1-v)*(1-w);
const scalar_t b = (1-v)*w;
const scalar_t c = v*(1-w);
const scalar_t d = 1-a-b-c;
dL_dfeats[n][0][f] = (1-u)*a*dL_dfeat_interp[n][f];
dL_dfeats[n][1][f] = (1-u)*b*dL_dfeat_interp[n][f];
dL_dfeats[n][2][f] = (1-u)*c*dL_dfeat_interp[n][f];
dL_dfeats[n][3][f] = (1-u)*d*dL_dfeat_interp[n][f];
dL_dfeats[n][4][f] = u*a*dL_dfeat_interp[n][f];
dL_dfeats[n][5][f] = u*b*dL_dfeat_interp[n][f];
dL_dfeats[n][6][f] = u*c*dL_dfeat_interp[n][f];
dL_dfeats[n][7][f] = u*d*dL_dfeat_interp[n][f];
}
torch::Tensor trilinear_bw_cu(
const torch::Tensor dL_dfeat_interp,
const torch::Tensor feats,
const torch::Tensor points
){
const int N = feats.size(0), F = feats.size(2);
torch::Tensor dL_dfeats = torch::empty({N, 8, F}, feats.options());
const dim3 threads(16, 16);
const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y);
AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_bw_cu",
([&] {
trilinear_bw_kernel<scalar_t><<<blocks, threads>>>(
dL_dfeat_interp.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
feats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
points.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
dL_dfeats.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>()
);
}));
return dL_dfeats;
}