Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pairwise string distance comparison #2517

Merged
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- Added new `PairwiseStringDistanceFunctionLevel` and `PairwiseStringDistanceFunctionAtThresholds`
for comparing array columns using a string similarity on each pair of values
([#2517](https://github.com/moj-analytical-services/splink/pull/2517))

### Fixed

- Various bugfixes for `debug_mode` ([#2481](https://github.com/moj-analytical-services/splink/pull/2481))
Expand Down
2 changes: 2 additions & 0 deletions splink/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Not,
NullLevel,
Or,
PairwiseStringDistanceFunctionLevel,
PercentageDifferenceLevel,
)

Expand All @@ -37,6 +38,7 @@
"JaroLevel",
"JaccardLevel",
"DistanceFunctionLevel",
"PairwiseStringDistanceFunctionLevel",
"AbsoluteTimeDifferenceLevel",
"AbsoluteDateDifferenceLevel",
"DistanceInKMLevel",
Expand Down
2 changes: 2 additions & 0 deletions splink/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JaroWinklerAtThresholds,
LevenshteinAtThresholds,
NameComparison,
PairwiseStringDistanceFunctionAtThresholds,
PostcodeComparison,
)

Expand All @@ -28,6 +29,7 @@
"JaroAtThresholds",
"JaroWinklerAtThresholds",
"DistanceFunctionAtThresholds",
"PairwiseStringDistanceFunctionAtThresholds",
zmbc marked this conversation as resolved.
Show resolved Hide resolved
"AbsoluteTimeDifferenceAtThresholds",
"AbsoluteDateDifferenceAtThresholds",
"ArrayIntersectAtSizes",
Expand Down
91 changes: 91 additions & 0 deletions splink/internals/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,97 @@ def create_label_for_charts(self) -> str:
)


class PairwiseStringDistanceFunctionLevel(ComparisonLevelCreator):
def __init__(
self,
col_name: str | ColumnExpression,
distance_function_name: Literal[
"levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"
],
distance_threshold: Union[int, float],
):
"""A comparison level using the *most similar* string distance
between any pair of values between arrays in an array column.

The function given by `distance_function_name` must be one of
"levenshtein," "damera_levenshtein," "jaro_winkler," or "jaro."

Args:
col_name (str | ColumnExpression): Input column name
distance_function_name (str): the name of the string distance function
distance_threshold (Union[int, float]): The threshold to use to assess
similarity
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_function_name = validate_categorical_parameter(
allowed_values=[
"levenshtein",
"damerau_levenshtein",
"jaro_winkler",
"jaro",
],
parameter_value=distance_function_name,
level_name=self.__class__.__name__,
parameter_name="distance_function_name",
)
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=float("inf"),
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
parameter_name="distance_threshold",
)

@unsupported_splink_dialects(["sqlite", "spark", "postgres", "athena"])
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
distance_function_name_transpiled = {
"levenshtein": sql_dialect.levenshtein_function_name,
"damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name,
"jaro_winkler": sql_dialect.jaro_winkler_function_name,
"jaro": sql_dialect.jaro_function_name,
}[self.distance_function_name]

return f"""list_{self._aggregator()}(
list_transform(
flatten(
list_transform(
{col.name_l},
x -> list_transform(
{col.name_r},
y -> [x,y]
)
)
),
pair -> {distance_function_name_transpiled}(pair[1], pair[2])
)
) {self._comparator()} {self.distance_threshold}"""

def create_label_for_charts(self) -> str:
col = self.col_expression
return (
f"{self._aggregator().title()} `{self.distance_function_name}` "
f"distance of '{col.label} "
f"{self._comparator()} than {self.distance_threshold}'"
)

def _aggregator(self):
return "max" if self._higher_is_more_similar() else "min"

def _comparator(self):
return ">=" if self._higher_is_more_similar() else "<="

def _higher_is_more_similar(self):
return {
"levenshtein": False,
"damerau_levenshtein": False,
"jaro_winkler": True,
"jaro": True,
}[self.distance_function_name]


DateMetricType = Literal["second", "minute", "hour", "day", "month", "year"]


Expand Down
72 changes: 71 additions & 1 deletion splink/internals/comparison_library.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Iterable, List, Optional, Union
from typing import Any, Iterable, List, Literal, Optional, Union

from splink.internals import comparison_level_library as cll
from splink.internals.column_expression import ColumnExpression
Expand Down Expand Up @@ -345,6 +345,76 @@ def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class PairwiseStringDistanceFunctionAtThresholds(ComparisonCreator):
def __init__(
self,
col_name: str,
distance_function_name: Literal[
"levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"
],
distance_threshold_or_thresholds: Union[Iterable[int | float], int | float],
):
"""
Represents a comparison of the *most similar pair* of values
where the first value is in the array data in `col_name` for the first record
and the second value is in the array data in `col_name` for the second record.
The comparison has three or more levels:

- Exact match between any pair of values
- User-selected string distance function levels at specified thresholds
- ...
- Anything else

For example, with distance_threshold_or_thresholds = [1, 3]
and distance_function 'levenshtein' the levels are:

- Exact match between any pair of values
- Levenshtein distance between the most similar pair of values <= 1
- Levenshtein distance between the most similar pair of values <= 3
- Anything else

Args:
col_name (str): The name of the column to compare.
distance_function_name (str): the name of the string distance function.
Must be one of "levenshtein," "damera_levenshtein," "jaro_winkler,"
or "jaro."
distance_threshold_or_thresholds (Union[float, list], optional): The
threshold(s) to use for the distance function level(s).
"""
thresholds_as_iterable = ensure_is_iterable(distance_threshold_or_thresholds)
self.thresholds = [*thresholds_as_iterable]
self.distance_function_name = distance_function_name
super().__init__(col_name)

def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
# It is assumed that any string distance treats identical
# arrays as the most similar
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll.PairwiseStringDistanceFunctionLevel(
self.col_expression,
distance_threshold=threshold,
distance_function_name=self.distance_function_name,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

def create_description(self) -> str:
comma_separated_thresholds_string = ", ".join(map(str, self.thresholds))
plural = "s" if len(self.thresholds) > 1 else ""
return (
f"Pairwise {self.distance_function_name} distance at threshold{plural} "
f"{comma_separated_thresholds_string} vs. anything else"
)

def create_output_column_name(self) -> str:
return self.col_expression.output_column_name


class AbsoluteTimeDifferenceAtThresholds(ComparisonCreator):
def __init__(
self,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_comparison_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,62 @@ def test_distance_function_comparison():
assert sum(df_pred[f"gamma_{col}"] == gamma_val) == expected_count


@mark_with_dialects_excluding("sqlite", "spark", "postgres", "athena")
zmbc marked this conversation as resolved.
Show resolved Hide resolved
def test_pairwise_stringdistance_function_comparison(test_helpers, dialect):
helper = test_helpers[dialect]
db_api = helper.extra_linker_args()["db_api"]

test_cases = [
{
"comparison": cl.PairwiseStringDistanceFunctionAtThresholds(
"forename",
"damerau_levenshtein",
distance_threshold_or_thresholds=[1, 2],
),
"inputs": [
{
"forename_l": ["Cally", "Sally"],
"forename_r": ["Cally"],
"expected_value": 3,
"expected_label": "Array intersection size >= 1",
},
{
"forename_l": ["Geof"],
"forename_r": ["Geoff"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Saly", "Barey"],
"forename_r": ["Sally", "Barry"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Carry", "Different"],
"forename_r": ["Barry", "Completely"],
"expected_value": 2,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 1'", # noqa: E501
},
{
"forename_l": ["Carry", "Sabby"],
"forename_r": ["Cally"],
"expected_value": 1,
"expected_label": "Min `damerau_levenshtein` distance of 'forename <= than 2'", # noqa: E501
},
{
"forename_l": ["Completely", "Different"],
"forename_r": ["Something", "Else"],
"expected_value": 0,
"expected_label": "All other comparisons",
},
],
}
]

run_comparison_vector_value_tests(test_cases, db_api)


@mark_with_dialects_excluding()
def test_set_to_lowercase(test_helpers, dialect):
helper = test_helpers[dialect]
Expand Down
Loading