Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add API specification for returning the k largest elements #722

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Dec 14, 2023

Update: This PR

  • resolves RFC: add topk and / or argpartition #629 by adding one new API to the Array API specification

    • top_k: returns a tuple whose first element is an array containing the top k largest (or smallest) values and whose second element is an array containing the indices of those k values.

The design decisions largely follow the discussion below. The specialized methods top_k_indices and top_k_values have been dropped from the specification.


This PR

  • resolves RFC: add topk and / or argpartition #629 by adding 3 new APIs to the Array API specification

    • top_k: returns a tuple whose first element is an array containing the top k largest (or smallest) values and whose second element is an array containing the indices of those k values.
    • top_k_indices: returns an array containing the indices of the k largest (or smallest) values.
    • top_k_values: returns an array containing the k largest (or smallest) values.

Prior Art

As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the k largest or smallest values.

  • NumPy has partition and argpartition, but these return full arrays. When axis is None, NumPy operates on a flattened input array. To get the top k values, one must index and, if wanting sorted values, sort.
  • CuPy has partition and argpartition and follows NumPy; however, for argpartition, for implementation reasons, it performs a full sort.
  • Dask has topk and switches "modes" (largest or smallest) based on whether k is positive or negative.
  • JAX has top_k which only returns values always returns both values and indices, as well as NumPy equivalent partition and argpartition APIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis.
  • PyTorch has topk which always returns both values and indices.
  • TensorFlow has top_k which always returns both values and indices and only supports searching along the last axis.

Proposed APIs

This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.

top_k

def top_k(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> Tuple[array, array]

Returns a tuple containing the k largest (or smallest) elements in x.

def top_k_indices(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the indices of the k largest (or smallest) elements in x.

def top_k_values(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the k largest (or smallest) elements in x.

Design Decision Rationale

  • The default for axis is None in order to match min, max, argmin, and argmax. In those APIs, when axis is None (the default), the functions operate over a flattened array. Given that top_k* may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules for top_k*.
  • axis only supports int and None in order to match argmin and argmax. In min and max, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension.
  • top_k was chosen over topk due to naming concerns discussed elsewhere (namely top k vs to pk). Furthermore, "top k" follows ML conventions, as opposed to maxk/max_k or nlargest/nsmallest as found in other languages.
  • The PR includes three separate APIs following the lead of unique. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both.
  • The PR follows the unique_* naming convention, rather than the arg* naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as in unique_* seems reasonable and follows existing precedent in the specification.
  • The APIs include a "mode" option to specify the type (largest or smallest) of values to return. Most existing array libraries supporting a "top k" API return only the largest values; however, PyTorch supports returning either the smallest or largest and does so via a largest keyword argument. This PR chooses to name the kwarg mode in order to be more explicit (what does largest=False mean to the lay reader?) and follows precedent elsewhere in the specification (e.g., linalg.qr) where mode is used to toggle between different operating modes.
  • The PR does not include a sorted kwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly call sort (except in Dask which doesn't currently support full sorting) after calling top_k or top_k_values, and (c) can be addressed in a future specification extension. Additionally, if we support sorted, we may also want to support a stable kwarg as in sort to allow ensuring that returned indices are consistent when provided the same input array.
  • Leaves unspecified what should happen when k exceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returning m < k values).

Questions

  • Should we be more strict in specifying what should happen when k exceeds the number of elements?
  • Should zero-dimensional arrays be supported?
  • In argmin and argmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given that top_k* can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?
  • Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?
  • Are we okay with None being the default for axis, where the default behavior is searching over a flattened array?

Considerations

The APIs included in this PR have implications for the following array libraries:

  • NumPy: these will be new APIs, and, similar to unique_*, will need to be added to the main namespace.
  • CuPy: same as NumPy.
  • Dask: will need to introduce new APIs; however, the new APIs can be implemented as lightweight wrappers around Dask's existing topk and argtopk.
  • JAX: were JAX to place the APIs in its lax namespace, this PR would introduce breaking changes, as JAX would need to return both values and indices, by default, and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in its numpy namespace, these will simply be new APIs. In both scenarios, JAX will need to add support for axis and mode behavior.
  • PyTorch: these will be new APIs; however, the new APIs can be implemented as lightweight wrappers around PyTorch's existing topk.
  • TensorFlow: additional APIs (top_k_values and top_k_indices). If implemented in its math namespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in its numpy namespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support for axis and mode behavior.

Related Links

@kgryte kgryte added the API extension Adds new functions or objects to the API. label Dec 14, 2023
@kgryte kgryte added this to the v2023 milestone Dec 14, 2023
@kgryte
Copy link
Contributor Author

kgryte commented Dec 14, 2023

cc @ogrisel for visibility since you originally opened #722.

@kgryte
Copy link
Contributor Author

kgryte commented Jan 25, 2024

@ogrisel Do you have opinions on whether having three separate APIs is preferable to having just argtop_k and top_k?

I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.

There may also be other combinations. The PR currently specifies

  • top_k returns [ values, indices ]
  • top_k_values returns values
  • top_k_indices returns indices

We could, e.g., only specify

  • top_k returns [ values, indices ]
  • argtop_k returns indices

or strictly complementary

  • top_k returns values
  • argtop_k returns indices

If you have a feel for what is preferable, that would be great to hear!

@kgryte kgryte modified the milestones: v2023, v2024 Jan 25, 2024
@ogrisel
Copy link

ogrisel commented Mar 29, 2024

Thank you very much for the survey of current implementations and API proposal. From a potential API consumer point of view the main proposal seems good for me.

About the name: topk seems to be a bit more popular than top_k in existing libraries, but using top_k might help reducing breaking changes when adopting this spec but it might be better to hear from library implementers.

Should we be more strict in specifying what should happen when k exceeds the number of elements?

I believe so. As a user I would accept the call to fail with a standard exception type such as ValueError.

In argmin and argmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given that top_k* can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?

I think it's fine to allow topk to be faster by not enforcing this constraint. It would always be possible to add stable=False bool kwarg later to make it possible for the users to request stability of the results maybe at the cost of a performance penalty as done for the xp.sort function.

Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?

I wouldn't like this requirement to hurt speed of adoption by backing libraries. I don't think many users need that in practice.

Are we okay with None being the default for axis, where the default behavior is searching over a flattened array?

I don't think users have an need for this: then can flatten the input by themselves if need. But I have the impression that NumPy (and therefore Array API) often axis=None as a default convention to work on 1d flatten array anyway so I am fine with staying consistent in that regard.

However for this particular case, the default in numpy.argpartition is axis=-1 instead of axis=None. No strong opinion of which is most natural.

Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?

That would remove a lot of value by preventing efficient parallelization by the underlying backend when running topk on a 2D array with axis=0 or axis=1. This pattern was the original motivation for #629 (e.g. k-nearest neighbors classification in scikit-learn).

I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.

The benefit for the 3 function API would be to always be able to optimize memory usage by not allocating unnecessary arrays when k is large enough for this to matter. However I have no good feeling about how much this would really be a problem in practice (and how much the underlying implementation would be able to skip the extra contiguous memory allocation internally).

Something that seems missing from this spec is to specify the handling of NaN values. Maybe and extra kwarg is needed to specify if they should be considered as either smallest or largest, or if they need to be filtere out from the result (but then the result size would be data-dependent and potentially empty arrays which might also cause problems).

Also I assume that nan values are always smaller than +inf and larger than -inf but maybe not all libraries agree on that.

@rgommers
Copy link
Member

Revisiting this topic in preparation of helping it move forward. Quick first comment on:

JAX has top_k which only returns values,

I am not sure if it changed in the meantime or you just misread the JAX docs at the time, but https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.top_k.html says that both values and indices are returned.

I had a peek at the jax.lax implementation as well, but could tell so quickly because I'm not familiar with the .bind method. @jakevdp would you mind confirming that the docs are correct here?

So it looks like PyTorch, JAX and TF all return (values, indices). That seems to align well enough with the top_k definition proposed here.

@jakevdp
Copy link

jakevdp commented Jun 10, 2024

JAX's current top_k functions are in the jax.lax namespace, while the array API implementations will be in the jax.numpy namespace. So there is no issue with having different API conventions here.

The rendered documentation is misleading due to a misplaced colon in the Returns block, but JAX returns (values, indices):

In [2]: x = jax.numpy.arange(100, 110)

In [3]: jax.lax.top_k(x, 3)
Out[3]: [Array([109, 108, 107], dtype=int32), Array([9, 8, 7], dtype=int32)]

@rgommers
Copy link
Member

Here is a PR with a draft implementation for NumPy: numpy/numpy#26666, aligned with the top_k signature in this PR.

@seberg
Copy link
Contributor

seberg commented Jun 17, 2024

The most sane NaN handling IMO, is to sort NaNs always to the end, which however means that if you implement sort="desc" or largest values here, NaNs should end up also at the end, which is opposite of what happens for ascending sort/smallest values.

The annoyance with that is that it means sort behavior diverges for asc/desc sort beyong a [::-1] also for unstable sort.

Of course one can just leave it unspecified here. OTOH, I dunno how much that limits the usability.

JuliaPoo added a commit to JuliaPoo/array-api-tests that referenced this pull request Jun 24, 2024
The purpose of this PR is to continue several threads of discussion
regarding `top_k`.

This follows roughly the specifications of `top_k` in
data-apis/array-api#722, with slight modifications to the API:

```py
def topk(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...
```

Modifications:
- `mode: Literal["largest", "smallest"]` is replaced with
`largest: bool`
- `axis` is no longer a kw-only arg. This makes `torch.topk`
slightly more compatible.

The tests implemented here follows the proposed `top_k`
implementation at numpy/numpy#26666.
@kgryte kgryte added the Needs Changes Pull request which needs changes before being merged. label Sep 19, 2024
@kgryte kgryte changed the title Add API specifications for returning the k largest elements feat: add API specification for returning the k largest elements Dec 12, 2024
@kgryte kgryte removed the Needs Changes Pull request which needs changes before being merged. label Dec 12, 2024
@kgryte kgryte added the Needs Review Pull request which needs review. label Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. Needs Review Pull request which needs review.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: add topk and / or argpartition
6 participants