diff --git a/csrc/bgmv/bgmv_impl.cuh b/csrc/bgmv/bgmv_impl.cuh index d2ad687..3c4b973 100644 --- a/csrc/bgmv/bgmv_impl.cuh +++ b/csrc/bgmv/bgmv_impl.cuh @@ -89,7 +89,9 @@ __global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += __shfl_down_sync(0xffffffff, sum, offset); } - y_warpwise[threadIdx.y] = sum; + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = sum; + } block.sync(); #pragma unroll for (size_t i = 0; i < ty; ++i) { @@ -117,10 +119,12 @@ __global__ void bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X, for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += __shfl_down_sync(0xffffffff, sum, offset); } - y_warpwise[threadIdx.y] = - ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) - ? sum - : 0.f; + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + } block.sync(); #pragma unroll for (size_t i = 0; i < ty; ++i) {