-
-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Allow expressions to be shipped to the scheduler #294
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,6 @@ bench/shakespeare.txt | |
.idea/ | ||
.ipynb_checkpoints/ | ||
coverage.xml | ||
|
||
|
||
test_cluster_dump/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -215,16 +215,20 @@ def _wrap_unary_expr_op(self, op=None): | |
# | ||
# Collection classes | ||
# | ||
from dask.typing import DaskCollection2 | ||
|
||
|
||
class FrameBase(DaskMethodsMixin): | ||
class FrameBase(DaskMethodsMixin, DaskCollection2): | ||
"""Base class for Expr-backed Collections""" | ||
|
||
__dask_scheduler__ = staticmethod( | ||
named_schedulers.get("threads", named_schedulers["sync"]) | ||
) | ||
__dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk) | ||
|
||
def __dask_tokenize__(self): | ||
return self.expr._name | ||
|
||
def __init__(self, expr): | ||
self._expr = expr | ||
|
||
|
@@ -311,25 +315,11 @@ def persist(self, fuse=True, **kwargs): | |
return DaskMethodsMixin.persist(out, **kwargs) | ||
|
||
def compute(self, fuse=True, **kwargs): | ||
out = self | ||
if not isinstance(out, Scalar): | ||
out = out.repartition(npartitions=1) | ||
out = out.optimize(fuse=fuse) | ||
out = self.finalize_compute() | ||
return DaskMethodsMixin.compute(out, **kwargs) | ||
|
||
@property | ||
def dask(self): | ||
return self.__dask_graph__() | ||
|
||
def __dask_graph__(self): | ||
out = self.expr | ||
out = out.lower_completely() | ||
return out.__dask_graph__() | ||
|
||
def __dask_keys__(self): | ||
out = self.expr | ||
out = out.lower_completely() | ||
return out.__dask_keys__() | ||
Comment on lines
-329
to
-332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having keys defined on the collections level feels like an abstraction leak. Among other things this is what threw me off for a while when implementing this the first time. I had a hard time distinguishing collections from graphs in the existing code. I find this now a bit clearer in the above PRs |
||
def __dask_graph_factory__(self): | ||
return self.expr | ||
|
||
def simplify(self): | ||
return new_collection(self.expr.simplify()) | ||
|
@@ -342,7 +332,20 @@ def optimize(self, fuse: bool = True): | |
|
||
@property | ||
def dask(self): | ||
return self.__dask_graph__() | ||
# FIXME: This is highly problematic. Defining this as a property can | ||
# cause very unfortunate materializations. Even a mere hasattr(obj, | ||
# "dask") check already triggers this since it's a property, not even a | ||
# method. | ||
return self.__dask_graph_factory__().lower_completely().materialize() | ||
|
||
def finalize_compute(self): | ||
return new_collection(Repartition(self.expr, 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very problematic in the fairly common use case where the client mounts a lot more memory than a single worker. This forces the whole object to be unnecessarily collected onto a single worker and then sent to the client, whereas we could just have the client fetch separate partitions from separate workers (which may or may not happen all at once if it needs to transit through the scheduler). This replicates the issue with the current finalizer methods in dask/dask, which are created by Memory considerations aside, bouncing through a single worker instead of collecting it on the client directly is adding latency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is exactly how it is done right now and I don't intend to touch that behavior now |
||
|
||
def postpersist(self, futures): | ||
if not isinstance(futures, dict): | ||
raise TypeError("Provided `futures` must be a dictionary") | ||
func, args = self.__dask_postpersist__() | ||
return func(futures, *args) | ||
|
||
def __dask_postcompute__(self): | ||
state = new_collection(self.expr.lower_completely()) | ||
|
@@ -3001,7 +3004,9 @@ def __bool__(self): | |
"a conditional statement." | ||
) | ||
|
||
def __dask_postcompute__(self): | ||
def finalize_compute(self): | ||
return self | ||
|
||
return first, () | ||
|
||
def to_series(self, index=0) -> Series: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import dask | ||
import numpy as np | ||
import pandas as pd | ||
import toolz | ||
from dask.array import Array | ||
from dask.base import normalize_token | ||
from dask.core import flatten | ||
|
@@ -27,7 +28,7 @@ | |
safe_head, | ||
total_mem_usage, | ||
) | ||
from dask.dataframe.dispatch import meta_nonempty | ||
from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty | ||
from dask.dataframe.rolling import CombinedOutput, _head_timedelta, overlap_chunk | ||
from dask.dataframe.shuffle import drop_overlap, get_overlap | ||
from dask.dataframe.utils import ( | ||
|
@@ -67,9 +68,6 @@ def ndim(self): | |
except AttributeError: | ||
return 0 | ||
|
||
def __dask_keys__(self): | ||
return [(self._name, i) for i in range(self.npartitions)] | ||
|
||
def optimize(self, **kwargs): | ||
return optimize(self, **kwargs) | ||
|
||
|
@@ -105,6 +103,10 @@ def __getattr__(self, key): | |
f"API function. Current API coverage is documented here: {link}." | ||
) | ||
|
||
@classmethod | ||
def combine_factories(cls, *exprs: Expr, **kwargs) -> Expr: | ||
return Tuple(*exprs) | ||
Comment on lines
+106
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly syntactic sugar and I don't know if I want to keep this. I see a way forward to just move HLGs and old-style collections to the new protocol and nuke a lot of compat code. In this case, such a hook here would be useful. For now, you can ignore this |
||
|
||
@property | ||
def index(self): | ||
return Index(self) | ||
|
@@ -387,6 +389,7 @@ def memory_usage_per_partition(self, index=True, deep=False): | |
|
||
@functools.cached_property | ||
def divisions(self): | ||
# Note: This is triggering a divisions calculation on an hasattr check! | ||
return tuple(self._divisions()) | ||
|
||
def _divisions(self): | ||
|
@@ -3060,3 +3063,33 @@ def _get_meta_map_partitions(args, dfs, func, kwargs, meta, parent_meta): | |
Var, | ||
) | ||
from dask_expr.io import IO, BlockwiseIO | ||
|
||
|
||
class Tuple(Expr): | ||
def __getitem__(self, other): | ||
return self.operands[other] | ||
|
||
def _layer(self) -> dict: | ||
return toolz.merge(op._layer() for op in self.operands) | ||
|
||
def __dask_output_keys__(self) -> list: | ||
all_keys = [] | ||
for op in self.operands: | ||
l = op.__dask_output_keys__() | ||
if len(l) > 1: | ||
raise NotImplementedError() | ||
all_keys.append(l[0]) | ||
return all_keys | ||
|
||
def __len__(self): | ||
return len(self.operands) | ||
|
||
def __iter__(self): | ||
return iter(self.operands) | ||
|
||
|
||
@make_meta_dispatch.register(Expr) | ||
def make_meta_expr(expr, index=None): | ||
# make_meta only access the _meta attribute for collections but Expr is not | ||
# a collection. Still, we're sometimes calling make_meta on Expr instances | ||
return expr._meta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just defined as part of the collections protocol. Not sure if it is actually required