diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 431319db6af8..d7f5d95ec29b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k: integer specifying the number of top entries. Returns: - values: array containing the top k values along the last axis. - indices: array containing the indices corresponding to values. + A tuple ``(values, indices)`` where + + - ``values`` is an array containing the top k values along the last axis. + - ``indices`` is an array containing the indices corresponding to values. See also: - - :func:`jax.lax.approx_max_k` - - :func:`jax.lax.approx_min_k` + - :func:`jax.lax.approx_max_k` + - :func:`jax.lax.approx_min_k` + + Example: + Find the largest three values, and their indices, within an array: + + >>> x = jnp.array([9., 3., 6., 4., 10.]) + >>> values, indices = jax.lax.top_k(x, 3) + >>> values + Array([10., 9., 6.], dtype=float32) + >>> indices + Array([4, 0, 2], dtype=int32) """ if core.is_constant_dim(k): k = int(k)