Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sorted scatter emitter is sometimes very slow #22233

Open
jreiffers opened this issue Feb 3, 2025 · 2 comments
Open

Sorted scatter emitter is sometimes very slow #22233

jreiffers opened this issue Feb 3, 2025 · 2 comments

Comments

@jreiffers
Copy link
Contributor

jreiffers commented Feb 3, 2025

When scatter indices are sorted, performance is sometimes very slow. Here is a reproducer:

import jax
import jax.numpy as jnp

n = 4_000
size = 1091   # Works on my machine, YMMV. Try 1097, 1103, 1109 if it's not slow.
i = jax.random.randint(jax.random.key(0), (70_000,), 0, n)
i = jnp.sort(i)
y = jnp.zeros((70_000, size))


@jax.jit
def scatter_sorted(y, i):
    return (
        jnp.zeros((n, y.shape[1]), dtype=y.dtype).at[i].add(y, indices_are_sorted=True)
    )

scatter_sorted(y, i)

On my machine, the scatter kernel runs in 38ms vs 0.4ms for the unsorted one.

The slowdown does not happen for all input sizes, it appears to depend on the size of the intermediate vector where the updates are accumulated. With a size of 34, everything is fine on my machine, with 35 there are lots of loads/store from/to local memory and performance is terrible.

The issue starts somewhere in LLVM - the initial MLIR is nearly identical, and the PTX already contains the loads/stores from local memory. As far as I can tell, the IR starts to diverge in LoopFullUnrollPass, but I don't know if that's the actual culprit.

@jreiffers
Copy link
Contributor Author

@pifon2a

@pifon2a
Copy link
Contributor

pifon2a commented Feb 3, 2025

Ah, I think I know what it is. I even have a fix for that, but not submitted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants