From bc5eb21f3b1899ed68103aab298d33a650e76e36 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Wed, 24 Apr 2024 09:33:42 +0100 Subject: [PATCH] Allow df_concat to be created without a linker --- splink/database_api.py | 7 +-- splink/linker.py | 15 +++++- splink/profile_data.py | 52 +++++++++++-------- splink/settings.py | 5 ++ splink/vertically_concatenate.py | 76 ++++++++++++++++++---------- tests/test_total_comparison_count.py | 9 +++- 6 files changed, 106 insertions(+), 58 deletions(-) diff --git a/splink/database_api.py b/splink/database_api.py index f12d6b70f8..47a9cd9386 100644 --- a/splink/database_api.py +++ b/splink/database_api.py @@ -233,12 +233,7 @@ def register_multiple_tables( existing_tables = [] if not input_aliases: - # If any of the input_tables are strings, this means they refer - # to tables that already exist in the database and an alias is not needed - input_aliases = [ - table if isinstance(table, str) else f"__splink__{ascii_uid(8)}" - for table in input_tables - ] + input_aliases = [f"__splink__{ascii_uid(8)}" for table in input_tables] for table, alias in zip(input_tables, input_aliases): if isinstance(table, str): diff --git a/splink/linker.py b/splink/linker.py index 744e825710..3e028f27aa 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -124,6 +124,8 @@ _composite_unique_id_from_nodes_sql, ) from .unlinkables import unlinkables_data + +from .column_expression import ColumnExpression from .vertically_concatenate import ( vertically_concatenate_sql, split_df_concat_with_tf_into_two_tables_sqls, @@ -1669,7 +1671,10 @@ def compute_graph_metrics( ) def profile_columns( - self, column_expressions: str | list[str] | None = None, top_n=10, bottom_n=10 + self, + column_expressions: Optional[List[Union[str, ColumnExpression]]] = None, + top_n=10, + bottom_n=10, ): """ Profiles the specified columns of the dataframe initiated with the linker. @@ -2695,7 +2700,13 @@ def count_num_comparisons_from_blocking_rule( pipeline = CTEPipeline() - sql = vertically_concatenate_sql(self) + sds_name = self._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=self._input_tables_dict, + salting_reqiured=self._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") sql = number_of_comparisons_generated_by_blocking_rule_post_filters_sql( diff --git a/splink/profile_data.py b/splink/profile_data.py index b4e219d1ec..c2ef57eea3 100644 --- a/splink/profile_data.py +++ b/splink/profile_data.py @@ -1,10 +1,14 @@ import logging import re from copy import deepcopy +from typing import List, Optional, Union from .charts import altair_or_json, load_chart_definition +from .column_expression import ColumnExpression +from .database_api import DatabaseAPI from .misc import ensure_is_list from .pipeline import CTEPipeline +from .vertically_concatenate import vertically_concatenate_sql logger = logging.getLogger(__name__) @@ -195,7 +199,11 @@ def _add_100_percentile_to_df_percentiles(percentile_rows): def profile_columns( - table_or_tables, db_api, column_expressions=None, top_n=10, bottom_n=10 + table_or_tables, + db_api: DatabaseAPI, + column_expressions: Optional[List[Union[str, ColumnExpression]]] = None, + top_n=10, + bottom_n=10, ): """ Profiles the specified columns of the dataframe initiated with the linker. @@ -235,32 +243,33 @@ def profile_columns( """ splink_df_dict = db_api.register_multiple_tables(table_or_tables) + + pipeline = CTEPipeline() + sql = vertically_concatenate_sql( + splink_df_dict, salting_reqiured=False, source_dataset_column_name=None + ) + pipeline.enqueue_sql(sql, "__splink__df_concat") + input_dataframes = list(splink_df_dict.values()) - input_aliases = list(splink_df_dict.keys()) + input_columns = input_dataframes[0].columns_escaped if not column_expressions: column_expressions_raw = input_columns else: column_expressions_raw = ensure_is_list(column_expressions) - column_expressions = expressions_to_sql(column_expressions_raw) - pipeline = CTEPipeline(input_dataframes) + # If the user has provided any ColumnExpression, convert to string + for i in column_expressions_raw: + if isinstance(i, ColumnExpression): + i.sql_dialect = db_api.sql_dialect - cols_to_select = ", ".join(input_columns) - template = """ - select {cols_to_select} - from {table_name} - """ - - sql_df_concat = " UNION ALL".join( - [ - template.format(cols_to_select=cols_to_select, table_name=table_name) - for table_name in input_aliases - ] - ) + column_expressions_raw = [ + ce.name if isinstance(ce, ColumnExpression) else ce + for ce in column_expressions_raw + ] - pipeline.enqueue_sql(sql_df_concat, "__splink__df_concat") + column_expressions_as_sql = expressions_to_sql(column_expressions_raw) sql = _col_or_expr_frequencies_raw_data_sql( column_expressions_raw, "__splink__df_concat" @@ -271,27 +280,26 @@ def profile_columns( pipeline = CTEPipeline(input_dataframes=[df_raw]) sqls = _get_df_percentiles() - for sql in sqls: - pipeline.enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) df_percentiles = db_api.sql_pipeline_to_splink_dataframe(pipeline) percentile_rows_all = df_percentiles.as_record_dict() pipeline = CTEPipeline(input_dataframes=[df_raw]) - sql = _get_df_top_bottom_n(column_expressions, top_n, "desc") + sql = _get_df_top_bottom_n(column_expressions_as_sql, top_n, "desc") pipeline.enqueue_sql(sql, "__splink__df_top_n") df_top_n = db_api.sql_pipeline_to_splink_dataframe(pipeline) top_n_rows_all = df_top_n.as_record_dict() pipeline = CTEPipeline(input_dataframes=[df_raw]) - sql = _get_df_top_bottom_n(column_expressions, bottom_n, "asc") + sql = _get_df_top_bottom_n(column_expressions_as_sql, bottom_n, "asc") pipeline.enqueue_sql(sql, "__splink__df_bottom_n") df_bottom_n = db_api.sql_pipeline_to_splink_dataframe(pipeline) bottom_n_rows_all = df_bottom_n.as_record_dict() inner_charts = [] - for expression in column_expressions: + for expression in column_expressions_as_sql: percentile_rows = [ p for p in percentile_rows_all if p["group_name"] == _group_name(expression) ] diff --git a/splink/settings.py b/splink/settings.py index 8f57d3e013..b1220b9a08 100644 --- a/splink/settings.py +++ b/splink/settings.py @@ -656,6 +656,11 @@ def human_readable_description(self): @property def salting_required(self): + # see https://github.com/duckdb/duckdb/discussions/9710 + # in duckdb to parallelise we need salting + if self._sql_dialect == "duckdb": + return True + for br in self._blocking_rules_to_generate_predictions: if isinstance(br, SaltedBlockingRule): return True diff --git a/splink/vertically_concatenate.py b/splink/vertically_concatenate.py index 9fc605df5d..98810eabd7 100644 --- a/splink/vertically_concatenate.py +++ b/splink/vertically_concatenate.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict from .pipeline import CTEPipeline from .splink_dataframe import SplinkDataFrame @@ -14,9 +14,13 @@ from .linker import Linker -def vertically_concatenate_sql(linker: Linker) -> str: +def vertically_concatenate_sql( + input_tables: Dict[str, SplinkDataFrame], + salting_reqiured: bool, + source_dataset_column_name: str = None, +) -> str: """ - Using `input_table_or_tables`, create a single table with the columns and + Using `input_tables`, create a single table with the columns and rows required for linking. This table will later be the basis for the generation of pairwise record comparises, @@ -30,37 +34,32 @@ def vertically_concatenate_sql(linker: Linker) -> str: """ # Use column order from first table in dict - df_obj = next(iter(linker._input_tables_dict.values())) + df_obj = next(iter(input_tables.values())) columns = df_obj.columns_escaped select_columns_sql = ", ".join(columns) - salting_reqiured = False - - source_dataset_col_req = ( - linker._settings_obj.column_info_settings.source_dataset_column_name is not None - ) - - salting_reqiured = linker._settings_obj.salting_required - - # see https://github.com/duckdb/duckdb/discussions/9710 - # in duckdb to parallelise we need salting - if linker._sql_dialect == "duckdb": - salting_reqiured = True - if salting_reqiured: salt_sql = ", random() as __splink_salt" else: salt_sql = "" - if source_dataset_col_req: - sqls_to_union = [] + source_dataset_column_already_exists = False + if source_dataset_column_name: + source_dataset_column_already_exists = source_dataset_column_name in [ + c.unquote().name for c in df_obj.columns + ] - create_sds_if_needed = "" + select_columns_sql = ", ".join(columns) + if len(input_tables) > 1: + sqls_to_union = [] - for df_obj in linker._input_tables_dict.values(): - if not linker._source_dataset_column_already_exists: + for df_obj in input_tables.values(): + if source_dataset_column_already_exists: + create_sds_if_needed = "" + else: create_sds_if_needed = f"'{df_obj.templated_name}' as source_dataset," + sql = f""" select {create_sds_if_needed} @@ -81,14 +80,19 @@ def vertically_concatenate_sql(linker: Linker) -> str: def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline: - cache = linker._intermediate_table_cache if "__splink__df_concat_with_tf" in cache: nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf") pipeline.append_input_dataframe(nodes_with_tf) return pipeline - sql = vertically_concatenate_sql(linker) + sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=linker._input_tables_dict, + salting_reqiured=linker._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") sqls = compute_all_term_frequencies_sqls(linker, pipeline) @@ -104,7 +108,13 @@ def compute_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> SplinkDa if "__splink__df_concat_with_tf" in cache: return cache.get_with_logging("__splink__df_concat_with_tf") - sql = vertically_concatenate_sql(linker) + sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=linker._input_tables_dict, + salting_reqiured=linker._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") sqls = compute_all_term_frequencies_sqls(linker, pipeline) @@ -131,7 +141,13 @@ def enqueue_df_concat(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline: pipeline.append_input_dataframe(nodes_with_tf) return pipeline - sql = vertically_concatenate_sql(linker) + sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=linker._input_tables_dict, + salting_reqiured=linker._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") return pipeline @@ -148,7 +164,13 @@ def compute_df_concat(linker: Linker, pipeline: CTEPipeline) -> SplinkDataFrame: df.templated_name = "__splink__df_concat" return df - sql = vertically_concatenate_sql(linker) + sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=linker._input_tables_dict, + salting_reqiured=linker._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") nodes_with_tf = db_api.sql_pipeline_to_splink_dataframe(pipeline) diff --git a/tests/test_total_comparison_count.py b/tests/test_total_comparison_count.py index 6680862052..b78c0e22b2 100644 --- a/tests/test_total_comparison_count.py +++ b/tests/test_total_comparison_count.py @@ -91,7 +91,14 @@ def make_dummy_frame(row_count): linker = Linker(dfs, settings, database_api=db_api) pipeline = CTEPipeline() - sql = vertically_concatenate_sql(linker) + + sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name + + sql = vertically_concatenate_sql( + input_tables=linker._input_tables_dict, + salting_reqiured=linker._settings_obj.salting_required, + source_dataset_column_name=sds_name, + ) pipeline.enqueue_sql(sql, "__splink__df_concat") df_concat = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)