Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Limit number of sources in merged scan task #3695

Merged
merged 5 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def set_execution_config(
config: PyDaftExecutionConfig | None = None,
scan_tasks_min_size_bytes: int | None = None,
scan_tasks_max_size_bytes: int | None = None,
max_sources_per_scan_task: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
parquet_split_row_groups_max_files: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
Expand Down Expand Up @@ -368,6 +369,7 @@ def set_execution_config(
scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 384 MiB)
max_sources_per_scan_task: Maximum number of sources in a single ScanTask. (Defaults to 10)
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
parquet_split_row_groups_max_files: Maximum number of files to read in which the row group splitting should happen. (Defaults to 10)
Expand Down Expand Up @@ -406,6 +408,7 @@ def set_execution_config(
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
max_sources_per_scan_task=max_sources_per_scan_task,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
Expand Down
3 changes: 3 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,7 @@ class PyDaftExecutionConfig:
self,
scan_tasks_min_size_bytes: int | None = None,
scan_tasks_max_size_bytes: int | None = None,
max_sources_per_scan_task: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
parquet_split_row_groups_max_files: int | None = None,
sort_merge_join_sort_with_aligned_boundaries: bool | None = None,
Expand Down Expand Up @@ -1731,6 +1732,8 @@ class PyDaftExecutionConfig:
@property
def scan_tasks_max_size_bytes(self) -> int: ...
@property
def max_sources_per_scan_task(self) -> int: ...
@property
def broadcast_join_size_bytes_threshold(self) -> int: ...
@property
def sort_merge_join_sort_with_aligned_boundaries(self) -> bool: ...
Expand Down
2 changes: 2 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl DaftPlanningConfig {
pub struct DaftExecutionConfig {
pub scan_tasks_min_size_bytes: usize,
pub scan_tasks_max_size_bytes: usize,
pub max_sources_per_scan_task: usize,
pub broadcast_join_size_bytes_threshold: usize,
pub sort_merge_join_sort_with_aligned_boundaries: bool,
pub hash_join_partition_size_leniency: f64,
Expand Down Expand Up @@ -69,6 +70,7 @@ impl Default for DaftExecutionConfig {
Self {
scan_tasks_min_size_bytes: 96 * 1024 * 1024, // 96MB
scan_tasks_max_size_bytes: 384 * 1024 * 1024, // 384MB
max_sources_per_scan_task: 10,
broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB
sort_merge_join_sort_with_aligned_boundaries: false,
hash_join_partition_size_leniency: 0.5,
Expand Down
10 changes: 10 additions & 0 deletions src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
#[pyo3(signature = (
scan_tasks_min_size_bytes=None,
scan_tasks_max_size_bytes=None,
max_sources_per_scan_task=None,
broadcast_join_size_bytes_threshold=None,
parquet_split_row_groups_max_files=None,
sort_merge_join_sort_with_aligned_boundaries=None,
Expand Down Expand Up @@ -105,6 +106,7 @@
&self,
scan_tasks_min_size_bytes: Option<usize>,
scan_tasks_max_size_bytes: Option<usize>,
max_sources_per_scan_task: Option<usize>,
broadcast_join_size_bytes_threshold: Option<usize>,
parquet_split_row_groups_max_files: Option<usize>,
sort_merge_join_sort_with_aligned_boundaries: Option<bool>,
Expand Down Expand Up @@ -136,6 +138,9 @@
if let Some(scan_tasks_min_size_bytes) = scan_tasks_min_size_bytes {
config.scan_tasks_min_size_bytes = scan_tasks_min_size_bytes;
}
if let Some(max_sources_per_scan_task) = max_sources_per_scan_task {
config.max_sources_per_scan_task = max_sources_per_scan_task;
}
if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold {
config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold;
}
Expand Down Expand Up @@ -236,6 +241,11 @@
Ok(self.config.scan_tasks_max_size_bytes)
}

#[getter]
fn get_max_sources_per_scan_task(&self) -> PyResult<usize> {
Ok(self.config.max_sources_per_scan_task)
}

Check warning on line 247 in src/common/daft-config/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/common/daft-config/src/python.rs#L245-L247

Added lines #L245 - L247 were not covered by tests

#[getter]
fn get_broadcast_join_size_bytes_threshold(&self) -> PyResult<usize> {
Ok(self.config.broadcast_join_size_bytes_threshold)
Expand Down
18 changes: 12 additions & 6 deletions src/daft-scan/src/scan_task_iters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type BoxScanTaskIter<'a> = Box<dyn Iterator<Item = DaftResult<ScanTaskRef>> + 'a
/// * `scan_tasks`: A Boxed Iterator of ScanTaskRefs to perform merging on
/// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed
/// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask
/// * `max_source_count`: Maximum number of ScanTasks to merge
#[must_use]
fn merge_by_sizes<'a>(
scan_tasks: BoxScanTaskIter<'a>,
Expand Down Expand Up @@ -57,6 +58,7 @@ fn merge_by_sizes<'a>(
target_upper_bound_size_bytes: (limit_bytes * 1.5) as usize,
target_lower_bound_size_bytes: (limit_bytes / 2.) as usize,
accumulator: None,
max_source_count: cfg.max_sources_per_scan_task,
}) as BoxScanTaskIter;
}
}
Expand All @@ -69,6 +71,7 @@ fn merge_by_sizes<'a>(
target_upper_bound_size_bytes: cfg.scan_tasks_max_size_bytes,
target_lower_bound_size_bytes: cfg.scan_tasks_min_size_bytes,
accumulator: None,
max_source_count: cfg.max_sources_per_scan_task,
}) as BoxScanTaskIter
}
}
Expand All @@ -83,6 +86,9 @@ struct MergeByFileSize<'a> {

// Current element being accumulated on
accumulator: Option<ScanTaskRef>,

// Maximum number of files in a merged ScanTask
max_source_count: usize,
}

impl<'a> MergeByFileSize<'a> {
Expand All @@ -92,11 +98,11 @@ impl<'a> MergeByFileSize<'a> {
/// in estimated bytes, as well as other factors including any limit pushdowns.
fn accumulator_ready(&self) -> bool {
// Emit the accumulator as soon as it is bigger than the specified `target_lower_bound_size_bytes`
if let Some(acc) = &self.accumulator
&& let Some(acc_bytes) = acc.estimate_in_memory_size_bytes(Some(self.cfg))
&& acc_bytes >= self.target_lower_bound_size_bytes
{
true
if let Some(acc) = &self.accumulator {
acc.sources.len() >= self.max_source_count
|| acc
.estimate_in_memory_size_bytes(Some(self.cfg))
.map_or(false, |bytes| bytes >= self.target_lower_bound_size_bytes)
} else {
false
}
Expand Down Expand Up @@ -143,7 +149,7 @@ impl<'a> Iterator for MergeByFileSize<'a> {
};
}

// Emit accumulator if ready
// Emit accumulator if ready or if merge count limit is reached
if self.accumulator_ready() {
return self.accumulator.take().map(Ok);
}
Expand Down
12 changes: 12 additions & 0 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,15 @@ def test_merge_scan_task_limit_override(csv_files):
):
df = daft.read_csv(str(csv_files)).limit(1)
assert df.num_partitions() == 3, "Should have 3 partitions [(CSV1, CSV2, CSV3)] since we have a limit 1"


def test_merge_scan_task_up_to_max_sources(csv_files):
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=30,
scan_tasks_max_size_bytes=30,
max_sources_per_scan_task=2,
):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the third CSV is too large to merge with the first two, and max_sources_per_scan_task is set to 2"
Loading