-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add API specification for returning the k
largest elements
#722
base: main
Are you sure you want to change the base?
Conversation
@ogrisel Do you have opinions on whether having three separate APIs is preferable to having just I tried tracking down actual usages of top There may also be other combinations. The PR currently specifies
We could, e.g., only specify
or strictly complementary
If you have a feel for what is preferable, that would be great to hear! |
Thank you very much for the survey of current implementations and API proposal. From a potential API consumer point of view the main proposal seems good for me. About the name:
I believe so. As a user I would accept the call to fail with a standard exception type such as
I think it's fine to allow topk to be faster by not enforcing this constraint. It would always be possible to add
I wouldn't like this requirement to hurt speed of adoption by backing libraries. I don't think many users need that in practice.
I don't think users have an need for this: then can flatten the input by themselves if need. But I have the impression that NumPy (and therefore Array API) often However for this particular case, the default in
That would remove a lot of value by preventing efficient parallelization by the underlying backend when running topk on a 2D array with
The benefit for the 3 function API would be to always be able to optimize memory usage by not allocating unnecessary arrays when k is large enough for this to matter. However I have no good feeling about how much this would really be a problem in practice (and how much the underlying implementation would be able to skip the extra contiguous memory allocation internally). Something that seems missing from this spec is to specify the handling of NaN values. Maybe and extra kwarg is needed to specify if they should be considered as either smallest or largest, or if they need to be filtere out from the result (but then the result size would be data-dependent and potentially empty arrays which might also cause problems). Also I assume that nan values are always smaller than +inf and larger than -inf but maybe not all libraries agree on that. |
Revisiting this topic in preparation of helping it move forward. Quick first comment on:
I am not sure if it changed in the meantime or you just misread the JAX docs at the time, but https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.top_k.html says that both I had a peek at the jax.lax implementation as well, but could tell so quickly because I'm not familiar with the So it looks like PyTorch, JAX and TF all return |
JAX's current The rendered documentation is misleading due to a misplaced colon in the Returns block, but JAX returns In [2]: x = jax.numpy.arange(100, 110)
In [3]: jax.lax.top_k(x, 3)
Out[3]: [Array([109, 108, 107], dtype=int32), Array([9, 8, 7], dtype=int32)] |
Here is a PR with a draft implementation for NumPy: numpy/numpy#26666, aligned with the |
The most sane NaN handling IMO, is to sort NaNs always to the end, which however means that if you implement The annoyance with that is that it means sort behavior diverges for asc/desc sort beyong a Of course one can just leave it unspecified here. OTOH, I dunno how much that limits the usability. |
The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666.
Co-authored-by: ndgrigorian <[email protected]>
k
largest elementsk
largest elements
Update: This PR
resolves RFC: add topk and / or argpartition #629 by adding one new API to the Array API specification
top_k
: returns a tuple whose first element is an array containing the topk
largest (or smallest) values and whose second element is an array containing the indices of thosek
values.The design decisions largely follow the discussion below. The specialized methods
top_k_indices
andtop_k_values
have been dropped from the specification.This PR
resolves RFC: add topk and / or argpartition #629 by adding
3
new APIs to the Array API specificationtop_k
: returns a tuple whose first element is an array containing the topk
largest (or smallest) values and whose second element is an array containing the indices of thosek
values.top_k_indices
: returns an array containing the indices of thek
largest (or smallest) values.top_k_values
: returns an array containing thek
largest (or smallest) values.Prior Art
As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the
k
largest or smallest values.partition
andargpartition
, but these return full arrays. Whenaxis
isNone
, NumPy operates on a flattened input array. To get the topk
values, one must index and, if wanting sorted values, sort.partition
andargpartition
and follows NumPy; however, forargpartition
, for implementation reasons, it performs a full sort.topk
and switches "modes" (largest or smallest) based on whetherk
is positive or negative.top_k
whichonly returns valuesalways returns both values and indices, as well as NumPy equivalentpartition
andargpartition
APIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis.topk
which always returns both values and indices.top_k
which always returns both values and indices and only supports searching along the last axis.Proposed APIs
This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.
top_k
Returns a tuple containing the
k
largest (or smallest) elements inx
.Returns an array containing the indices of the
k
largest (or smallest) elements inx
.Returns an array containing the
k
largest (or smallest) elements inx
.Design Decision Rationale
axis
isNone
in order to matchmin
,max
,argmin
, andargmax
. In those APIs, whenaxis
isNone
(the default), the functions operate over a flattened array. Given thattop_k*
may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules fortop_k*
.axis
only supportsint
andNone
in order to matchargmin
andargmax
. Inmin
andmax
, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension.top_k
was chosen overtopk
due to naming concerns discussed elsewhere (namelytop k
vsto pk
). Furthermore, "top k" follows ML conventions, as opposed tomaxk
/max_k
ornlargest
/nsmallest
as found in other languages.unique
. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both.unique_*
naming convention, rather than thearg*
naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as inunique_*
seems reasonable and follows existing precedent in the specification.largest
keyword argument. This PR chooses to name the kwargmode
in order to be more explicit (what doeslargest=False
mean to the lay reader?) and follows precedent elsewhere in the specification (e.g.,linalg.qr
) wheremode
is used to toggle between different operating modes.sorted
kwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly callsort
(except in Dask which doesn't currently support full sorting) after callingtop_k
ortop_k_values
, and (c) can be addressed in a future specification extension. Additionally, if we supportsorted
, we may also want to support astable
kwarg as insort
to allow ensuring that returned indices are consistent when provided the same input array.k
exceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returningm < k
values).Questions
k
exceeds the number of elements?argmin
andargmax
, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given thattop_k*
can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?min
andmax
?None
being the default foraxis
, where the default behavior is searching over a flattened array?Considerations
The APIs included in this PR have implications for the following array libraries:
unique_*
, will need to be added to the main namespace.topk
andargtopk
.lax
namespace, this PR would introduce breaking changes, asJAX would need to return both values and indices, by default,and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in itsnumpy
namespace, these will simply be new APIs. In both scenarios, JAX will need to add support foraxis
andmode
behavior.topk
.top_k_values
andtop_k_indices
). If implemented in itsmath
namespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in itsnumpy
namespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support foraxis
andmode
behavior.Related Links