diff --git a/splink/blocking_analysis.py b/splink/blocking_analysis.py index 62a5f93448..3d49a54d24 100644 --- a/splink/blocking_analysis.py +++ b/splink/blocking_analysis.py @@ -2,10 +2,12 @@ count_comparisons_from_blocking_rule, cumulative_comparisons_to_be_scored_from_blocking_rules_chart, cumulative_comparisons_to_be_scored_from_blocking_rules_data, + n_largest_blocks, ) __all__ = [ "count_comparisons_from_blocking_rule", "cumulative_comparisons_to_be_scored_from_blocking_rules_chart", "cumulative_comparisons_to_be_scored_from_blocking_rules_data", + "n_largest_blocks", ] diff --git a/splink/internals/blocking_analysis.py b/splink/internals/blocking_analysis.py index b20c7091f2..51926920b9 100644 --- a/splink/internals/blocking_analysis.py +++ b/splink/internals/blocking_analysis.py @@ -190,6 +190,107 @@ def _count_comparisons_from_blocking_rule_pre_filter_conditions_sqls( return sqls +def _count_comparisons_from_n_largest_blocks_pre_filter_conditions_sqls( + input_data_dict: dict[str, "SplinkDataFrame"], + blocking_rule: "BlockingRule", + link_type: str, + db_api: DatabaseAPISubClass, + n_largest=5, +) -> list[dict[str, str]]: + input_dataframes = list(input_data_dict.values()) + + join_conditions = blocking_rule._equi_join_conditions + two_dataset_link_only = link_type == "link_only" and len(input_dataframes) == 2 + + sqls = [] + + if two_dataset_link_only: + input_tablename_l = input_dataframes[0].physical_name + input_tablename_r = input_dataframes[1].physical_name + else: + sql = vertically_concatenate_sql( + input_data_dict, salting_required=False, source_dataset_input_column=None + ) + sqls.append({"sql": sql, "output_table_name": "__splink__df_concat"}) + + input_tablename_l = "__splink__df_concat" + input_tablename_r = "__splink__df_concat" + + l_cols_sel = [] + r_cols_sel = [] + l_cols_gb = [] + r_cols_gb = [] + using = [] + for ( + i, + (l_key, r_key), + ) in enumerate(join_conditions): + l_cols_sel.append(f"{l_key} as key_{i}") + r_cols_sel.append(f"{r_key} as key_{i}") + l_cols_gb.append(l_key) + r_cols_gb.append(r_key) + using.append(f"key_{i}") + + l_cols_sel_str = ", ".join(l_cols_sel) + r_cols_sel_str = ", ".join(r_cols_sel) + l_cols_gb_str = ", ".join(l_cols_gb) + r_cols_gb_str = ", ".join(r_cols_gb) + using_str = ", ".join(using) + + if not join_conditions: + if two_dataset_link_only: + sql = f""" + SELECT + (SELECT COUNT(*) FROM {input_tablename_l}) + * + (SELECT COUNT(*) FROM {input_tablename_r}) + AS count_of_pairwise_comparisons_generated + """ + else: + sql = """ + select count(*) * count(*) as count_of_pairwise_comparisons_generated + from __splink__df_concat + + """ + sqls.append( + {"sql": sql, "output_table_name": "__splink__total_of_block_counts"} + ) + return sqls + + sql = f""" + select {l_cols_sel_str}, count(*) as count_l + from {input_tablename_l} + group by {l_cols_gb_str} + """ + + sqls.append( + {"sql": sql, "output_table_name": "__splink__count_comparisons_from_blocking_l"} + ) + + sql = f""" + select {r_cols_sel_str}, count(*) as count_r + from {input_tablename_r} + group by {r_cols_gb_str} + """ + + sqls.append( + {"sql": sql, "output_table_name": "__splink__count_comparisons_from_blocking_r"} + ) + + sql = f""" + select {using_str}, count_l, count_r, count_l * count_r as block_count + from __splink__count_comparisons_from_blocking_l + inner join __splink__count_comparisons_from_blocking_r + using ({using_str}) + order by count_l * count_r desc + limit {n_largest} + """ + + sqls.append({"sql": sql, "output_table_name": "__splink__block_counts"}) + + return sqls + + def _row_counts_per_input_table( *, splink_df_dict: dict[str, "SplinkDataFrame"], @@ -654,3 +755,25 @@ def cumulative_comparisons_to_be_scored_from_blocking_rules_chart( return cumulative_blocking_rule_comparisons_generated( pd_df.to_dict(orient="records") ) + + +def n_largest_blocks( + *, + table_or_tables: Sequence[AcceptableInputTableType], + blocking_rule: Union[BlockingRuleCreator, str, Dict[str, Any]], + link_type: user_input_link_type_options, + db_api: DatabaseAPISubClass, +): + blocking_rule_as_br = to_blocking_rule_creator(blocking_rule).get_blocking_rule( + db_api.sql_dialect.name + ) + + splink_df_dict = db_api.register_multiple_tables(table_or_tables) + + sqls = _count_comparisons_from_n_largest_blocks_pre_filter_conditions_sqls( + splink_df_dict, blocking_rule_as_br, link_type, db_api + ) + pipeline = CTEPipeline() + pipeline.enqueue_list_of_sqls(sqls) + + return db_api.sql_pipeline_to_splink_dataframe(pipeline)