Skip to content

Commit

Permalink
[pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types a…
Browse files Browse the repository at this point in the history
…nymore.

PiperOrigin-RevId: 722615539
  • Loading branch information
cperivol authored and Google-ML-Automation committed Feb 3, 2025
1 parent 28afd25 commit 1e78ebd
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def attention(q, k, v, config: TuningConfig):

def kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
wg_idx = lax.axis_index("wg")
qo_smem2, k_smem, v_smem = smem_buffers
Expand All @@ -85,7 +86,6 @@ def _compute_wg():
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")

plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
Expand Down Expand Up @@ -175,7 +175,7 @@ def _wait():
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
for i in range(max_concurrent_steps):
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
Expand Down Expand Up @@ -268,11 +268,11 @@ def attention_with_pipeline_emitter(q, k, v, config: TuningConfig):

def fa3_kernel(q_ref, k_ref, v_ref, out_ref, scoped):
batch = lax.axis_index("batch")
kv_head = lax.div(lax.axis_index("heads"), q_heads_per_kv_head)
wg_idx = lax.axis_index("wg")
qo_smem2, q_barriers, schedule_barrier = scoped
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")
kv_head = lax.div(q_head, q_heads_per_kv_head)

def perform_schedule_barrier():
if config.use_schedule_barrier:
Expand Down

0 comments on commit 1e78ebd

Please sign in to comment.