Skip to content

Commit

Permalink
Add multi-partition Join support to cuDF-Polars (#17518)
Browse files Browse the repository at this point in the history
Adds multi-partition `Join` support following the same design as #17441

In order to support parallel joins, this PR also introduces a special `Shuffle` node.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Matthew Murray (https://github.com/Matt711)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17518
  • Loading branch information
rjzamora authored Feb 4, 2025
1 parent 99b207f commit a477a6b
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 0 deletions.
314 changes: 314 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel Join Logic."""

from __future__ import annotations

import operator
from functools import reduce
from typing import TYPE_CHECKING, Any

from cudf_polars.dsl.ir import Join
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
from cudf_polars.experimental.shuffle import Shuffle, _partition_dataframe

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.expr import NamedExpr
from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.parallel import LowerIRTransformer


def _maybe_shuffle_frame(
frame: IR,
on: tuple[NamedExpr, ...],
partition_info: MutableMapping[IR, PartitionInfo],
shuffle_options: dict[str, Any],
output_count: int,
) -> IR:
# Shuffle `frame` if it isn't already shuffled.
if (
partition_info[frame].partitioned_on == on
and partition_info[frame].count == output_count
):
# Already shuffled
return frame
else:
# Insert new Shuffle node
frame = Shuffle(
frame.schema,
on,
shuffle_options,
frame,
)
partition_info[frame] = PartitionInfo(
count=output_count,
partitioned_on=on,
)
return frame


def _make_hash_join(
ir: Join,
output_count: int,
partition_info: MutableMapping[IR, PartitionInfo],
left: IR,
right: IR,
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Shuffle left and right dataframes (if necessary)
shuffle_options: dict[str, Any] = {} # Unused for now
new_left = _maybe_shuffle_frame(
left,
ir.left_on,
partition_info,
shuffle_options,
output_count,
)
new_right = _maybe_shuffle_frame(
right,
ir.right_on,
partition_info,
shuffle_options,
output_count,
)
if left != new_left or right != new_right:
ir = ir.reconstruct([new_left, new_right])
left = new_left
right = new_right

# Record new partitioning info
partitioned_on: tuple[NamedExpr, ...] = ()
if ir.left_on == ir.right_on or (ir.options[0] in ("Left", "Semi", "Anti")):
partitioned_on = ir.left_on
elif ir.options[0] == "Right":
partitioned_on = ir.right_on
partition_info[ir] = PartitionInfo(
count=output_count,
partitioned_on=partitioned_on,
)

return ir, partition_info


def _should_bcast_join(
ir: Join,
left: IR,
right: IR,
partition_info: MutableMapping[IR, PartitionInfo],
output_count: int,
) -> bool:
# Decide if a broadcast join is appropriate.
if partition_info[left].count >= partition_info[right].count:
small_count = partition_info[right].count
large = left
large_on = ir.left_on
else:
small_count = partition_info[left].count
large = right
large_on = ir.right_on

# Avoid the broadcast if the "large" table is already shuffled
large_shuffled = (
partition_info[large].partitioned_on == large_on
and partition_info[large].count == output_count
)

# Broadcast-Join Criteria:
# 1. Large dataframe isn't already shuffled
# 2. Small dataframe has 8 partitions (or fewer).
# TODO: Make this value/heuristic configurable).
# We may want to account for the number of workers.
# 3. The "kind" of join is compatible with a broadcast join
return (
not large_shuffled
and small_count <= 8 # TODO: Make this configurable
and (
ir.options[0] == "Inner"
or (ir.options[0] in ("Left", "Semi", "Anti") and large == left)
or (ir.options[0] == "Right" and large == right)
)
)


def _make_bcast_join(
ir: Join,
output_count: int,
partition_info: MutableMapping[IR, PartitionInfo],
left: IR,
right: IR,
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
if ir.options[0] != "Inner":
shuffle_options: dict[str, Any] = {}
left_count = partition_info[left].count
right_count = partition_info[right].count

# Shuffle the smaller table (if necessary) - Notes:
# - We need to shuffle the smaller table if
# (1) we are not doing an "inner" join,
# and (2) the small table contains multiple
# partitions.
# - We cannot simply join a large-table partition
# to each small-table partition, and then
# concatenate the partial-join results, because
# a non-"inner" join does NOT commute with
# concatenation.
# - In some cases, we can perform the partial joins
# sequentially. However, we are starting with a
# catch-all algorithm that works for all cases.
if left_count >= right_count:
right = _maybe_shuffle_frame(
right,
ir.right_on,
partition_info,
shuffle_options,
right_count,
)
else:
left = _maybe_shuffle_frame(
left,
ir.left_on,
partition_info,
shuffle_options,
left_count,
)

new_node = ir.reconstruct([left, right])
partition_info[new_node] = PartitionInfo(count=output_count)
return new_node, partition_info


@lower_ir_node.register(Join)
def _(
ir: Join, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

left, right = children
output_count = max(partition_info[left].count, partition_info[right].count)
if output_count == 1:
new_node = ir.reconstruct(children)
partition_info[new_node] = PartitionInfo(count=1)
return new_node, partition_info
elif ir.options[0] == "Cross":
raise NotImplementedError(
"Cross join not support for multiple partitions."
) # pragma: no cover

if _should_bcast_join(ir, left, right, partition_info, output_count):
# Create a broadcast join
return _make_bcast_join(
ir,
output_count,
partition_info,
left,
right,
)
else:
# Create a hash join
return _make_hash_join(
ir,
output_count,
partition_info,
left,
right,
)


@generate_ir_tasks.register(Join)
def _(
ir: Join, partition_info: MutableMapping[IR, PartitionInfo]
) -> MutableMapping[Any, Any]:
left, right = ir.children
output_count = partition_info[ir].count

left_partitioned = (
partition_info[left].partitioned_on == ir.left_on
and partition_info[left].count == output_count
)
right_partitioned = (
partition_info[right].partitioned_on == ir.right_on
and partition_info[right].count == output_count
)

if output_count == 1 or (left_partitioned and right_partitioned):
# Partition-wise join
left_name = get_key_name(left)
right_name = get_key_name(right)
return {
key: (
ir.do_evaluate,
*ir._non_child_args,
(left_name, i),
(right_name, i),
)
for i, key in enumerate(partition_info[ir].keys(ir))
}
else:
# Broadcast join
left_parts = partition_info[left]
right_parts = partition_info[right]
if left_parts.count >= right_parts.count:
small_side = "Right"
small_name = get_key_name(right)
small_size = partition_info[right].count
large_name = get_key_name(left)
large_on = ir.left_on
else:
small_side = "Left"
small_name = get_key_name(left)
small_size = partition_info[left].count
large_name = get_key_name(right)
large_on = ir.right_on

graph: MutableMapping[Any, Any] = {}

out_name = get_key_name(ir)
out_size = partition_info[ir].count
split_name = f"split-{out_name}"
inter_name = f"inter-{out_name}"

for part_out in range(out_size):
if ir.options[0] != "Inner":
graph[(split_name, part_out)] = (
_partition_dataframe,
(large_name, part_out),
large_on,
small_size,
)

_concat_list = []
for j in range(small_size):
join_children = [
(
(
operator.getitem,
(split_name, part_out),
j,
)
if ir.options[0] != "Inner"
else (large_name, part_out)
),
(small_name, j),
]
if small_side == "Left":
join_children.reverse()

inter_key = (inter_name, part_out, j)
graph[(inter_name, part_out, j)] = (
ir.do_evaluate,
ir.left_on,
ir.right_on,
ir.options,
*join_children,
)
_concat_list.append(inter_key)
if len(_concat_list) == 1:
graph[(out_name, part_out)] = graph.pop(_concat_list[0])
else:
graph[(out_name, part_out)] = (_concat, _concat_list)

return graph
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.io
import cudf_polars.experimental.join
import cudf_polars.experimental.select
import cudf_polars.experimental.shuffle # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
Expand Down
54 changes: 54 additions & 0 deletions python/cudf_polars/tests/experimental/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("how", ["inner", "left", "right", "full", "semi", "anti"])
@pytest.mark.parametrize("reverse", [True, False])
@pytest.mark.parametrize("max_rows_per_partition", [1, 5, 10, 15])
def test_join(how, reverse, max_rows_per_partition):
engine = pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": max_rows_per_partition},
)
left = pl.LazyFrame(
{
"x": range(15),
"y": ["cat", "dog", "fish"] * 5,
"z": [1.0, 2.0, 3.0, 4.0, 5.0] * 3,
}
)
right = pl.LazyFrame(
{
"xx": range(6),
"y": ["dog", "bird", "fish"] * 2,
"zz": [1, 2] * 3,
}
)
if reverse:
left, right = right, left

q = left.join(right, on="y", how=how)

assert_gpu_result_equal(q, engine=engine, check_row_order=False)

# Join again on the same key.
# (covers code path that avoids redundant shuffles)
if how in ("inner", "left", "right"):
right2 = pl.LazyFrame(
{
"xxx": range(6),
"yyy": ["dog", "bird", "fish"] * 2,
"zzz": [3, 4] * 3,
}
)
q2 = q.join(right2, left_on="y", right_on="yyy", how=how)
assert_gpu_result_equal(q2, engine=engine, check_row_order=False)

0 comments on commit a477a6b

Please sign in to comment.