Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNM: Handle pipeline breakers through avoiding reuse #873

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement pipeline breaker
phofl committed Feb 13, 2024
commit 915f758ec1df30e00b1935a556cbb5aaed056000
7 changes: 7 additions & 0 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
@@ -53,6 +53,12 @@ def _tune_down(self):
def _tune_up(self, parent):
return None

def _pipe_down(self):
return None

def _pipe_up(self, parent):
return None

def _cull_down(self):
return None

@@ -342,6 +348,7 @@ def simplify(self) -> Expr:
while True:
dependents = collect_dependents(expr)
new = expr.simplify_once(dependents=dependents, simplified={})
new = new.rewrite("pipe")
if new._name == expr._name:
break
expr = new
11 changes: 9 additions & 2 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
@@ -1205,7 +1205,7 @@ def _meta(self):
args = [
meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args
]
return self.operation(*args, **self._kwargs)
return make_meta(self.operation(*args, **self._kwargs))

@staticmethod
def operation(df, index, sorted_index):
@@ -2062,6 +2062,9 @@ class ResetIndex(Elemwise):
operation = M.reset_index
_filter_passthrough = True

def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)

@functools.cached_property
def _kwargs(self) -> dict:
kwargs = {"drop": self.drop}
@@ -2099,7 +2102,11 @@ def _simplify_up(self, parent, dependents):
return self._filter_simplification(parent, predicate)

if isinstance(parent, Projection):
if self.frame.ndim == 1 and not self.drop and not isinstance(parent, list):
if (
self.frame.ndim == 1
and not self.drop
and not isinstance(parent.operand("columns"), list)
):
col = parent.operand("columns")
if col in (self.name, "index"):
return
12 changes: 10 additions & 2 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,7 @@ class GroupByApplyConcatApply(ApplyConcatApply, GroupByBase):
@functools.cached_property
def _meta_chunk(self):
meta = meta_nonempty(self.frame._meta)
return self.chunk(meta, *self._by_meta, **self.chunk_kwargs)
return make_meta(self.chunk(meta, *self._by_meta, **self.chunk_kwargs))

@property
def _chunk_cls_args(self):
@@ -201,6 +201,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase):
"split_out",
"sort",
"shuffle_method",
"_pipeline_breaker_counter",
]
_defaults = {
"observed": None,
@@ -212,6 +213,7 @@ class SingleAggregation(GroupByApplyConcatApply, GroupByBase):
"split_out": None,
"sort": None,
"shuffle_method": None,
"_pipeline_breaker_counter": None,
}

groupby_chunk = None
@@ -251,7 +253,11 @@ def aggregate_kwargs(self) -> dict:
}

def _simplify_up(self, parent, dependents):
return groupby_projection(self, parent, dependents)
if isinstance(parent, Projection):
return groupby_projection(self, parent, dependents)

def _pipe_down(self):
return self._adjust_for_pipelinebreaker()


class GroupbyAggregationBase(GroupByApplyConcatApply, GroupByBase):
@@ -1479,6 +1485,7 @@ def _single_agg(
split_out,
self.sort,
shuffle_method,
None,
*self.by,
)
)
@@ -2161,6 +2168,7 @@ def nunique(self, split_every=None, split_out=True, shuffle_method=None):
split_out,
self.sort,
shuffle_method,
None,
*self.by,
)
)
81 changes: 76 additions & 5 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
@@ -507,6 +507,41 @@ def _lower(self):
ignore_index=getattr(self, "ignore_index", True),
)

def _adjust_for_pipelinebreaker(self):
if self._pipeline_breaker_counter is not None:
return
from dask_expr.io.io import IO

seen = set()
stack = self.dependencies()
io_nodes = []
counter = 1

while stack:
node = stack.pop()

if node._name in seen:
continue
seen.add(node._name)

if isinstance(node, IO):
io_nodes.append(node)
continue
elif isinstance(node, ApplyConcatApply):
counter += 1
continue
stack.extend(node.dependencies())
if len(io_nodes) == 0:
return
io_nodes_new = [
io.substitute_parameters({"_pipeline_breaker_counter": counter})
for io in io_nodes
]
expr = self
for io_node_old, io_node_new in zip(io_nodes, io_nodes_new):
expr = expr.substitute(io_node_old, io_node_new)
return expr.substitute_parameters({"_pipeline_breaker_counter": counter})


class Unique(ApplyConcatApply):
_parameters = ["frame", "split_every", "split_out", "shuffle_method"]
@@ -773,13 +808,23 @@ def _simplify_up(self, parent, dependents):
if isinstance(parent, Projection):
return plain_column_projection(self, parent, dependents)

def _pipe_down(self):
return self._adjust_for_pipelinebreaker()


class Sum(Reduction):
_parameters = ["frame", "skipna", "numeric_only", "split_every"]
_parameters = [
"frame",
"skipna",
"numeric_only",
"split_every",
"_pipeline_breaker_counter",
]
_defaults = {
"split_every": False,
"numeric_only": False,
"skipna": True,
"_pipeline_breaker_counter": None,
}
reduction_chunk = M.sum

@@ -1090,8 +1135,21 @@ def reduction_aggregate(cls, vals, order):


class Mean(Reduction):
_parameters = ["frame", "skipna", "numeric_only", "split_every", "axis"]
_defaults = {"skipna": True, "numeric_only": False, "split_every": False, "axis": 0}
_parameters = [
"frame",
"skipna",
"numeric_only",
"split_every",
"axis",
"_pipeline_breaker_counter",
]
_defaults = {
"skipna": True,
"numeric_only": False,
"split_every": False,
"axis": 0,
"_pipeline_breaker_counter": None,
}

@functools.cached_property
def _meta(self):
@@ -1267,8 +1325,21 @@ def _nlast(df, columns, n, ascending):


class NFirst(NLargest):
_parameters = ["frame", "n", "_columns", "ascending", "split_every"]
_defaults = {"n": 5, "_columns": None, "ascending": None, "split_every": None}
_parameters = [
"frame",
"n",
"_columns",
"ascending",
"split_every",
"_pipeline_breaker_counter",
]
_defaults = {
"n": 5,
"_columns": None,
"ascending": None,
"split_every": None,
"_pipeline_breaker_counter": None,
}
reduction_chunk = _nfirst
reduction_aggregate = _nfirst

2 changes: 2 additions & 0 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
@@ -320,6 +320,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
"columns",
"_partitions",
"_series",
"_pipeline_breaker_counter",
]
_defaults = {
"npartitions": None,
@@ -328,6 +329,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"_series": False,
"chunksize": None,
"_pipeline_breaker_counter": None,
}
_pd_length_stats = None
_absorb_projections = True
2 changes: 2 additions & 0 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
@@ -402,6 +402,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"_partitions",
"_series",
"_dataset_info_cache",
"_pipeline_breaker_counter",
]
_defaults = {
"columns": None,
@@ -422,6 +423,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
"_pipeline_breaker_counter": None,
}
_pq_length_stats = None
_absorb_projections = True