From 71a53f524ce3aec8060efd6de053e04a657834d5 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Thu, 25 Apr 2024 09:14:32 +0100 Subject: [PATCH] better documentation --- splink/accuracy.py | 53 +++++++++++++++++++++++++++++++++++++++--- splink/database_api.py | 2 +- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/splink/accuracy.py b/splink/accuracy.py index 3c8155db86..61a24fc93a 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -23,6 +23,28 @@ def truth_space_table_from_labels_with_predictions_sqls( total_labels: int = None, positives_not_captured_by_blocking_rules_scored_as_zero=True, ): + """ + Create truth statistics (TP, TN, FP, FN) at each threshold, suitable for + plotting e.g. on a ROC chart. + + A naive implementation would iterate through each threshold computing + the truth statistics. However, if the number of thresholds is large, this + is very slow. + + This implementation uses a trick to allow the whole table to be computed 'in one' + + But this makes the code quite a lot harder to understand. See comments below + + The positives_not_captured_by_blocking_rules_scored_as_zero argument + controls whether we score all the clerical label pairs, irrespective of + whether they were found by blocking rules, or whether any clerical pair + not found by blocking rules is scored as zero. + + This is important because, from the users point of view, they might want to + distinguish the case of 'model scores correctly but blocking rules are flawed' + from 'model scores incorrectly' + + """ # Round to match_weight_round_to_nearest. # e.g. if it's 0.25, 1.27 gets rounded to 1.25 if match_weight_round_to_nearest is not None: @@ -34,6 +56,8 @@ def truth_space_table_from_labels_with_predictions_sqls( truth_thres_expr = "match_weight" sqls = [] + + # Classify each record as a clerical positive or negative according to the threshold sql = f""" select *, @@ -53,6 +77,9 @@ def truth_space_table_from_labels_with_predictions_sqls( sql_info = {"sql": sql, "output_table_name": "__splink__labels_with_pos_neg"} sqls.append(sql_info) + # Override the truth threshold (splink score) for any records + # not found by blocking rules + if positives_not_captured_by_blocking_rules_scored_as_zero: truth_threshold = """CASE WHEN found_by_blocking_rules then truth_threshold @@ -61,8 +88,9 @@ def truth_space_table_from_labels_with_predictions_sqls( else: truth_threshold = "truth_threshold" - # Originally the case statement was inlined in the follow statement, but spark - # doesn't like that + # (Originally the case statement was inlined in the following cte, but caused + # an error with the spark parser that it couldn't find the variable + # found_by_blocking_rules) sql = f""" select *, {truth_threshold} as truth_threshold_adj from __splink__labels_with_pos_neg @@ -70,6 +98,13 @@ def truth_space_table_from_labels_with_predictions_sqls( sql_info = {"sql": sql, "output_table_name": "__splink__labels_with_pos_neg_tt_adj"} sqls.append(sql_info) + # For each distinct truth threshold (splink score) + # compute the number of pairwise comparisons at that splink score + # and the the count of positive and negative clerical labels + # for those pairwise comparisons + # e.g.For records scored at the truth threshold (splink score threshold) of + # -2.28757, two of our comparisons are in fact matches + # This single table can be used to derive the whole truth space table sql = """ select truth_threshold_adj as truth_threshold, @@ -88,6 +123,13 @@ def truth_space_table_from_labels_with_predictions_sqls( } sqls.append(sql_info) + # By computing cumulative counts on the above table, we can derive the basis for + # computing the TP, TN, FP, FN counts at each threshold + # For example, for a given threshold, if we know: + # The number of labels where splink gave a score about the threshold + # The number of clerical positives at the threshold + # The difference between the two is the number of false positives + sql = """ select truth_threshold, @@ -133,7 +175,12 @@ def truth_space_table_from_labels_with_predictions_sqls( """ else: # This code is only needed in the case of labelling from a column - # rather than a pairwise table + # rather than a pairwise table. It's needed because there are some + # 'implicit' 'ghost' label - i.e. pairs of records which are not in the + # labels table, but which are still negatives. This is because + # we only create the pairwise comparisons that pass the blocking rules + # or are clerical matches. + # All of these 'ghost' records are true negatives total_labels_str = f"cast({total_labels} as float8)" # When we blocked, some clerical negatives would have been found indirectly diff --git a/splink/database_api.py b/splink/database_api.py index 2666490339..82de06f24e 100644 --- a/splink/database_api.py +++ b/splink/database_api.py @@ -160,7 +160,7 @@ def sql_to_splink_dataframe_checking_cache( try: from IPython.display import display - display(df_pd.head()) + display(df_pd) except ModuleNotFoundError: print(df_pd) # noqa: T201