diff --git a/splink/blocking.py b/splink/blocking.py index 47160a499c..b1468067c8 100644 --- a/splink/blocking.py +++ b/splink/blocking.py @@ -7,7 +7,9 @@ from sqlglot.expressions import Column, Join from sqlglot.optimizer.eliminate_joins import join_condition +from .input_column import InputColumn from .misc import ensure_is_list +from .splink_dataframe import SplinkDataFrame from .unique_id_concat import _composite_unique_id_from_nodes_sql logger = logging.getLogger(__name__) @@ -27,13 +29,26 @@ def blocking_rule_to_obj(br): sqlglot_dialect = br.get("sql_dialect", None) salting_partitions = br.get("salting_partitions", None) - if salting_partitions is None: - return BlockingRule(blocking_rule, sqlglot_dialect) - else: + arrays_to_explode = br.get("arrays_to_explode", None) + + if arrays_to_explode is not None and salting_partitions is not None: + raise ValueError( + "Splink does not support blocking rules that are " + " both salted and exploding" + ) + + if salting_partitions is not None: return SaltedBlockingRule( blocking_rule, sqlglot_dialect, salting_partitions ) + if arrays_to_explode is not None: + return ExplodingBlockingRule( + blocking_rule, sqlglot_dialect, arrays_to_explode + ) + + return BlockingRule(blocking_rule, sqlglot_dialect) + else: br = BlockingRule(br) return br @@ -69,8 +84,7 @@ def add_preceding_rules(self, rules): rules = ensure_is_list(rules) self.preceding_rules = rules - @property - def exclude_pairs_generated_by_this_rule_sql(self): + def exclude_pairs_generated_by_this_rule_sql(self, linker: Linker): """A SQL string specifying how to exclude the results of THIS blocking rule from subseqent blocking statements, so that subsequent statements do not produce duplicate pairs @@ -81,15 +95,15 @@ def exclude_pairs_generated_by_this_rule_sql(self): # meaning these comparisons get lost return f"coalesce(({self.blocking_rule_sql}),false)" - @property - def exclude_pairs_generated_by_all_preceding_rules_sql(self): + def exclude_pairs_generated_by_all_preceding_rules_sql(self, linker: Linker): """A SQL string that excludes the results of ALL previous blocking rules from the pairwise comparisons generated. """ if not self.preceding_rules: return "" or_clauses = [ - br.exclude_pairs_generated_by_this_rule_sql for br in self.preceding_rules + br.exclude_pairs_generated_by_this_rule_sql(linker) + for br in self.preceding_rules ] previous_rules = " OR ".join(or_clauses) return f"AND NOT ({previous_rules})" @@ -107,8 +121,8 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability) inner join {linker._input_tablename_r} as r on ({self.blocking_rule_sql}) - {self.exclude_pairs_generated_by_all_preceding_rules_sql} {where_condition} + {self.exclude_pairs_generated_by_all_preceding_rules_sql(linker)} """ return sql @@ -233,14 +247,181 @@ def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability) inner join {linker._input_tablename_r} as r on ({self.blocking_rule_sql} {salt_condition}) - {self.exclude_pairs_generated_by_all_preceding_rules_sql} {where_condition} + {self.exclude_pairs_generated_by_all_preceding_rules_sql(linker)} """ sqls.append(sql) return " UNION ALL ".join(sqls) +class ExplodingBlockingRule(BlockingRule): + def __init__( + self, + blocking_rule: BlockingRule | dict | str, + sqlglot_dialect: str = None, + array_columns_to_explode: list = [], + ): + super().__init__(blocking_rule, sqlglot_dialect) + self.array_columns_to_explode: List[str] = array_columns_to_explode + self.exploded_id_pair_table: SplinkDataFrame = None + + def marginal_exploded_id_pairs_table_sql(self, linker: Linker, br: BlockingRule): + """generates a table of the marginal id pairs from the exploded blocking rule + i.e. pairs are only created that match this blocking rule and NOT any of + the preceding blocking rules + """ + + settings_obj = linker._settings_obj + unique_id_col = settings_obj._unique_id_column_name + + link_type = settings_obj._link_type + + if linker._two_dataset_link_only: + link_type = "two_dataset_link_only" + + if linker._self_link_mode: + link_type = "self_link" + + where_condition = _sql_gen_where_condition( + link_type, settings_obj._unique_id_input_columns + ) + + id_expr_l = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "l" + ) + id_expr_r = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "r" + ) + + if link_type == "two_dataset_link_only": + where_condition = ( + where_condition + " and l.source_dataset < r.source_dataset" + ) + + sql = f""" + select distinct + {id_expr_l} as {unique_id_col}_l, + {id_expr_r} as {unique_id_col}_r + from __splink__df_concat_with_tf_unnested as l + inner join __splink__df_concat_with_tf_unnested as r + on ({br.blocking_rule_sql}) + {where_condition} + {self.exclude_pairs_generated_by_all_preceding_rules_sql(linker)} + """ + + return sql + + def drop_materialised_id_pairs_dataframe(self): + self.exploded_id_pair_table.drop_table_from_database_and_remove_from_cache() + self.exploded_id_pair_table = None + + def exclude_pairs_generated_by_this_rule_sql(self, linker: Linker): + """A SQL string specifying how to exclude the results + of THIS blocking rule from subseqent blocking statements, + so that subsequent statements do not produce duplicate pairs + """ + + unique_id_column = linker._settings_obj._unique_id_column_name + splink_df = self.exploded_id_pair_table + ids_to_compare_sql = f"select * from {splink_df.physical_name}" + + settings_obj = linker._settings_obj + id_expr_l = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "l" + ) + id_expr_r = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "r" + ) + + return f"""EXISTS ( + select 1 from ({ids_to_compare_sql}) as ids_to_compare + where ( + {id_expr_l} = ids_to_compare.{unique_id_column}_l and + {id_expr_r} = ids_to_compare.{unique_id_column}_r + ) + ) + """ + + def create_blocked_pairs_sql(self, linker: Linker, where_condition, probability): + columns_to_select = linker._settings_obj._columns_to_select_for_blocking + sql_select_expr = ", ".join(columns_to_select) + + if self.exploded_id_pair_table is None: + raise ValueError( + "Exploding blocking rules are not supported for the function you have" + " called." + ) + settings_obj = linker._settings_obj + id_expr_l = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "l" + ) + id_expr_r = _composite_unique_id_from_nodes_sql( + settings_obj._unique_id_input_columns, "r" + ) + + exploded_id_pair_table = self.exploded_id_pair_table + unique_id_col = linker._settings_obj._unique_id_column_name + sql = f""" + select + {sql_select_expr}, + '{self.match_key}' as match_key + {probability} + from {exploded_id_pair_table.physical_name} as pairs + left join {linker._input_tablename_l} as l + on pairs.{unique_id_col}_l={id_expr_l} + left join {linker._input_tablename_r} as r + on pairs.{unique_id_col}_r={id_expr_r} + """ + return sql + + def as_dict(self): + output = super().as_dict() + output["arrays_to_explode"] = self.array_columns_to_explode + return output + + +def materialise_exploded_id_tables(linker: Linker): + settings_obj = linker._settings_obj + + blocking_rules = settings_obj._blocking_rules_to_generate_predictions + exploding_blocking_rules = [ + br for br in blocking_rules if isinstance(br, ExplodingBlockingRule) + ] + exploded_tables = [] + + input_dataframe = linker._initialise_df_concat_with_tf() + input_colnames = {col.name for col in input_dataframe.columns} + + for br in exploding_blocking_rules: + arrays_to_explode_quoted = [ + InputColumn(colname, sql_dialect=linker._sql_dialect).quote().name + for colname in br.array_columns_to_explode + ] + expl_sql = linker._explode_arrays_sql( + "__splink__df_concat_with_tf", + br.array_columns_to_explode, + list(input_colnames.difference(arrays_to_explode_quoted)), + ) + + linker._enqueue_sql( + expl_sql, + "__splink__df_concat_with_tf_unnested", + ) + + base_name = "__splink__marginal_exploded_ids_blocking_rule" + table_name = f"{base_name}_mk_{br.match_key}" + + sql = br.marginal_exploded_id_pairs_table_sql(linker, br) + + linker._enqueue_sql(sql, table_name) + + marginal_ids_table = linker._execute_sql_pipeline([input_dataframe]) + br.exploded_id_pair_table = marginal_ids_table + exploded_tables.append(marginal_ids_table) + return exploding_blocking_rules + + def _sql_gen_where_condition(link_type, unique_id_cols): id_expr_l = _composite_unique_id_from_nodes_sql(unique_id_cols, "l") id_expr_r = _composite_unique_id_from_nodes_sql(unique_id_cols, "r") diff --git a/splink/duckdb/linker.py b/splink/duckdb/linker.py index 79fc0b43c1..1685bb136f 100644 --- a/splink/duckdb/linker.py +++ b/splink/duckdb/linker.py @@ -319,3 +319,23 @@ def export_to_duckdb_file(self, output_path, delete_intermediate_tables=False): new_con = duckdb.connect(database=output_path) new_con.execute(f"IMPORT DATABASE '{tmpdir}';") new_con.close() + + def _explode_arrays_sql( + self, tbl_name, columns_to_explode, other_columns_to_retain + ): + """Generated sql that explodes one or more columns in a table""" + columns_to_explode = columns_to_explode.copy() + other_columns_to_retain = other_columns_to_retain.copy() + # base case + if len(columns_to_explode) == 0: + return f"select {','.join(other_columns_to_retain)} from {tbl_name}" + else: + column_to_explode = columns_to_explode.pop() + cols_to_select = ( + [f"unnest({column_to_explode}) as {column_to_explode}"] + + other_columns_to_retain + + columns_to_explode + ) + other_columns_to_retain.append(column_to_explode) + return f"""select {','.join(cols_to_select)} + from ({self._explode_arrays_sql(tbl_name,columns_to_explode,other_columns_to_retain)})""" # noqa: E501 diff --git a/splink/linker.py b/splink/linker.py index 4a5001c91e..7aae0180f2 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -39,6 +39,7 @@ BlockingRule, block_using_rules_sqls, blocking_rule_to_obj, + materialise_exploded_id_tables, ) from .cache_dict_with_logging import CacheDictWithLogging from .charts import ( @@ -1432,10 +1433,15 @@ def deterministic_link(self) -> SplinkDataFrame: self._deterministic_link_mode = True concat_with_tf = self._initialise_df_concat_with_tf() + exploding_br_with_id_tables = materialise_exploded_id_tables(self) + sqls = block_using_rules_sqls(self) for sql in sqls: self._enqueue_sql(sql["sql"], sql["output_table_name"]) - return self._execute_sql_pipeline([concat_with_tf]) + + deterministic_link_df = self._execute_sql_pipeline([concat_with_tf]) + [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables] + return deterministic_link_df def estimate_u_using_random_sampling( self, max_pairs: int = None, seed: int = None, *, target_rows=None @@ -1655,7 +1661,15 @@ def estimate_parameters_using_expectation_maximisation( # to be used by the training linkers self._initialise_df_concat_with_tf() + # Extract the blocking rule + # Check it's a BlockingRule (not a SaltedBlockingRule, ExlpodingBlockingRule) + # and raise error if not specfically a BlockingRule blocking_rule = blocking_rule_to_obj(blocking_rule) + if type(blocking_rule) is not BlockingRule: + raise TypeError( + "EM blocking rules must be plain blocking rules, not " + "salted or exploding blocking rules" + ) if comparisons_to_deactivate: # If user provided a string, convert to Comparison object @@ -1754,6 +1768,10 @@ def predict( if nodes_with_tf: input_dataframes.append(nodes_with_tf) + # If exploded blocking rules exist, we need to materialise + # the tables of ID pairs + exploding_br_with_id_tables = materialise_exploded_id_tables(self) + sqls = block_using_rules_sqls(self) for sql in sqls: self._enqueue_sql(sql["sql"], sql["output_table_name"]) @@ -1779,6 +1797,9 @@ def predict( predictions = self._execute_sql_pipeline(input_dataframes) self._predict_warning() + + [b.drop_materialised_id_pairs_dataframe() for b in exploding_br_with_id_tables] + return predictions def find_matches_to_new_records( @@ -4024,3 +4045,10 @@ def _detect_blocking_rules_for_em_training( "suggested_blocking_rules_as_splink_brs" ].iloc[0] return suggestion + + def _explode_arrays_sql( + self, tbl_name, columns_to_explode, other_columns_to_retain + ): + raise NotImplementedError( + f"Unnesting blocking rules are not supported for {type(self)}" + ) diff --git a/splink/spark/linker.py b/splink/spark/linker.py index c19e09f716..1514e63b0c 100644 --- a/splink/spark/linker.py +++ b/splink/spark/linker.py @@ -540,3 +540,21 @@ def _check_ansi_enabled_if_converting_dates(self): classed as comparison level = "ELSE". Ensure date strings are cleaned to remove bad dates \n""" ) + + def _explode_arrays_sql( + self, tbl_name, columns_to_explode, other_columns_to_retain + ): + """Generated sql that explodes one or more columns in a table""" + columns_to_explode = columns_to_explode.copy() + other_columns_to_retain = other_columns_to_retain.copy() + if len(columns_to_explode) == 0: + return f"select {','.join(other_columns_to_retain)} from {tbl_name}" + else: + column_to_explode = columns_to_explode.pop() + cols_to_select = ( + [f"explode({column_to_explode}) as {column_to_explode}"] + + other_columns_to_retain + + columns_to_explode + ) + return f"""select {','.join(cols_to_select)} + from ({self._explode_arrays_sql(tbl_name,columns_to_explode,other_columns_to_retain+[column_to_explode])})""" # noqa: E501 diff --git a/tests/test_array_based_blocking.py b/tests/test_array_based_blocking.py new file mode 100644 index 0000000000..a8ca333d7b --- /dev/null +++ b/tests/test_array_based_blocking.py @@ -0,0 +1,254 @@ +import random + +import pandas as pd + +from tests.decorator import mark_with_dialects_including + + +@mark_with_dialects_including("duckdb", "spark", pass_dialect=True) +def test_simple_example_link_only(test_helpers, dialect): + data_l = pd.DataFrame.from_dict( + [ + {"unique_id": 1, "gender": "m", "postcode": ["2612", "2000"]}, + {"unique_id": 2, "gender": "m", "postcode": ["2612", "2617"]}, + {"unique_id": 3, "gender": "f", "postcode": ["2617"]}, + ] + ) + data_r = pd.DataFrame.from_dict( + [ + {"unique_id": 4, "gender": "m", "postcode": ["2617", "2600"]}, + {"unique_id": 5, "gender": "f", "postcode": ["2000"]}, + {"unique_id": 6, "gender": "m", "postcode": ["2617", "2612", "2000"]}, + ] + ) + helper = test_helpers[dialect] + settings = { + "link_type": "link_only", + "blocking_rules_to_generate_predictions": [ + { + "blocking_rule": "l.gender = r.gender and l.postcode = r.postcode", + "arrays_to_explode": ["postcode"], + }, + "l.gender = r.gender", + ], + "comparisons": [helper.cl.array_intersect_at_sizes("postcode", [1])], + } + ## the pairs returned by the first blocking rule are (1,6),(2,4),(2,6) + ## the additional pairs returned by the second blocking rule are (1,4),(3,5) + linker = helper.Linker([data_l, data_r], settings, **helper.extra_linker_args()) + linker.debug_mode = False + returned_triples = linker.predict().as_pandas_dataframe()[ + ["unique_id_l", "unique_id_r", "match_key"] + ] + returned_triples = { + (unique_id_l, unique_id_r, match_key) + for unique_id_l, unique_id_r, match_key in zip( + returned_triples.unique_id_l, + returned_triples.unique_id_r, + returned_triples.match_key, + ) + } + expected_triples = {(1, 6, "0"), (2, 4, "0"), (2, 6, "0"), (1, 4, "1"), (3, 5, "1")} + assert expected_triples == returned_triples + + +def generate_array_based_datasets_helper( + n_rows=1000, n_array_based_columns=3, n_distinct_values=1000, array_size=3, seed=1 +): + random.seed(seed) + datasets = [] + for _k in range(2): + results_dict = {} + results_dict["cluster"] = list(range(n_rows)) + for i in range(n_array_based_columns): + col = [] + for j in range(n_rows): + col.append(random.sample(range(n_distinct_values), array_size)) + if random.random() < 0.8 or i == n_array_based_columns - 1: + col[-1].append(j) + random.shuffle(col[-1]) + results_dict[f"array_column_{i}"] = col + datasets.append(pd.DataFrame.from_dict(results_dict)) + return datasets + + +@mark_with_dialects_including("duckdb", "spark", pass_dialect=True) +def test_array_based_blocking_with_random_data_dedupe(test_helpers, dialect): + helper = test_helpers[dialect] + input_data_l, input_data_r = generate_array_based_datasets_helper() + input_data_l = input_data_l.assign( + unique_id=[str(cluster_id) + "-0" for cluster_id in input_data_l.cluster] + ) + input_data_r = input_data_r.assign( + unique_id=[str(cluster_id) + "-1" for cluster_id in input_data_r.cluster] + ) + input_data = pd.concat([input_data_l, input_data_r]) + blocking_rules = [ + { + "blocking_rule": """l.array_column_0 = r.array_column_0 + and l.array_column_1 = r.array_column_1""", + "arrays_to_explode": ["array_column_0", "array_column_1"], + }, + { + "blocking_rule": """l.array_column_0 = r.array_column_0 + and l.array_column_1 = r.array_column_1 + and l.array_column_2 = r.array_column_2""", + "arrays_to_explode": ["array_column_0", "array_column_1"], + }, + { + "blocking_rule": "l.array_column_2 = r.array_column_2", + "arrays_to_explode": ["array_column_2"], + }, + ] + settings = { + "link_type": "dedupe_only", + "blocking_rules_to_generate_predictions": blocking_rules, + "unique_id_column_name": "unique_id", + "additional_columns_to_retain": ["cluster"], + "comparisons": [helper.cl.array_intersect_at_sizes("array_column_1", [1])], + } + linker = helper.Linker(input_data, settings, **helper.extra_linker_args()) + linker.debug_mode = False + df_predict = linker.predict().as_pandas_dataframe() + ## check that there are no duplicates in the output + assert ( + df_predict.drop_duplicates(["unique_id_l", "unique_id_r"]).shape[0] + == df_predict.shape[0] + ) + + ## check that the output contains no links with match_key=1, + ## since all pairs returned by the second rule should also be + ## returned by the first rule and so should be filtered out + assert df_predict[df_predict.match_key == 1].shape[0] == 0 + + ## check that all 1000 true matches are in the output + ## (this is guaranteed by how the data was generated) + assert sum(df_predict.cluster_l == df_predict.cluster_r) == 1000 + + +@mark_with_dialects_including("duckdb", "spark", pass_dialect=True) +def test_array_based_blocking_with_random_data_link_only(test_helpers, dialect): + helper = test_helpers[dialect] + input_data_l, input_data_r = generate_array_based_datasets_helper() + blocking_rules = [ + { + "blocking_rule": """l.array_column_0 = r.array_column_0 + and l.array_column_1 = r.array_column_1""", + "arrays_to_explode": ["array_column_0", "array_column_1"], + }, + { + "blocking_rule": """l.array_column_0 = r.array_column_0 + and l.array_column_1 = r.array_column_1 + and l.array_column_2=r.array_column_2""", + "arrays_to_explode": ["array_column_0", "array_column_1", "array_column_2"], + }, + { + "blocking_rule": "l.array_column_2 = r.array_column_2", + "arrays_to_explode": ["array_column_2"], + }, + ] + settings = { + "link_type": "link_only", + "blocking_rules_to_generate_predictions": blocking_rules, + "unique_id_column_name": "cluster", + "additional_columns_to_retain": ["cluster"], + "comparisons": [helper.cl.array_intersect_at_sizes("array_column_1", [1])], + } + linker = helper.Linker( + [input_data_l, input_data_r], settings, **helper.extra_linker_args() + ) + linker.debug_mode = False + df_predict = linker.predict().as_pandas_dataframe() + + ## check that we get no within-dataset links + within_dataset_links = df_predict[ + df_predict.source_dataset_l == df_predict.source_dataset_r + ].shape[0] + assert within_dataset_links == 0 + + ## check that no pair of ids appears twice in the output + assert ( + df_predict.drop_duplicates(["cluster_l", "cluster_r"]).shape[0] + == df_predict.shape[0] + ) + + ## check that the second blocking rule returns no matches, + ## since every pair matching the second rule will also match the first, + ## and so should be filtered out + assert df_predict[df_predict.match_key == 1].shape[0] == 0 + + ## check that all 1000 true matches are returned + assert sum(df_predict.cluster_l == df_predict.cluster_r) == 1000 + + +@mark_with_dialects_including("duckdb", pass_dialect=True) +def test_link_only_unique_id_ambiguity(test_helpers, dialect): + helper = test_helpers[dialect] + data_1 = [ + { + "unique_id": 1, + "first_name": "John", + "surname": "Doe", + "postcode": ["A", "B"], + }, + {"unique_id": 3, "first_name": "John", "surname": "Doe", "postcode": ["B"]}, + ] + + data_2 = [ + {"unique_id": 3, "first_name": "John", "surname": "Smith", "postcode": ["A"]}, + ] + + data_3 = [ + {"unique_id": 3, "first_name": "John", "surname": "Smith", "postcode": ["A"]}, + {"unique_id": 4, "first_name": "John", "surname": "Doe", "postcode": ["C"]}, + ] + + df_1 = pd.DataFrame(data_1) + df_2 = pd.DataFrame(data_2) + df_3 = pd.DataFrame(data_3) + + settings = { + "link_type": "link_only", + "blocking_rules_to_generate_predictions": [ + { + "blocking_rule": """l.postcode = r.postcode + and l.first_name = r.first_name""", + "arrays_to_explode": ["postcode"], + }, + "l.surname = r.surname", + ], + "comparisons": [ + helper.cl.exact_match("first_name"), + helper.cl.exact_match("surname"), + helper.cl.exact_match("postcode"), + ], + "retain_intermediate_calculation_columns": True, + } + + linker = helper.Linker( + [df_1, df_2, df_3], settings, input_table_aliases=["a_", "b_", "c_"] + ) + returned_triples = linker.predict().as_pandas_dataframe()[ + [ + "source_dataset_l", + "unique_id_l", + "source_dataset_r", + "unique_id_r", + "match_key", + ] + ] + + actual_triples = { + tuple(t) for t in returned_triples.to_dict(orient="split")["data"] + } + assert len(returned_triples) == 5 + + rule1_tuples = { + ("a_", 1, "b_", 3, "0"), + ("a_", 1, "c_", 3, "0"), + ("b_", 3, "c_", 3, "0"), + } + rule2_tuples = {("a_", 1, "c_", 4, "1"), ("a_", 3, "c_", 4, "1")} + + all_tuples = rule1_tuples.union(rule2_tuples) + assert actual_triples == all_tuples