Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: top_k tests #274

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from hypothesis import given, note
from hypothesis import strategies as st
from hypothesis.control import assume

from . import _array_module as xp
from . import dtype_helpers as dh
Expand Down Expand Up @@ -203,3 +204,102 @@ def test_searchsorted(data):
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: shapes and values testing


@pytest.mark.unvectorized
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@given(
x=hh.arrays(
dtype=hh.real_dtypes,
shape=hh.shapes(min_dims=1, min_side=1),
elements={"allow_nan": False},
),
data=st.data()
)
def test_top_k(x, data):

if dh.is_float_dtype(x.dtype):
assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check won't work properly, since 0.0 == -0.0 is True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's helpers in array_helpers.py that you can use, although I'm not clear why you need to skip signed zeros here.


axis = data.draw(
st.integers(-x.ndim, x.ndim - 1), label='axis')
largest = data.draw(st.booleans(), label='largest')
if axis is None:
k = data.draw(st.integers(1, math.prod(x.shape)))
else:
k = data.draw(st.integers(1, x.shape[axis]))

kw = dict(
x=x,
k=k,
axis=axis,
largest=largest,
)

(out_values, out_indices) = xp.top_k(x, k, axis, largest=largest)
if axis is None:
x = xp.reshape(x, (-1,))
axis = 0

ph.assert_dtype("top_k", in_dtype=x.dtype, out_dtype=out_values.dtype)
ph.assert_dtype(
"top_k",
in_dtype=x.dtype,
out_dtype=out_indices.dtype,
expected=dh.default_int
)
axes, = sh.normalise_axis(axis, x.ndim)
for arr in [out_values, out_indices]:
ph.assert_shape(
"top_k",
out_shape=arr.shape,
expected=x.shape[:axes] + (k,) + x.shape[axes + 1:],
kw=kw
)

scalar_type = dh.get_scalar_type(x.dtype)

for indices in sh.axes_ndindex(x.shape, (axes,)):

# Test if the values indexed by out_indices corresponds to
# the correct top_k values.
elements = [scalar_type(x[idx]) for idx in indices]
size = len(elements)
correct_order = sorted(
range(size),
key=elements.__getitem__,
reverse=largest
)
correct_order = correct_order[:k]
test_order = [out_indices[idx] for idx in indices[:k]]
# Sort because top_k does not necessarily return the values in
# sorted order.
test_sorted_order = sorted(
test_order,
key=elements.__getitem__,
reverse=largest
)

for y_o, x_o in zip(correct_order, test_sorted_order):
y_idx = indices[y_o]
x_idx = indices[x_o]
ph.assert_0d_equals(
"top_k",
x_repr=f"x[{x_idx}]",
x_val=x[x_idx],
out_repr=f"x[{y_idx}]",
out_val=x[y_idx],
kw=kw,
)

# Test if the values indexed by out_indices corresponds to out_values.
for y_o, x_idx in zip(test_order, indices[:k]):
y_idx = indices[y_o]
ph.assert_0d_equals(
"top_k",
x_repr=f"out_values[{x_idx}]",
x_val=scalar_type(out_values[x_idx]),
out_repr=f"x[{y_idx}]",
out_val=x[y_idx],
kw=kw
)
Loading