Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up imports for dask>2024.12.1 support #1424

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions dask_cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import dask
import dask.utils
import dask.dataframe.core
import dask.dataframe as dd
import dask.dataframe.shuffle
from .explicit_comms.dataframe.shuffle import patch_shuffle_expression
from dask.dataframe import DASK_EXPR_ENABLED
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

Expand All @@ -19,12 +18,15 @@
from .proxify_device_objects import proxify_decorator, unproxify_decorator


if not DASK_EXPR_ENABLED:
raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)
Comment on lines -22 to -27
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic into explicit_comms.dataframe.shuffle since that's really the only place in Dask-CUDA where dask-expr matters.

try:
if not dd._dask_expr_enabled():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when/if the _dask_expr_enabled attribute will be removed from the dask.dataframe module. However, when it is removed, we don't need to worry about query-planning being disabled, because that version of dask won't include the legacy API anyway.

raise ValueError(
"Dask-CUDA no longer supports the legacy Dask DataFrame API. "
"Please set the 'dataframe.query-planning' config to `True` "
"or None, or downgrade RAPIDS to <=24.12."
)
except AttributeError:
pass


# Monkey patching Dask to make use of explicit-comms when `DASK_EXPLICIT_COMMS=True`
Expand Down
4 changes: 2 additions & 2 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dask.base import tokenize
from dask.dataframe import DataFrame, Series
from dask.dataframe.core import _concat as dd_concat
from dask.dataframe.shuffle import group_split_dispatch, hash_object_dispatch
from dask.dataframe.dispatch import group_split_dispatch, hash_object_dispatch
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All dispatch functions have been centralized in the dispatch module for a long time now. Many changes in this PR are just using that preferred module.

from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
from distributed.worker import Worker
Expand Down Expand Up @@ -585,7 +585,7 @@ def _layer(self):
# Execute an explicit-comms shuffle
if not hasattr(self, "_ec_shuffled"):
on = self.partitioning_index
df = dask_expr._collection.new_collection(self.frame)
df = dask_expr.new_collection(self.frame)
self._ec_shuffled = shuffle(
df,
[on] if isinstance(on, str) else on,
Expand Down
33 changes: 13 additions & 20 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import dask
import dask.array.core
import dask.dataframe.methods
import dask.dataframe.backends
import dask.dataframe.dispatch
import dask.dataframe.utils
import dask.utils
import distributed.protocol
Expand All @@ -22,16 +23,6 @@

from dask_cuda.disk_io import disk_read

try:
from dask.dataframe.backends import concat_pandas
except ImportError:
from dask.dataframe.methods import concat_pandas

try:
from dask.dataframe.dispatch import make_meta_dispatch as make_meta_dispatch
except ImportError:
from dask.dataframe.utils import make_meta as make_meta_dispatch

from .disk_io import SpillToDiskFile
from .is_device_object import is_device_object

Expand Down Expand Up @@ -893,10 +884,12 @@ def obj_pxy_dask_deserialize(header, frames):
return subclass(pxy)


@dask.dataframe.core.get_parallel_type.register(ProxyObject)
@dask.dataframe.dispatch.get_parallel_type.register(ProxyObject)
def get_parallel_type_proxy_object(obj: ProxyObject):
# Notice, `get_parallel_type()` needs a instance not a type object
return dask.dataframe.core.get_parallel_type(obj.__class__.__new__(obj.__class__))
return dask.dataframe.dispatch.get_parallel_type(
obj.__class__.__new__(obj.__class__)
)


def unproxify_input_wrapper(func):
Expand All @@ -913,24 +906,24 @@ def wrapper(*args, **kwargs):

# Register dispatch of ProxyObject on all known dispatch objects
for dispatch in (
dask.dataframe.core.hash_object_dispatch,
make_meta_dispatch,
dask.dataframe.dispatch.hash_object_dispatch,
dask.dataframe.dispatch.make_meta_dispatch,
dask.dataframe.utils.make_scalar,
dask.dataframe.core.group_split_dispatch,
dask.dataframe.dispatch.group_split_dispatch,
dask.array.core.tensordot_lookup,
dask.array.core.einsum_lookup,
dask.array.core.concatenate_lookup,
):
dispatch.register(ProxyObject, unproxify_input_wrapper(dispatch))

dask.dataframe.methods.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.methods.concat)
dask.dataframe.dispatch.concat_dispatch.register(
ProxyObject, unproxify_input_wrapper(dask.dataframe.dispatch.concat)
)


# We overwrite the Dask dispatch of Pandas objects in order to
# deserialize all ProxyObjects before concatenating
dask.dataframe.methods.concat_dispatch.register(
dask.dataframe.dispatch.concat_dispatch.register(
(pandas.DataFrame, pandas.Series, pandas.Index),
unproxify_input_wrapper(concat_pandas),
unproxify_input_wrapper(dask.dataframe.backends.concat_pandas),
)
16 changes: 8 additions & 8 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,27 +504,27 @@ def test_pandas():
df1 = pandas.DataFrame({"a": range(10)})
df2 = pandas.DataFrame({"a": range(10)})

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert_frame_equal(res, got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert_frame_equal(res, got)

df1 = pandas.Series(range(10))
df2 = pandas.Series(range(10))

res = dask.dataframe.methods.concat([df1, df2])
got = dask.dataframe.methods.concat([df1, df2])
res = dask.dataframe.dispatch.concat([df1, df2])
got = dask.dataframe.dispatch.concat([df1, df2])
assert all(res == got)

got = dask.dataframe.methods.concat([proxy_object.asproxy(df1), df2])
got = dask.dataframe.dispatch.concat([proxy_object.asproxy(df1), df2])
assert all(res == got)

got = dask.dataframe.methods.concat([df1, proxy_object.asproxy(df2)])
got = dask.dataframe.dispatch.concat([df1, proxy_object.asproxy(df2)])
assert all(res == got)


Expand Down
Loading