Skip to content

Commit

Permalink
better documentation
Browse files Browse the repository at this point in the history
RobinL committed Apr 25, 2024
1 parent 75ea5ec commit 71a53f5
Showing 2 changed files with 51 additions and 4 deletions.
53 changes: 50 additions & 3 deletions splink/accuracy.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,28 @@ def truth_space_table_from_labels_with_predictions_sqls(
total_labels: int = None,
positives_not_captured_by_blocking_rules_scored_as_zero=True,
):
"""
Create truth statistics (TP, TN, FP, FN) at each threshold, suitable for
plotting e.g. on a ROC chart.
A naive implementation would iterate through each threshold computing
the truth statistics. However, if the number of thresholds is large, this
is very slow.
This implementation uses a trick to allow the whole table to be computed 'in one'
But this makes the code quite a lot harder to understand. See comments below
The positives_not_captured_by_blocking_rules_scored_as_zero argument
controls whether we score all the clerical label pairs, irrespective of
whether they were found by blocking rules, or whether any clerical pair
not found by blocking rules is scored as zero.
This is important because, from the users point of view, they might want to
distinguish the case of 'model scores correctly but blocking rules are flawed'
from 'model scores incorrectly'
"""
# Round to match_weight_round_to_nearest.
# e.g. if it's 0.25, 1.27 gets rounded to 1.25
if match_weight_round_to_nearest is not None:
@@ -34,6 +56,8 @@ def truth_space_table_from_labels_with_predictions_sqls(
truth_thres_expr = "match_weight"

sqls = []

# Classify each record as a clerical positive or negative according to the threshold
sql = f"""
select
*,
@@ -53,6 +77,9 @@ def truth_space_table_from_labels_with_predictions_sqls(
sql_info = {"sql": sql, "output_table_name": "__splink__labels_with_pos_neg"}
sqls.append(sql_info)

# Override the truth threshold (splink score) for any records
# not found by blocking rules

if positives_not_captured_by_blocking_rules_scored_as_zero:
truth_threshold = """CASE
WHEN found_by_blocking_rules then truth_threshold
@@ -61,15 +88,23 @@ def truth_space_table_from_labels_with_predictions_sqls(
else:
truth_threshold = "truth_threshold"

# Originally the case statement was inlined in the follow statement, but spark
# doesn't like that
# (Originally the case statement was inlined in the following cte, but caused
# an error with the spark parser that it couldn't find the variable
# found_by_blocking_rules)
sql = f"""
select *, {truth_threshold} as truth_threshold_adj
from __splink__labels_with_pos_neg
"""
sql_info = {"sql": sql, "output_table_name": "__splink__labels_with_pos_neg_tt_adj"}
sqls.append(sql_info)

# For each distinct truth threshold (splink score)
# compute the number of pairwise comparisons at that splink score
# and the the count of positive and negative clerical labels
# for those pairwise comparisons
# e.g.For records scored at the truth threshold (splink score threshold) of
# -2.28757, two of our comparisons are in fact matches
# This single table can be used to derive the whole truth space table
sql = """
select
truth_threshold_adj as truth_threshold,
@@ -88,6 +123,13 @@ def truth_space_table_from_labels_with_predictions_sqls(
}
sqls.append(sql_info)

# By computing cumulative counts on the above table, we can derive the basis for
# computing the TP, TN, FP, FN counts at each threshold
# For example, for a given threshold, if we know:
# The number of labels where splink gave a score about the threshold
# The number of clerical positives at the threshold
# The difference between the two is the number of false positives

sql = """
select
truth_threshold,
@@ -133,7 +175,12 @@ def truth_space_table_from_labels_with_predictions_sqls(
"""
else:
# This code is only needed in the case of labelling from a column
# rather than a pairwise table
# rather than a pairwise table. It's needed because there are some
# 'implicit' 'ghost' label - i.e. pairs of records which are not in the
# labels table, but which are still negatives. This is because
# we only create the pairwise comparisons that pass the blocking rules
# or are clerical matches.
# All of these 'ghost' records are true negatives
total_labels_str = f"cast({total_labels} as float8)"

# When we blocked, some clerical negatives would have been found indirectly
2 changes: 1 addition & 1 deletion splink/database_api.py
Original file line number Diff line number Diff line change
@@ -160,7 +160,7 @@ def sql_to_splink_dataframe_checking_cache(
try:
from IPython.display import display

display(df_pd.head())
display(df_pd)
except ModuleNotFoundError:
print(df_pd) # noqa: T201

0 comments on commit 71a53f5

Please sign in to comment.