Skip to content

Commit

Permalink
DOC: jax.lax.top_k: fix rendering of return values
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 10, 2024
1 parent 27de854 commit c1598ec
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]:
k: integer specifying the number of top entries.
Returns:
values: array containing the top k values along the last axis.
indices: array containing the indices corresponding to values.
A tuple ``(values, indices)`` where
- ``values`` is an array containing the top k values along the last axis.
- ``indices`` is an array containing the indices corresponding to values.
See also:
- :func:`jax.lax.approx_max_k`
- :func:`jax.lax.approx_min_k`
- :func:`jax.lax.approx_max_k`
- :func:`jax.lax.approx_min_k`
Example:
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10., 9., 6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)
"""
if core.is_constant_dim(k):
k = int(k)
Expand Down

0 comments on commit c1598ec

Please sign in to comment.