Skip to content

Commit

Permalink
formatting and updating docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
aymonwuolanne committed Jan 6, 2025
1 parent 85fa9e5 commit be4b680
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 43 deletions.
45 changes: 38 additions & 7 deletions splink/internals/linker_components/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from splink.internals.connected_components import (
solve_connected_components,
)
from splink.internals.one_to_one_clustering import (
one_to_one_clustering,
)
from splink.internals.edge_metrics import compute_edge_metrics
from splink.internals.graph_metrics import (
GraphMetricsResults,
Expand All @@ -17,6 +14,9 @@
from splink.internals.misc import (
threshold_args_to_match_prob,
)
from splink.internals.one_to_one_clustering import (
one_to_one_clustering,
)
from splink.internals.pipeline import CTEPipeline
from splink.internals.splink_dataframe import SplinkDataFrame
from splink.internals.unique_id_concat import (
Expand Down Expand Up @@ -181,7 +181,7 @@ def cluster_pairwise_predictions_at_threshold(
return df_clustered_with_input_data


def cluster_single_best_links_at_threshold(
def cluster_using_single_best_links(
self,
df_predict: SplinkDataFrame,
source_datasets: List[str],
Expand All @@ -192,7 +192,37 @@ def cluster_single_best_links_at_threshold(
Clusters the pairwise match predictions that result from
`linker.inference.predict()` into groups of connected records using a single
best links method that restricts the clusters to have at most one record from
each source dataset in the `source_datasets` list.
each source dataset in the `source_datasets` list.
This method will include a record into a cluster if it is mutually the best
match for the record and for the cluster, and if adding the record will not
violate the criteria of having at most one record from each of the
`source_datasets`.
Args:
df_predict (SplinkDataFrame): The results of `linker.predict()`
source_datasets (List[str]): The source datasets which should be treated
as having no duplicates. Clusters will not form with more than
one record from each of these datasets. This can be a subset of all of
the source datasets in the input data.
threshold_match_probability (float, optional): Pairwise comparisons with a
`match_probability` at or above this threshold are matched
threshold_match_weight (float, optional): Pairwise comparisons with a
`match_weight` at or above this threshold are matched. Only one of
threshold_match_probability or threshold_match_weight should be provided
Returns:
SplinkDataFrame: A SplinkDataFrame containing a list of all IDs, clustered
into groups based on the desired match threshold and the source datasets
for which duplicates are not allowed.
Examples:
```python
df_predict = linker.inference.predict(threshold_match_probability=0.5)
df_clustered = linker.clustering.cluster_pairwise_predictions_at_threshold(
df_predict, source_datasets=["A", "B"], threshold_match_probability=0.95
)
```
"""
linker = self._linker
db_api = linker._db_api
Expand All @@ -206,7 +236,9 @@ def cluster_single_best_links_at_threshold(
uid_concat_edges_r = _composite_unique_id_from_edges_sql(uid_cols, "r")
uid_concat_nodes = _composite_unique_id_from_nodes_sql(uid_cols, None)

source_dataset_column_name = linker._settings_obj.column_info_settings.source_dataset_column_name
source_dataset_column_name = (
linker._settings_obj.column_info_settings.source_dataset_column_name
)

sql = f"""
select
Expand Down Expand Up @@ -301,7 +333,6 @@ def cluster_single_best_links_at_threshold(

return df_clustered_with_input_data


def _compute_metrics_nodes(
self,
df_predict: SplinkDataFrame,
Expand Down
85 changes: 49 additions & 36 deletions splink/internals/one_to_one_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

logger = logging.getLogger(__name__)


def one_to_one_clustering(
nodes_table: SplinkDataFrame,
edges_table: SplinkDataFrame,
Expand All @@ -23,13 +24,8 @@ def one_to_one_clustering(
) -> SplinkDataFrame:
"""One to one clustering algorithm.
This function clusters together records so that at most one record from each dataset is in each cluster.
Args:
Returns:
SplinkDataFrame: A dataframe containing the connected components list
for your link or dedupe job.
This function clusters together records so that at most one record from each
dataset is in each cluster.
"""

Expand All @@ -39,7 +35,8 @@ def one_to_one_clustering(
if threshold_match_probability is None:
match_prob_expr = ""

# Add 'reverse-edges' so that the algorithm can rank all incoming and outgoing edges
# Add 'reverse-edges' so that the algorithm can rank all incoming and outgoing
# edges
sql = f"""
select
{edge_id_column_name_left} as node_id,
Expand Down Expand Up @@ -82,9 +79,14 @@ def one_to_one_clustering(

pipeline = CTEPipeline([neighbours, prev_representatives])

# might need to quote the value here?
contains_expr = ", ".join([f"max(source_dataset == '{sd}') as contains_{sd}" for sd in source_datasets])

# might need to quote the value here?
contains_expr = ", ".join(
[
f"max(source_dataset == '{sd}') as contains_{sd}"
for sd in source_datasets
]
)

sql = f"""
select
representative,
Expand All @@ -93,55 +95,66 @@ def one_to_one_clustering(
group by representative
"""

pipeline.enqueue_sql(sql, f"__splink__representative_contains_flags_{iteration}")
pipeline.enqueue_sql(
sql, f"__splink__representative_contains_flags_{iteration}"
)

sql = f"""
select
r.node_id,
r.source_dataset,
cf.*
from {prev_representatives.physical_name} as r
from {prev_representatives.physical_name} as r
inner join __splink__representative_contains_flags_{iteration} as cf
on r.representative = cf.representative
"""

pipeline.enqueue_sql(sql, f"__splink__df_representatives_with_flags_{iteration}")
pipeline.enqueue_sql(
sql, f"__splink__df_representatives_with_flags_{iteration}"
)

duplicate_criteria = " or ".join(
[f"(l.contains_{sd} and r.contains_{sd})" for sd in source_datasets]
)

duplicate_criteria = " or ".join([f"(l.contains_{sd} and r.contains_{sd})" for sd in source_datasets])

# must be calculated every iteration since the where condition changes as the clustering progresses
# must be calculated every iteration since the where condition changes as
# the clustering progresses
sql = f"""
select
select
neighbours.node_id,
neighbours.neighbour,
{duplicate_criteria} as duplicate_criteria,
row_number() over (partition by l.representative order by match_probability desc) as rank_l,
row_number() over (partition by r.representative order by match_probability desc) as rank_r,
row_number() over (
partition by l.representative order by match_probability desc
) as rank_l,
row_number() over (
partition by r.representative order by match_probability desc
) as rank_r,
from {neighbours.physical_name} as neighbours
inner join __splink__df_representatives_with_flags_{iteration} as l
on neighbours.node_id = l.node_id
inner join __splink__df_representatives_with_flags_{iteration} as r
inner join __splink__df_representatives_with_flags_{iteration} as r
on neighbours.neighbour = r.node_id
where l.representative <> r.representative
"""
# note for the future: a strategy to handle ties would go right here.

# note for the future: a strategy to handle ties would go right here.

pipeline.enqueue_sql(sql, f"__splink__df_ranked_{iteration}")

sql = f"""
select
node_id,
neighbour
from __splink__df_ranked_{iteration}
from __splink__df_ranked_{iteration}
where rank_l = 1 and rank_r = 1 and not duplicate_criteria
"""

pipeline.enqueue_sql(sql, f"__splink__df_neighbours_{iteration}")

sql = f"""
select
source.node_id,
source.node_id,
min(source.representative) as representative
from
(
Expand All @@ -151,9 +164,9 @@ def one_to_one_clustering(
from __splink__df_neighbours_{iteration} as neighbours
left join {prev_representatives.physical_name} as repr_neighbour
on neighbours.neighbour = repr_neighbour.node_id
union all
select
node_id,
representative
Expand All @@ -162,8 +175,8 @@ def one_to_one_clustering(
group by source.node_id
"""

pipeline.enqueue_sql(sql, f"r")
pipeline.enqueue_sql(sql, "r")

sql = f"""
select
r.node_id,
Expand All @@ -186,7 +199,7 @@ def one_to_one_clustering(

# assess if the exit condition has been met
sql = f"""
select
select
count(*) as count_of_nodes_needing_updating
from {representatives.physical_name}
where needs_updating
Expand All @@ -211,9 +224,9 @@ def one_to_one_clustering(
pipeline = CTEPipeline()

sql = f"""
select node_id as {node_id_column_name}, representative as cluster_id
select node_id as {node_id_column_name}, representative as cluster_id
from {representatives.physical_name}
"""
"""

pipeline.enqueue_sql(sql, "__splink__clustering_output_final")

Expand All @@ -222,4 +235,4 @@ def one_to_one_clustering(
representatives.drop_table_from_database_and_remove_from_cache()
neighbours.drop_table_from_database_and_remove_from_cache()

return final_result
return final_result

0 comments on commit be4b680

Please sign in to comment.