From 0e6d045c5a210a70e7b3d59274393a5d1d96a0bc Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 3 Jan 2024 21:54:37 -0800 Subject: [PATCH 1/3] Fixed deadlock in sgmv_shrink kernel caused by imbalanced segments --- csrc/sgmv_flashinfer/sgmv_flashinfer.cuh | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh index 51c3650..a1fc9da 100644 --- a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -45,8 +45,9 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t w_frag[num_k_frags][num_blocks_n][4]; float y_frag[num_blocks_n][8]; - for (uint32_t i = 0; - i < (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16); ++i) { + const uint32_t s_start = s_starts[problem_id], s_end = s_ends[problem_id]; + const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0; + for (uint32_t i = 0; i < num_steps; ++i) { // init y_frag if (bx == 0) { if constexpr (num_blocks_n == 1) { @@ -335,6 +336,20 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, } } } + + // handle the case where one of the segments needs more steps than this one + // to avoid deadlock + if constexpr (cooperative) { + uint32_t max_segment_size = 0; + for (uint32_t i = 0; i < num_problems; ++i) { + max_segment_size = max(max_segment_size, s_ends[i] - s_starts[i]); + } + + const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16); + for (uint32_t i = 0; i < max_steps - num_steps; ++i) { + grid.sync(); + } + } } } // namespace sgmv From cddc1dde1e745b85b40e982182c397255011afe1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 4 Jan 2024 08:24:09 -0800 Subject: [PATCH 2/3] Fixed s --- csrc/sgmv_flashinfer/sgmv_flashinfer.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh index a1fc9da..2a79a9f 100644 --- a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -342,7 +342,7 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, if constexpr (cooperative) { uint32_t max_segment_size = 0; for (uint32_t i = 0; i < num_problems; ++i) { - max_segment_size = max(max_segment_size, s_ends[i] - s_starts[i]); + max_segment_size = max(max_segment_size, s[i + 1] - s[i]); } const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16); From 2a32749fa0652ca998776d56a263939b3d3b5750 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 4 Jan 2024 10:11:20 -0800 Subject: [PATCH 3/3] Added tests --- csrc/sgmv_flashinfer/sgmv_flashinfer.cuh | 3 +-- tests/test_sgmv.py | 8 +++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh index 2a79a9f..af3659c 100644 --- a/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh +++ b/csrc/sgmv_flashinfer/sgmv_flashinfer.cuh @@ -20,7 +20,6 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero; const uint32_t problem_id = blockIdx.y; const uint32_t bx = blockIdx.x; - const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; constexpr uint32_t num_stages = 2; constexpr uint32_t num_k_frags = 8; constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity(); @@ -45,7 +44,7 @@ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s, float* tmp, uint32_t w_frag[num_k_frags][num_blocks_n][4]; float y_frag[num_blocks_n][8]; - const uint32_t s_start = s_starts[problem_id], s_end = s_ends[problem_id]; + const uint32_t s_start = s[problem_id], s_end = s[problem_id + 1]; const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0; for (uint32_t i = 0; i < num_steps; ++i) { // init y_frag diff --git a/tests/test_sgmv.py b/tests/test_sgmv.py index d2cb805..ad645cc 100644 --- a/tests/test_sgmv.py +++ b/tests/test_sgmv.py @@ -51,6 +51,12 @@ def get_lora_lens(bs: int, popularity: str) -> list[int]: a *= alpha lens.append(bs - sum(lens)) return sorted(lens, reverse=True) + if popularity.startswith("skewed"): + if bs < 3: + return [bs] + # Create a highly imbalanced distribution by setting the first segment + # length to 1 and the remainder to the second segment. + return [1, bs - 1] raise KeyError(popularity) @@ -81,7 +87,7 @@ def lora_ref_impl( pytest.param("expand", marks=pytest.mark.xfail(reason="TODO: sgmv expand")), ], ) -@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical"]) +@pytest.mark.parametrize("popularity", ["distinct", "uniform", "zipf:1.5", "identical", "skewed"]) @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 7, 10, 16, 32, 64, 133]) @torch.inference_mode() def test_sgmv_correctness(dtype_str, h, r, direction, popularity, batch_size):