Skip to content

Commit

Permalink
add function that reports biggest blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Jul 16, 2024
1 parent 7bc6ab3 commit 26259d3
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
2 changes: 2 additions & 0 deletions splink/blocking_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
count_comparisons_from_blocking_rule,
cumulative_comparisons_to_be_scored_from_blocking_rules_chart,
cumulative_comparisons_to_be_scored_from_blocking_rules_data,
n_largest_blocks,
)

__all__ = [
"count_comparisons_from_blocking_rule",
"cumulative_comparisons_to_be_scored_from_blocking_rules_chart",
"cumulative_comparisons_to_be_scored_from_blocking_rules_data",
"n_largest_blocks",
]
123 changes: 123 additions & 0 deletions splink/internals/blocking_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,107 @@ def _count_comparisons_from_blocking_rule_pre_filter_conditions_sqls(
return sqls


def _count_comparisons_from_n_largest_blocks_pre_filter_conditions_sqls(
input_data_dict: dict[str, "SplinkDataFrame"],
blocking_rule: "BlockingRule",
link_type: str,
db_api: DatabaseAPISubClass,
n_largest=5,
) -> list[dict[str, str]]:
input_dataframes = list(input_data_dict.values())

join_conditions = blocking_rule._equi_join_conditions
two_dataset_link_only = link_type == "link_only" and len(input_dataframes) == 2

sqls = []

if two_dataset_link_only:
input_tablename_l = input_dataframes[0].physical_name
input_tablename_r = input_dataframes[1].physical_name
else:
sql = vertically_concatenate_sql(
input_data_dict, salting_required=False, source_dataset_input_column=None
)
sqls.append({"sql": sql, "output_table_name": "__splink__df_concat"})

input_tablename_l = "__splink__df_concat"
input_tablename_r = "__splink__df_concat"

l_cols_sel = []
r_cols_sel = []
l_cols_gb = []
r_cols_gb = []
using = []
for (
i,
(l_key, r_key),
) in enumerate(join_conditions):
l_cols_sel.append(f"{l_key} as key_{i}")
r_cols_sel.append(f"{r_key} as key_{i}")
l_cols_gb.append(l_key)
r_cols_gb.append(r_key)
using.append(f"key_{i}")

l_cols_sel_str = ", ".join(l_cols_sel)
r_cols_sel_str = ", ".join(r_cols_sel)
l_cols_gb_str = ", ".join(l_cols_gb)
r_cols_gb_str = ", ".join(r_cols_gb)
using_str = ", ".join(using)

if not join_conditions:
if two_dataset_link_only:
sql = f"""
SELECT
(SELECT COUNT(*) FROM {input_tablename_l})
*
(SELECT COUNT(*) FROM {input_tablename_r})
AS count_of_pairwise_comparisons_generated
"""
else:
sql = """
select count(*) * count(*) as count_of_pairwise_comparisons_generated
from __splink__df_concat
"""
sqls.append(
{"sql": sql, "output_table_name": "__splink__total_of_block_counts"}
)
return sqls

sql = f"""
select {l_cols_sel_str}, count(*) as count_l
from {input_tablename_l}
group by {l_cols_gb_str}
"""

sqls.append(
{"sql": sql, "output_table_name": "__splink__count_comparisons_from_blocking_l"}
)

sql = f"""
select {r_cols_sel_str}, count(*) as count_r
from {input_tablename_r}
group by {r_cols_gb_str}
"""

sqls.append(
{"sql": sql, "output_table_name": "__splink__count_comparisons_from_blocking_r"}
)

sql = f"""
select {using_str}, count_l, count_r, count_l * count_r as block_count
from __splink__count_comparisons_from_blocking_l
inner join __splink__count_comparisons_from_blocking_r
using ({using_str})
order by count_l * count_r desc
limit {n_largest}
"""

sqls.append({"sql": sql, "output_table_name": "__splink__block_counts"})

return sqls


def _row_counts_per_input_table(
*,
splink_df_dict: dict[str, "SplinkDataFrame"],
Expand Down Expand Up @@ -654,3 +755,25 @@ def cumulative_comparisons_to_be_scored_from_blocking_rules_chart(
return cumulative_blocking_rule_comparisons_generated(
pd_df.to_dict(orient="records")
)


def n_largest_blocks(
*,
table_or_tables: Sequence[AcceptableInputTableType],
blocking_rule: Union[BlockingRuleCreator, str, Dict[str, Any]],
link_type: user_input_link_type_options,
db_api: DatabaseAPISubClass,
):
blocking_rule_as_br = to_blocking_rule_creator(blocking_rule).get_blocking_rule(
db_api.sql_dialect.name
)

splink_df_dict = db_api.register_multiple_tables(table_or_tables)

sqls = _count_comparisons_from_n_largest_blocks_pre_filter_conditions_sqls(
splink_df_dict, blocking_rule_as_br, link_type, db_api
)
pipeline = CTEPipeline()
pipeline.enqueue_list_of_sqls(sqls)

return db_api.sql_pipeline_to_splink_dataframe(pipeline)

0 comments on commit 26259d3

Please sign in to comment.