From 4aa0f8485c3e3d7272ffe7a4a928e0f442947981 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:44:45 +0000 Subject: [PATCH 1/4] ColumnExpression operation: access_extreme_array_element abstracting selecting the first or last element of an array --- splink/internals/column_expression.py | 32 ++++++++++++++++- splink/internals/dialects.py | 50 ++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/splink/internals/column_expression.py b/splink/internals/column_expression.py index 1aa6e2a9d3..a8587685c7 100644 --- a/splink/internals/column_expression.py +++ b/splink/internals/column_expression.py @@ -4,7 +4,7 @@ import string from copy import copy from functools import partial -from typing import Protocol, Union +from typing import Literal, Protocol, Union import sqlglot @@ -249,6 +249,36 @@ def try_parse_timestamp(self, timestamp_format: str = None) -> "ColumnExpression return clone + def _access_extreme_array_element_dialected( + self, + name: str, + sql_dialect: SplinkDialect, + first_or_last: Literal["first", "last"], + ) -> str: + return sql_dialect.access_extreme_array_element( + name, first_or_last=first_or_last + ) + + def access_extreme_array_element( + self, first_or_last: Literal["first", "last"] + ) -> "ColumnExpression": + """ + Applies a transformation to access either the first or the last element + of an array + + Args: + first_or_last (str): 'first' for returning the first elemen of the array, + 'last' for the last element + """ + clone = self._clone() + op = partial( + clone._access_extreme_array_element_dialected, + first_or_last=first_or_last, + ) + clone.operations.append(op) + + return clone + @property def name(self) -> str: sql_expression = self._parse_input_string(self.sql_dialect) diff --git a/splink/internals/dialects.py b/splink/internals/dialects.py index 90535ed03c..00a7bc0535 100644 --- a/splink/internals/dialects.py +++ b/splink/internals/dialects.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractproperty -from typing import TYPE_CHECKING, Type, TypeVar, final +from typing import TYPE_CHECKING, Literal, Type, TypeVar, final if TYPE_CHECKING: from splink.internals.comparison_level_library import ( @@ -190,6 +190,14 @@ def _regex_extract_raw( "'regex_extract' function" ) + def access_extreme_array_element( + self, name: str, first_or_last: Literal["first", "last"] + ) -> str: + raise NotImplementedError( + f"Backend '{self.sql_dialect_str}' does not have an " + "'access_extreme_array_element' function" + ) + def explode_arrays_sql( self, tbl_name: str, @@ -288,6 +296,18 @@ def random_sample_sql( else: return f"USING SAMPLE {percent}% (bernoulli)" + def access_extreme_array_element( + self, name: str, first_or_last: Literal["first", "last"] + ) -> str: + if first_or_last == "first": + return f"{name}[{self.array_first_index}]" + if first_or_last == "last": + return f"{name}[-1]" + raise ValueError( + f"Argument 'first_or_last' should be 'first' or 'last', " + f"received: '{first_or_last}'" + ) + def explode_arrays_sql( self, tbl_name: str, @@ -397,6 +417,18 @@ def random_sample_sql( else: return f" TABLESAMPLE ({percent} PERCENT) " + def access_extreme_array_element( + self, name: str, first_or_last: Literal["first", "last"] + ) -> str: + if first_or_last == "first": + return f"{name}[{self.array_first_index}]" + if first_or_last == "last": + return f"element_at({name}, -1)" + raise ValueError( + f"Argument 'first_or_last' should be 'first' or 'last', " + f"received: '{first_or_last}'" + ) + def explode_arrays_sql( self, tbl_name: str, @@ -552,6 +584,22 @@ def random_sample_sql( def infinity_expression(self): return "'infinity'" + @property + def array_first_index(self): + return 1 + + def access_extreme_array_element( + self, name: str, first_or_last: Literal["first", "last"] + ) -> str: + if first_or_last == "first": + return f"{name}[{self.array_first_index}]" + if first_or_last == "last": + return f"{name}[array_length({name}, 1)]" + raise ValueError( + f"Argument 'first_or_last' should be 'first' or 'last', " + f"received: '{first_or_last}'" + ) + class AthenaDialect(SplinkDialect): _dialect_name_for_factory = "athena" From 2793a2a13f96ca181433936d6b073f7df337c08e Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:48:21 +0000 Subject: [PATCH 2/4] test ColumnExpression.access_extreme_array_element --- tests/test_column_expression.py | 50 +++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 tests/test_column_expression.py diff --git a/tests/test_column_expression.py b/tests/test_column_expression.py new file mode 100644 index 0000000000..83b01f510a --- /dev/null +++ b/tests/test_column_expression.py @@ -0,0 +1,50 @@ +import pandas as pd + +from splink import ColumnExpression +from splink.internals.dialects import SplinkDialect + +from .decorator import mark_with_dialects_excluding + + +@mark_with_dialects_excluding("sqlite") +def test_access_extreme_array_element(test_helpers, dialect): + # test data! + df_arr = pd.DataFrame( + [ + {"unique_id": 1, "name_arr": ["first", "middle", "last"]}, + {"unique_id": 2, "name_arr": ["aardvark", "llama", "ssnake", "zzebra"]}, + {"unique_id": 3, "name_arr": ["only"]}, + ] + ) + + helper = test_helpers[dialect] + # register table with backend + table_name = "arr_tab" + table = helper.convert_frame(df_arr) + db_api = helper.DatabaseAPI(**helper.db_api_args()) + arr_tab = db_api.register_table(table, table_name) + + # construct a SQL query from ColumnExpressions and run it against backend + splink_dialect = SplinkDialect.from_string(dialect) + first_element = ColumnExpression( + "name_arr", sql_dialect=splink_dialect + ).access_extreme_array_element("first") + last_element = ColumnExpression( + "name_arr", sql_dialect=splink_dialect + ).access_extreme_array_element("last") + sql = ( + f"SELECT unique_id, {first_element.name} AS first_element, " + f"{last_element.name} AS last_element " + f"FROM {arr_tab.physical_name} ORDER BY unique_id" + ) + res = db_api.sql_to_splink_dataframe_checking_cache( + sql, "test_first" + ).as_pandas_dataframe() + + pd.testing.assert_series_equal( + res["first_element"], + pd.Series(["first", "aardvark", "only"], name="first_element"), + ) + pd.testing.assert_series_equal( + res["last_element"], pd.Series(["last", "zzebra", "only"], name="last_element") + ) From 63d2dfec36976ac4a1368709b106f48b54853420 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:48:41 +0000 Subject: [PATCH 3/4] brief note on ColumnExpression benefit --- docs/api_docs/column_expression.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api_docs/column_expression.md b/docs/api_docs/column_expression.md index 797fbe9537..09eddcd97b 100644 --- a/docs/api_docs/column_expression.md +++ b/docs/api_docs/column_expression.md @@ -16,6 +16,7 @@ However, there may be situations where you don't wish to derive a new column, pe This is where a `ColumnExpression` may be used. It represents some SQL expression, which may be a column, or some more complicated construct, to which you can also apply zero or more transformations. These are lazily evaluated, and in particular will not be tied to a specific SQL dialect until they are put (via [settings](./settings_dict_guide.md) into a linker). +This can be particularly useful if you want to write code that can easily be switched between different backends. ??? warning "Term frequency adjustments" One caveat to using a `ColumnExpression` is that it cannot be combined with term frequency adjustments. From 3cec1eeb1639d3371521b4da8e3be5c8d7f4df0d Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:54:37 +0000 Subject: [PATCH 4/4] +changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c85c807e9..b697b10b7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- `ColumnExpression` now supports accessing first or last element of an array column via method `access_extreme_array_element()` ([#2520](https://github.com/moj-analytical-services/splink/pull/2585)) + ### Deprecated - Deprecated support for python `3.8.x` following end of support for that minor version ([#2520](https://github.com/moj-analytical-services/splink/pull/2520))