From 9d85bc0604d559a79d1795bb01129fbe41c37cd5 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Mon, 3 Feb 2025 05:13:28 -0800 Subject: [PATCH] [pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore. PiperOrigin-RevId: 722615539 --- jax/experimental/pallas/ops/gpu/attention_mgpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 7ca13669a1eb..2855930e8f6c 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -71,6 +71,7 @@ def attention(q, k, v, config: TuningConfig): def kernel(q_ref, k_ref, v_ref, out_ref, scoped): batch = lax.axis_index("batch") + q_head = lax.axis_index("heads") smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped wg_idx = lax.axis_index("wg") qo_smem2, k_smem, v_smem = smem_buffers @@ -85,7 +86,6 @@ def _compute_wg(): plgpu.set_max_registers(232, action="increase") qo_smem = qo_smem2.at[wg_idx] q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q - q_head = lax.axis_index("heads") plgpu.copy_gmem_to_smem( q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head], @@ -175,7 +175,7 @@ def _wait(): @pl.when(wg_idx == 2) def _memory_wg(): plgpu.set_max_registers(40, action="decrease") - kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) + kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype)) for i in range(max_concurrent_steps): s = (batch, pl.ds(i * block_kv, block_kv), kv_head) plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i]) @@ -268,11 +268,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig): def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped): batch = lax.axis_index("batch") - kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head) wg_idx = lax.axis_index("wg") qo_smem2, q_barriers, schedule_barrier = scoped q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q q_head = lax.axis_index("heads") + kv_head = lax.div(q_head, q_heads_per_kv_head) def perform_schedule_barrier(): if config.use_schedule_barrier: