You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When scatter indices are sorted, performance is sometimes very slow. Here is a reproducer:
import jax
import jax.numpy as jnp
n = 4_000
size = 1091 # Works on my machine, YMMV. Try 1097, 1103, 1109 if it's not slow.
i = jax.random.randint(jax.random.key(0), (70_000,), 0, n)
i = jnp.sort(i)
y = jnp.zeros((70_000, size))
@jax.jit
def scatter_sorted(y, i):
return (
jnp.zeros((n, y.shape[1]), dtype=y.dtype).at[i].add(y, indices_are_sorted=True)
)
scatter_sorted(y, i)
On my machine, the scatter kernel runs in 38ms vs 0.4ms for the unsorted one.
The slowdown does not happen for all input sizes, it appears to depend on the size of the intermediate vector where the updates are accumulated. With a size of 34, everything is fine on my machine, with 35 there are lots of loads/store from/to local memory and performance is terrible.
The issue starts somewhere in LLVM - the initial MLIR is nearly identical, and the PTX already contains the loads/stores from local memory. As far as I can tell, the IR starts to diverge in LoopFullUnrollPass, but I don't know if that's the actual culprit.
The text was updated successfully, but these errors were encountered:
When scatter indices are sorted, performance is sometimes very slow. Here is a reproducer:
On my machine, the scatter kernel runs in 38ms vs 0.4ms for the unsorted one.
The slowdown does not happen for all input sizes, it appears to depend on the size of the intermediate vector where the updates are accumulated. With a size of 34, everything is fine on my machine, with 35 there are lots of loads/store from/to local memory and performance is terrible.
The issue starts somewhere in LLVM - the initial MLIR is nearly identical, and the PTX already contains the loads/stores from local memory. As far as I can tell, the IR starts to diverge in LoopFullUnrollPass, but I don't know if that's the actual culprit.
The text was updated successfully, but these errors were encountered: