-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: top_k
tests
#274
base: master
Are you sure you want to change the base?
WIP: top_k
tests
#274
Conversation
The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666.
So the next step here would be to implement wrappers for this function in array_api_compat. That way we will be able to see just how complex the required changes are, and also so we can verify that there aren't other incompatibilities, since the tests can't check things later in the test if the things earlier in them the fail. The way I would do this is to make a PR to the compat library that
Here are some development notes for the compat library which should be helpful https://data-apis.org/array-api-compat/dev/index.html, but also feel free to ask any questions here. The main purpose is to just an idea of what the wrappers will look like and to get CI log showing the tests pass (or if something is too hard to wrap, what the error is). So you don't need to worry too much about making everything perfect and mergeable. If If you see something that could be changed in array-api-compat or the test suite to make this process easier, make a note of it. We are going to want to be able to repeat this whole process in the future any time a new function is proposed for inclusion in the array API. |
Quite a few things in PyTorch don't work with smaller uint dtypes. They are skipped in the CI, so you don't need to worry about that. If the torch wrapper is just |
This references the PR data-apis/array-api-tests#274.
This references the PR data-apis/array-api-tests#274.
def test_top_k(x, data): | ||
|
||
if dh.is_float_dtype(x.dtype): | ||
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check won't work properly, since 0.0 == -0.0
is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's helpers in array_helpers.py
that you can use, although I'm not clear why you need to skip signed zeros here.
The purpose of this PR is to continue several threads of discussion regarding
top_k
.This follows roughly the specifications of
top_k
in data-apis/array-api#722, with slight modifications to the API to increase compatibility:Modifications:
mode: Literal["largest", "smallest"]
is replaced withlargest: bool
axis
is no longer a kw-only arg. This makestorch.topk
slightly more compatible.The tests implemented here follows the proposed
top_k
implementation at numpy/numpy#26666.Compatibility concerns with prior art:
numpy
: None if WIP: top_k draft implementation numpy/numpy#26666 gets merged.torch
: In torch the API name istopk
instead, andtorch.topk
is only implemented for certain dtypes (e.g.,topk_cpu
does not implementUInt16
).tensorflow
:axis
keyword does not exist, behaves likeaxis=-1
.JAX
: Same astensorflow
Dask
:largest
keyword does not exist,largest
flag is instead determined by the sign ofk
.