Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow expressions to be shipped to the scheduler #294

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ bench/shakespeare.txt
.idea/
.ipynb_checkpoints/
coverage.xml


test_cluster_dump/*
45 changes: 25 additions & 20 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,20 @@ def _wrap_unary_expr_op(self, op=None):
#
# Collection classes
#
from dask.typing import DaskCollection2


class FrameBase(DaskMethodsMixin):
class FrameBase(DaskMethodsMixin, DaskCollection2):
"""Base class for Expr-backed Collections"""

__dask_scheduler__ = staticmethod(
named_schedulers.get("threads", named_schedulers["sync"])
)
__dask_optimize__ = staticmethod(lambda dsk, keys, **kwargs: dsk)

def __dask_tokenize__(self):
return self.expr._name
Comment on lines +229 to +230
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just defined as part of the collections protocol. Not sure if it is actually required


def __init__(self, expr):
self._expr = expr

Expand Down Expand Up @@ -311,25 +315,11 @@ def persist(self, fuse=True, **kwargs):
return DaskMethodsMixin.persist(out, **kwargs)

def compute(self, fuse=True, **kwargs):
out = self
if not isinstance(out, Scalar):
out = out.repartition(npartitions=1)
out = out.optimize(fuse=fuse)
out = self.finalize_compute()
return DaskMethodsMixin.compute(out, **kwargs)

@property
def dask(self):
return self.__dask_graph__()

def __dask_graph__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_graph__()

def __dask_keys__(self):
out = self.expr
out = out.lower_completely()
return out.__dask_keys__()
Comment on lines -329 to -332
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having keys defined on the collections level feels like an abstraction leak. Among other things this is what threw me off for a while when implementing this the first time. I had a hard time distinguishing collections from graphs in the existing code. I find this now a bit clearer in the above PRs

def __dask_graph_factory__(self):
return self.expr

def simplify(self):
return new_collection(self.expr.simplify())
Expand All @@ -342,7 +332,20 @@ def optimize(self, fuse: bool = True):

@property
def dask(self):
return self.__dask_graph__()
# FIXME: This is highly problematic. Defining this as a property can
# cause very unfortunate materializations. Even a mere hasattr(obj,
# "dask") check already triggers this since it's a property, not even a
# method.
return self.__dask_graph_factory__().lower_completely().materialize()

def finalize_compute(self):
return new_collection(Repartition(self.expr, 1))
Copy link
Collaborator

@crusaderky crusaderky Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very problematic in the fairly common use case where the client mounts a lot more memory than a single worker. This forces the whole object to be unnecessarily collected onto a single worker and then sent to the client, whereas we could just have the client fetch separate partitions from separate workers (which may or may not happen all at once if it needs to transit through the scheduler).

This replicates the issue with the current finalizer methods in dask/dask, which are created by dask.compute(df) but are skipped by df.compute().

Memory considerations aside, bouncing through a single worker instead of collecting it on the client directly is adding latency.

Copy link
Member Author

@fjetter fjetter Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is exactly how it is done right now and I don't intend to touch that behavior now


def postpersist(self, futures):
if not isinstance(futures, dict):
raise TypeError("Provided `futures` must be a dictionary")
func, args = self.__dask_postpersist__()
return func(futures, *args)

def __dask_postcompute__(self):
state = new_collection(self.expr.lower_completely())
Expand Down Expand Up @@ -3001,7 +3004,9 @@ def __bool__(self):
"a conditional statement."
)

def __dask_postcompute__(self):
def finalize_compute(self):
return self

return first, ()

def to_series(self, index=0) -> Series:
Expand Down
29 changes: 26 additions & 3 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import toolz
from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like
from dask.typing import TaskGraphFactory
from dask.utils import funcname, import_required, is_arraylike

from dask_expr._util import _BackendData, _tokenize_deterministic
Expand Down Expand Up @@ -426,7 +427,10 @@ def __getattr__(self, key):
f"API function. Current API coverage is documented here: {link}."
)

def __dask_graph__(self):
def get_annotations(self):
return {}

def materialize(self):
"""Traverse expression tree, collect layers"""
stack = [self]
seen = set()
Expand All @@ -444,9 +448,12 @@ def __dask_graph__(self):

return toolz.merge(layers)

def __dask_output_keys__(self) -> list:
return [(self._name, i) for i in range(self.npartitions)]

@property
def dask(self):
return self.__dask_graph__()
def dask(self) -> dict:
return self.materialize()

def substitute(self, old, new) -> Expr:
"""Substitute a specific term within the expression
Expand Down Expand Up @@ -619,6 +626,22 @@ def _to_graphviz(

return g

@classmethod
def combine_factories(cls, *exprs: Expr) -> Expr:
"""Combine multiple expressions into a single expression

Parameters
----------
exprs:
Expressions to combine

Returns
-------
expr:
Combined expression
"""
raise NotImplementedError()

def visualize(self, filename="dask-expr.svg", format=None, **kwargs):
"""
Visualize the expression graph.
Expand Down
41 changes: 37 additions & 4 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import dask
import numpy as np
import pandas as pd
import toolz
from dask.array import Array
from dask.base import normalize_token
from dask.core import flatten
Expand All @@ -27,7 +28,7 @@
safe_head,
total_mem_usage,
)
from dask.dataframe.dispatch import meta_nonempty
from dask.dataframe.dispatch import make_meta_dispatch, meta_nonempty
from dask.dataframe.rolling import CombinedOutput, _head_timedelta, overlap_chunk
from dask.dataframe.shuffle import drop_overlap, get_overlap
from dask.dataframe.utils import (
Expand Down Expand Up @@ -67,9 +68,6 @@ def ndim(self):
except AttributeError:
return 0

def __dask_keys__(self):
return [(self._name, i) for i in range(self.npartitions)]

def optimize(self, **kwargs):
return optimize(self, **kwargs)

Expand Down Expand Up @@ -105,6 +103,10 @@ def __getattr__(self, key):
f"API function. Current API coverage is documented here: {link}."
)

@classmethod
def combine_factories(cls, *exprs: Expr, **kwargs) -> Expr:
return Tuple(*exprs)
Comment on lines +106 to +108
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly syntactic sugar and I don't know if I want to keep this. I see a way forward to just move HLGs and old-style collections to the new protocol and nuke a lot of compat code. In this case, such a hook here would be useful. For now, you can ignore this


@property
def index(self):
return Index(self)
Expand Down Expand Up @@ -387,6 +389,7 @@ def memory_usage_per_partition(self, index=True, deep=False):

@functools.cached_property
def divisions(self):
# Note: This is triggering a divisions calculation on an hasattr check!
return tuple(self._divisions())

def _divisions(self):
Expand Down Expand Up @@ -3060,3 +3063,33 @@ def _get_meta_map_partitions(args, dfs, func, kwargs, meta, parent_meta):
Var,
)
from dask_expr.io import IO, BlockwiseIO


class Tuple(Expr):
def __getitem__(self, other):
return self.operands[other]

def _layer(self) -> dict:
return toolz.merge(op._layer() for op in self.operands)

def __dask_output_keys__(self) -> list:
all_keys = []
for op in self.operands:
l = op.__dask_output_keys__()
if len(l) > 1:
raise NotImplementedError()
all_keys.append(l[0])
return all_keys

def __len__(self):
return len(self.operands)

def __iter__(self):
return iter(self.operands)


@make_meta_dispatch.register(Expr)
def make_meta_expr(expr, index=None):
# make_meta only access the _meta attribute for collections but Expr is not
# a collection. Still, we're sometimes calling make_meta on Expr instances
return expr._meta
80 changes: 19 additions & 61 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Merge(Expr):
}

def __str__(self):
return f"Merge({self._name[-7:]})"
return f"{type(self).__name__}({self._name[-7:]})"

@property
def kwargs(self):
Expand Down Expand Up @@ -114,56 +114,7 @@ def _bcast_right(self):
return self.right

def _divisions(self):
if self.merge_indexed_left and self.merge_indexed_right:
divisions = list(
unique(merge_sorted(self.left.divisions, self.right.divisions))
)
if len(divisions) == 1:
return (divisions[0], divisions[0])
if self.left.npartitions == 1 and self.right.npartitions == 1:
return (min(divisions), max(divisions))
return divisions

if self._is_single_partition_broadcast:
use_left = self.right_index or _contains_index_name(
self.right._meta, self.right_on
)
use_right = self.left_index or _contains_index_name(
self.left._meta, self.left_on
)
if (
use_right
and self.left.npartitions == 1
and self.how in ("right", "inner")
):
return self.right.divisions
elif (
use_left
and self.right.npartitions == 1
and self.how in ("inner", "left")
):
return self.left.divisions
else:
_npartitions = max(self.left.npartitions, self.right.npartitions)

elif self.is_broadcast_join:
meta_index_names = set(self._meta.index.names)
if (
self.broadcast_side == "left"
and set(self.right._meta.index.names) == meta_index_names
):
return self._bcast_right._divisions()
elif (
self.broadcast_side == "right"
and set(self.left._meta.index.names) == meta_index_names
):
return self._bcast_left._divisions()
_npartitions = max(self.left.npartitions, self.right.npartitions)

else:
_npartitions = self._npartitions

return (None,) * (_npartitions + 1)
return self.lower_completely()._divisions()

@functools.cached_property
def broadcast_side(self):
Expand Down Expand Up @@ -235,7 +186,6 @@ def _lower(self):
left_index = self.left_index
right_index = self.right_index
shuffle_method = self.shuffle_method

# TODO:
# 1. Add/leverage partition statistics

Expand Down Expand Up @@ -437,6 +387,13 @@ class HashJoinP2P(Merge, PartitionsFiltered):
}
is_broadcast_join = False

@property
def npartitions(self):
return self._npartitions or max(self.left.npartitions, self.right.npartitions)

def _divisions(self):
return (None,) * (self.npartitions + 1)

def _lower(self):
return None

Expand Down Expand Up @@ -679,16 +636,17 @@ class BlockwiseMerge(Merge, Blockwise):

is_broadcast_join = False

def dependencies(self):
# FIXME: The Blockwise._divisions is assuming that the left most is not
# a broadcast dep
return sorted(super().dependencies(), key=self._broadcast_dep)

def _divisions(self):
if self.left.npartitions == self.right.npartitions:
return super()._divisions()
is_unknown = any(d is None for d in super()._divisions())
frame = (
self.left if self.left.npartitions > self.right.npartitions else self.right
)
if is_unknown:
return (None,) * (frame.npartitions + 1)
return frame.divisions
# Note: If reversed MRO for Blockwise to take precedence we wouldn't
# need this but we'd also get the _meta implementation of Blockwise even
# though we would want Merge to take precedence. This is probably the
# lesser evil
return Blockwise._divisions(self)

def _lower(self):
return None
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _layer(self):
random_state = self.random_state
state_data = random_state_data(self.frame.npartitions, random_state)

keys = self.frame.__dask_keys__()
keys = self.frame.__dask_output_keys__()
dtype_dsk = {(self._name, 0, 0): (dtype_info, keys[0])}

percentiles_dsk = {
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _layer(self):
# apply combine to batches of intermediate results
j = 1
d = {}
keys = self.frame.__dask_keys__()
keys = self.frame.__dask_output_keys__()
split_every = self.split_every
while len(keys) > 1:
new_keys = []
Expand Down
Loading
Loading