diff --git a/daft/context.py b/daft/context.py index d96400980d..76ba82325c 100644 --- a/daft/context.py +++ b/daft/context.py @@ -331,6 +331,7 @@ def set_execution_config( config: PyDaftExecutionConfig | None = None, scan_tasks_min_size_bytes: int | None = None, scan_tasks_max_size_bytes: int | None = None, + max_sources_per_scan_task: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, parquet_split_row_groups_max_files: int | None = None, sort_merge_join_sort_with_aligned_boundaries: bool | None = None, @@ -368,6 +369,7 @@ def set_execution_config( scan_tasks_max_size_bytes: Maximum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but fewer partitions. (Defaults to 384 MiB) + max_sources_per_scan_task: Maximum number of sources in a single ScanTask. (Defaults to 10) broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used. Default is 10 MiB. parquet_split_row_groups_max_files: Maximum number of files to read in which the row group splitting should happen. (Defaults to 10) @@ -406,6 +408,7 @@ def set_execution_config( new_daft_execution_config = old_daft_execution_config.with_config_values( scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, + max_sources_per_scan_task=max_sources_per_scan_task, broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, parquet_split_row_groups_max_files=parquet_split_row_groups_max_files, sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 457bdbb895..3e40c8800f 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1703,6 +1703,7 @@ class PyDaftExecutionConfig: self, scan_tasks_min_size_bytes: int | None = None, scan_tasks_max_size_bytes: int | None = None, + max_sources_per_scan_task: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, parquet_split_row_groups_max_files: int | None = None, sort_merge_join_sort_with_aligned_boundaries: bool | None = None, @@ -1731,6 +1732,8 @@ class PyDaftExecutionConfig: @property def scan_tasks_max_size_bytes(self) -> int: ... @property + def max_sources_per_scan_task(self) -> int: ... + @property def broadcast_join_size_bytes_threshold(self) -> int: ... @property def sort_merge_join_sort_with_aligned_boundaries(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index a23090d753..ddd367a6e3 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -40,6 +40,7 @@ impl DaftPlanningConfig { pub struct DaftExecutionConfig { pub scan_tasks_min_size_bytes: usize, pub scan_tasks_max_size_bytes: usize, + pub max_sources_per_scan_task: usize, pub broadcast_join_size_bytes_threshold: usize, pub sort_merge_join_sort_with_aligned_boundaries: bool, pub hash_join_partition_size_leniency: f64, @@ -69,6 +70,7 @@ impl Default for DaftExecutionConfig { Self { scan_tasks_min_size_bytes: 96 * 1024 * 1024, // 96MB scan_tasks_max_size_bytes: 384 * 1024 * 1024, // 384MB + max_sources_per_scan_task: 10, broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB sort_merge_join_sort_with_aligned_boundaries: false, hash_join_partition_size_leniency: 0.5, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 3371ef349c..bb6c7f9b4d 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -78,6 +78,7 @@ impl PyDaftExecutionConfig { #[pyo3(signature = ( scan_tasks_min_size_bytes=None, scan_tasks_max_size_bytes=None, + max_sources_per_scan_task=None, broadcast_join_size_bytes_threshold=None, parquet_split_row_groups_max_files=None, sort_merge_join_sort_with_aligned_boundaries=None, @@ -105,6 +106,7 @@ impl PyDaftExecutionConfig { &self, scan_tasks_min_size_bytes: Option, scan_tasks_max_size_bytes: Option, + max_sources_per_scan_task: Option, broadcast_join_size_bytes_threshold: Option, parquet_split_row_groups_max_files: Option, sort_merge_join_sort_with_aligned_boundaries: Option, @@ -136,6 +138,9 @@ impl PyDaftExecutionConfig { if let Some(scan_tasks_min_size_bytes) = scan_tasks_min_size_bytes { config.scan_tasks_min_size_bytes = scan_tasks_min_size_bytes; } + if let Some(max_sources_per_scan_task) = max_sources_per_scan_task { + config.max_sources_per_scan_task = max_sources_per_scan_task; + } if let Some(broadcast_join_size_bytes_threshold) = broadcast_join_size_bytes_threshold { config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold; } @@ -236,6 +241,11 @@ impl PyDaftExecutionConfig { Ok(self.config.scan_tasks_max_size_bytes) } + #[getter] + fn get_max_sources_per_scan_task(&self) -> PyResult { + Ok(self.config.max_sources_per_scan_task) + } + #[getter] fn get_broadcast_join_size_bytes_threshold(&self) -> PyResult { Ok(self.config.broadcast_join_size_bytes_threshold) diff --git a/src/daft-scan/src/scan_task_iters/mod.rs b/src/daft-scan/src/scan_task_iters/mod.rs index 226d4c3ee2..87ab489a13 100644 --- a/src/daft-scan/src/scan_task_iters/mod.rs +++ b/src/daft-scan/src/scan_task_iters/mod.rs @@ -26,6 +26,7 @@ type BoxScanTaskIter<'a> = Box> + 'a /// * `scan_tasks`: A Boxed Iterator of ScanTaskRefs to perform merging on /// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed /// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask +/// * `max_source_count`: Maximum number of ScanTasks to merge #[must_use] fn merge_by_sizes<'a>( scan_tasks: BoxScanTaskIter<'a>, @@ -57,6 +58,7 @@ fn merge_by_sizes<'a>( target_upper_bound_size_bytes: (limit_bytes * 1.5) as usize, target_lower_bound_size_bytes: (limit_bytes / 2.) as usize, accumulator: None, + max_source_count: cfg.max_sources_per_scan_task, }) as BoxScanTaskIter; } } @@ -69,6 +71,7 @@ fn merge_by_sizes<'a>( target_upper_bound_size_bytes: cfg.scan_tasks_max_size_bytes, target_lower_bound_size_bytes: cfg.scan_tasks_min_size_bytes, accumulator: None, + max_source_count: cfg.max_sources_per_scan_task, }) as BoxScanTaskIter } } @@ -83,6 +86,9 @@ struct MergeByFileSize<'a> { // Current element being accumulated on accumulator: Option, + + // Maximum number of files in a merged ScanTask + max_source_count: usize, } impl<'a> MergeByFileSize<'a> { @@ -92,11 +98,11 @@ impl<'a> MergeByFileSize<'a> { /// in estimated bytes, as well as other factors including any limit pushdowns. fn accumulator_ready(&self) -> bool { // Emit the accumulator as soon as it is bigger than the specified `target_lower_bound_size_bytes` - if let Some(acc) = &self.accumulator - && let Some(acc_bytes) = acc.estimate_in_memory_size_bytes(Some(self.cfg)) - && acc_bytes >= self.target_lower_bound_size_bytes - { - true + if let Some(acc) = &self.accumulator { + acc.sources.len() >= self.max_source_count + || acc + .estimate_in_memory_size_bytes(Some(self.cfg)) + .map_or(false, |bytes| bytes >= self.target_lower_bound_size_bytes) } else { false } @@ -143,7 +149,7 @@ impl<'a> Iterator for MergeByFileSize<'a> { }; } - // Emit accumulator if ready + // Emit accumulator if ready or if merge count limit is reached if self.accumulator_ready() { return self.accumulator.take().map(Ok); } diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index dd7696d8c4..9bd1773917 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -81,3 +81,15 @@ def test_merge_scan_task_limit_override(csv_files): ): df = daft.read_csv(str(csv_files)).limit(1) assert df.num_partitions() == 3, "Should have 3 partitions [(CSV1, CSV2, CSV3)] since we have a limit 1" + + +def test_merge_scan_task_up_to_max_sources(csv_files): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=30, + scan_tasks_max_size_bytes=30, + max_sources_per_scan_task=2, + ): + df = daft.read_csv(str(csv_files)) + assert ( + df.num_partitions() == 2 + ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the third CSV is too large to merge with the first two, and max_sources_per_scan_task is set to 2"