Skip to content

Commit

Permalink
Merge pull request #2517 from zmbc/pairwise_string_distance_comparison
Browse files Browse the repository at this point in the history
Pairwise string distance comparison
  • Loading branch information
RobinL authored Dec 3, 2024
2 parents 909e111 + 6f2eb40 commit 02b702b
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- Added new `PairwiseStringDistanceFunctionLevel` and `PairwiseStringDistanceFunctionAtThresholds`
for comparing array columns using a string similarity on each pair of values
([#2517](https://github.com/moj-analytical-services/splink/pull/2517))

### Fixed

- Various bugfixes for `debug_mode` ([#2481](https://github.com/moj-analytical-services/splink/pull/2481))
Expand Down
2 changes: 2 additions & 0 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Not,
NullLevel,
Or,
PairwiseStringDistanceFunctionLevel,
PercentageDifferenceLevel,
)

Expand All @@ -37,6 +38,7 @@
"JaroLevel",
"JaccardLevel",
"DistanceFunctionLevel",
"PairwiseStringDistanceFunctionLevel",
"AbsoluteTimeDifferenceLevel",
"AbsoluteDateDifferenceLevel",
"DistanceInKMLevel",
Expand Down
2 changes: 2 additions & 0 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JaroWinklerAtThresholds,
LevenshteinAtThresholds,
NameComparison,
PairwiseStringDistanceFunctionAtThresholds,
PostcodeComparison,
)

Expand All @@ -28,6 +29,7 @@
"JaroAtThresholds",
"JaroWinklerAtThresholds",
"DistanceFunctionAtThresholds",
"PairwiseStringDistanceFunctionAtThresholds",
"AbsoluteTimeDifferenceAtThresholds",
"AbsoluteDateDifferenceAtThresholds",
"ArrayIntersectAtSizes",
Expand Down
91 changes: 91 additions & 0 deletions splink/internals/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,97 @@ def create_label_for_charts(self) -> str:
)


class PairwiseStringDistanceFunctionLevel(ComparisonLevelCreator):
def __init__(
self,
col_name: str | ColumnExpression,
distance_function_name: Literal[
"levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"
],
distance_threshold: Union[int, float],
):
"""A comparison level using the *most similar* string distance
between any pair of values between arrays in an array column.
The function given by `distance_function_name` must be one of
"levenshtein," "damera_levenshtein," "jaro_winkler," or "jaro."
Args:
col_name (str | ColumnExpression): Input column name
distance_function_name (str): the name of the string distance function
distance_threshold (Union[int, float]): The threshold to use to assess
similarity
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_function_name = validate_categorical_parameter(
allowed_values=[
"levenshtein",
"damerau_levenshtein",
"jaro_winkler",
"jaro",
],
parameter_value=distance_function_name,
level_name=self.__class__.__name__,
parameter_name="distance_function_name",
)
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=float("inf"),
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
parameter_name="distance_threshold",
)

@unsupported_splink_dialects(["sqlite", "spark", "postgres", "athena"])
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
distance_function_name_transpiled = {
"levenshtein": sql_dialect.levenshtein_function_name,
"damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name,
"jaro_winkler": sql_dialect.jaro_winkler_function_name,
"jaro": sql_dialect.jaro_function_name,
}[self.distance_function_name]

return f"""list_{self._aggregator()}(
list_transform(
flatten(
list_transform(
{col.name_l},
x -> list_transform(
{col.name_r},
y -> [x,y]
)
)
),
pair -> {distance_function_name_transpiled}(pair[1], pair[2])
)
) {self._comparator()} {self.distance_threshold}"""

def create_label_for_charts(self) -> str:
col = self.col_expression
return (
f"{self._aggregator().title()} `{self.distance_function_name}` "
f"distance of '{col.label} "
f"{self._comparator()} than {self.distance_threshold}'"
)

def _aggregator(self):
return "max" if self._higher_is_more_similar() else "min"

def _comparator(self):
return ">=" if self._higher_is_more_similar() else "<="

def _higher_is_more_similar(self):
return {
"levenshtein": False,
"damerau_levenshtein": False,
"jaro_winkler": True,
"jaro": True,
}[self.distance_function_name]


DateMetricType = Literal["second", "minute", "hour", "day", "month", "year"]


Expand Down
72 changes: 71 additions & 1 deletion splink/internals/comparison_library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Iterable, List, Optional, Union
from typing import Any, Iterable, List, Literal, Optional, Union

from splink.internals import comparison_level_library as cll
from splink.internals.column_expression import ColumnExpression
Expand Down Expand Up @@ -345,6 +345,76 @@ def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class PairwiseStringDistanceFunctionAtThresholds(ComparisonCreator):
def __init__(
self,
col_name: str,
distance_function_name: Literal[
"levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"
],
distance_threshold_or_thresholds: Union[Iterable[int | float], int | float],
):
"""
Represents a comparison of the *most similar pair* of values
where the first value is in the array data in `col_name` for the first record
and the second value is in the array data in `col_name` for the second record.
The comparison has three or more levels:
- Exact match between any pair of values
- User-selected string distance function levels at specified thresholds
- ...
- Anything else
For example, with distance_threshold_or_thresholds = [1, 3]
and distance_function 'levenshtein' the levels are:
- Exact match between any pair of values
- Levenshtein distance between the most similar pair of values <= 1
- Levenshtein distance between the most similar pair of values <= 3
- Anything else
Args:
col_name (str): The name of the column to compare.
distance_function_name (str): the name of the string distance function.
Must be one of "levenshtein," "damera_levenshtein," "jaro_winkler,"
or "jaro."
distance_threshold_or_thresholds (Union[float, list], optional): The
threshold(s) to use for the distance function level(s).
"""
thresholds_as_iterable = ensure_is_iterable(distance_threshold_or_thresholds)
self.thresholds = [*thresholds_as_iterable]
self.distance_function_name = distance_function_name
super().__init__(col_name)

def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
# It is assumed that any string distance treats identical
# arrays as the most similar
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll.PairwiseStringDistanceFunctionLevel(
self.col_expression,
distance_threshold=threshold,
distance_function_name=self.distance_function_name,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

def create_description(self) -> str:
comma_separated_thresholds_string = ", ".join(map(str, self.thresholds))
plural = "s" if len(self.thresholds) > 1 else ""
return (
f"Pairwise {self.distance_function_name} distance at threshold{plural} "
f"{comma_separated_thresholds_string} vs. anything else"
)

def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class AbsoluteTimeDifferenceAtThresholds(ComparisonCreator):
def __init__(
self,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_comparison_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,62 @@ def test_distance_function_comparison():
assert sum(df_pred[f"gamma_{col}"] == gamma_val) == expected_count


@mark_with_dialects_excluding("sqlite", "spark", "postgres", "athena")
def test_pairwise_stringdistance_function_comparison(test_helpers, dialect):
helper = test_helpers[dialect]
db_api = helper.extra_linker_args()["db_api"]

test_cases = [
{
"comparison": cl.PairwiseStringDistanceFunctionAtThresholds(
"forename",
"damerau_levenshtein",
distance_threshold_or_thresholds=[1, 2],
),
"inputs": [
{
"forename_l": ["Cally", "Sally"],
"forename_r": ["Cally"],
"expected_value": 3,
"expected_label": "Array intersection size >= 1",
},
{
"forename_l": ["Geof"],
"forename_r": ["Geoff"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Saly", "Barey"],
"forename_r": ["Sally", "Barry"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Carry", "Different"],
"forename_r": ["Barry", "Completely"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Carry", "Sabby"],
"forename_r": ["Cally"],
"expected_value": 1,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 2'", # noqa: E501
},
{
"forename_l": ["Completely", "Different"],
"forename_r": ["Something", "Else"],
"expected_value": 0,
"expected_label": "All other comparisons",
},
],
}
]

run_comparison_vector_value_tests(test_cases, db_api)


@mark_with_dialects_excluding()
def test_set_to_lowercase(test_helpers, dialect):
helper = test_helpers[dialect]
Expand Down

0 comments on commit 02b702b

Please sign in to comment.