From 6b8e2f3467fe400fc0d6e5736b8e001ba5739050 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 10 Jun 2024 13:57:21 -0700 Subject: [PATCH] DOC: jax.lax.top_k: fix docstring rendering & add example --- jax/_src/lax/lax.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 431319db6af8..d7f5d95ec29b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1232,12 +1232,24 @@ def top_k(operand: ArrayLike, k: int) -> tuple[Array, Array]: k: integer specifying the number of top entries. Returns: - values: array containing the top k values along the last axis. - indices: array containing the indices corresponding to values. + A tuple ``(values, indices)`` where + + - ``values`` is an array containing the top k values along the last axis. + - ``indices`` is an array containing the indices corresponding to values. See also: - - :func:`jax.lax.approx_max_k` - - :func:`jax.lax.approx_min_k` + - :func:`jax.lax.approx_max_k` + - :func:`jax.lax.approx_min_k` + + Example: + Find the largest three values, and their indices, within an array: + + >>> x = jnp.array([9., 3., 6., 4., 10.]) + >>> values, indices = jax.lax.top_k(x, 3) + >>> values + Array([10., 9., 6.], dtype=float32) + >>> indices + Array([4, 0, 2], dtype=int32) """ if core.is_constant_dim(k): k = int(k)