Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: top_k tests #274

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

JuliaPoo
Copy link

@JuliaPoo JuliaPoo commented Jun 24, 2024

The purpose of this PR is to continue several threads of discussion regarding top_k.

This follows roughly the specifications of top_k in data-apis/array-api#722, with slight modifications to the API to increase compatibility:

def top_k(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...

Modifications:

  • mode: Literal["largest", "smallest"] is replaced with largest: bool
  • axis is no longer a kw-only arg. This makes torch.topk slightly more compatible.

The tests implemented here follows the proposed top_k implementation at numpy/numpy#26666.

Compatibility concerns with prior art:

  • numpy: None if WIP: top_k draft implementation numpy/numpy#26666 gets merged.
  • torch: In torch the API name is topk instead, and torch.topk is only implemented for certain dtypes (e.g., topk_cpu does not implement UInt16).
  • tensorflow: axis keyword does not exist, behaves like axis=-1.
  • JAX: Same as tensorflow
  • Dask: largest keyword does not exist, largest flag is instead determined by the sign of k.

The purpose of this PR is to continue several threads of discussion
regarding `top_k`.

This follows roughly the specifications of `top_k` in
data-apis/array-api#722, with slight modifications to the API:

```py
def topk(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...
```

Modifications:
- `mode: Literal["largest", "smallest"]` is replaced with
`largest: bool`
- `axis` is no longer a kw-only arg. This makes `torch.topk`
slightly more compatible.

The tests implemented here follows the proposed `top_k`
implementation at numpy/numpy#26666.
@asmeurer
Copy link
Member

So the next step here would be to implement wrappers for this function in array_api_compat. That way we will be able to see just how complex the required changes are, and also so we can verify that there aren't other incompatibilities, since the tests can't check things later in the test if the things earlier in them the fail.

The way I would do this is to make a PR to the compat library that

  1. Modifies the actions file to point to this PR: https://github.com/data-apis/array-api-compat/blob/ac15c526d9769f77c780958a00097dfd183a2a37/.github/workflows/array-api-tests.yml#L53

  2. Adds compat wrappers for the different libraries.

  3. Sparse and JAX are currently not tested in the compat library CI, because their support is entirely in the libraries themselves. So what you can do is make a simple wrapper namespace that just wraps top_k and nothing else. Then add a CI script like

    # .github/array-api-tests-jax.yml
    name: Array API Tests (JAX)
    
    on: [push, pull_request]
    
    jobs:
      array-api-tests-jax:
        uses: ./.github/workflows/array-api-tests.yml
        with:
          package-name: jax
          pytest-extra-args: -k top_k

    This should hopefully be straightforward, but let me know if you run into any issues.

  4. CuPy cannot be tested on CI. However, CuPy should be identical to NumPy, so if you don't have access to a CUDA machine, I wouldn't worry about it for now.

  5. Don't worry about tensorflow. It hasn't been included in the compat library at all yet.

  6. Considering WIP: top_k draft implementation numpy/numpy#26666 is a simple pure Python implementation, if top_k is accepted we can reuse it for the NumPy 1.26 wrapper. For now, though, you can either copy it as the NumPy wrapper to your compat PR to verify it, or change the NumPy dev CI job to point to your NumPy PR.

Here are some development notes for the compat library which should be helpful https://data-apis.org/array-api-compat/dev/index.html, but also feel free to ask any questions here.

The main purpose is to just an idea of what the wrappers will look like and to get CI log showing the tests pass (or if something is too hard to wrap, what the error is). So you don't need to worry too much about making everything perfect and mergeable. If top_k is eventually added to the standard we can cleanup these PRs and use them.

If you see something that could be changed in array-api-compat or the test suite to make this process easier, make a note of it. We are going to want to be able to repeat this whole process in the future any time a new function is proposed for inclusion in the array API.

@asmeurer
Copy link
Member

torch.topk is only implemented for certain dtypes (e.g., topk_cpu does not implement UInt16).

Quite a few things in PyTorch don't work with smaller uint dtypes. They are skipped in the CI, so you don't need to worry about that. If the torch wrapper is just top_k = topk and these tests pass, that will be a good sign that the proposed specification matches the existing PyTorch implementation.

JuliaPoo added a commit to JuliaPoo/array-api-compat that referenced this pull request Jun 26, 2024
JuliaPoo added a commit to JuliaPoo/array-api-compat that referenced this pull request Jun 27, 2024
def test_top_k(x, data):

if dh.is_float_dtype(x.dtype):
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check won't work properly, since 0.0 == -0.0 is True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's helpers in array_helpers.py that you can use, although I'm not clear why you need to skip signed zeros here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants