Skip to content

Commit

Permalink
Merge pull request #2585 from moj-analytical-services/feature/column-…
Browse files Browse the repository at this point in the history
…expression-first-index

`ColumnExpression` first/last index
  • Loading branch information
ADBond authored Jan 15, 2025
2 parents 885f1fb + 3cec1ee commit 24667bb
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- `ColumnExpression` now supports accessing first or last element of an array column via method `access_extreme_array_element()` ([#2520](https://github.com/moj-analytical-services/splink/pull/2585))

### Deprecated

- Deprecated support for python `3.8.x` following end of support for that minor version ([#2520](https://github.com/moj-analytical-services/splink/pull/2520))
Expand Down
1 change: 1 addition & 0 deletions docs/api_docs/column_expression.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ However, there may be situations where you don't wish to derive a new column, pe

This is where a `ColumnExpression` may be used. It represents some SQL expression, which may be a column, or some more complicated construct,
to which you can also apply zero or more transformations. These are lazily evaluated, and in particular will not be tied to a specific SQL dialect until they are put (via [settings](./settings_dict_guide.md) into a linker).
This can be particularly useful if you want to write code that can easily be switched between different backends.

??? warning "Term frequency adjustments"
One caveat to using a `ColumnExpression` is that it cannot be combined with term frequency adjustments.
Expand Down
32 changes: 31 additions & 1 deletion splink/internals/column_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import string
from copy import copy
from functools import partial
from typing import Protocol, Union
from typing import Literal, Protocol, Union

import sqlglot

Expand Down Expand Up @@ -249,6 +249,36 @@ def try_parse_timestamp(self, timestamp_format: str = None) -> "ColumnExpression

return clone

def _access_extreme_array_element_dialected(
self,
name: str,
sql_dialect: SplinkDialect,
first_or_last: Literal["first", "last"],
) -> str:
return sql_dialect.access_extreme_array_element(
name, first_or_last=first_or_last
)

def access_extreme_array_element(
self, first_or_last: Literal["first", "last"]
) -> "ColumnExpression":
"""
Applies a transformation to access either the first or the last element
of an array
Args:
first_or_last (str): 'first' for returning the first elemen of the array,
'last' for the last element
"""
clone = self._clone()
op = partial(
clone._access_extreme_array_element_dialected,
first_or_last=first_or_last,
)
clone.operations.append(op)

return clone

@property
def name(self) -> str:
sql_expression = self._parse_input_string(self.sql_dialect)
Expand Down
50 changes: 49 additions & 1 deletion splink/internals/dialects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractproperty
from typing import TYPE_CHECKING, Type, TypeVar, final
from typing import TYPE_CHECKING, Literal, Type, TypeVar, final

if TYPE_CHECKING:
from splink.internals.comparison_level_library import (
Expand Down Expand Up @@ -190,6 +190,14 @@ def _regex_extract_raw(
"'regex_extract' function"
)

def access_extreme_array_element(
self, name: str, first_or_last: Literal["first", "last"]
) -> str:
raise NotImplementedError(
f"Backend '{self.sql_dialect_str}' does not have an "
"'access_extreme_array_element' function"
)

def explode_arrays_sql(
self,
tbl_name: str,
Expand Down Expand Up @@ -288,6 +296,18 @@ def random_sample_sql(
else:
return f"USING SAMPLE {percent}% (bernoulli)"

def access_extreme_array_element(
self, name: str, first_or_last: Literal["first", "last"]
) -> str:
if first_or_last == "first":
return f"{name}[{self.array_first_index}]"
if first_or_last == "last":
return f"{name}[-1]"
raise ValueError(
f"Argument 'first_or_last' should be 'first' or 'last', "
f"received: '{first_or_last}'"
)

def explode_arrays_sql(
self,
tbl_name: str,
Expand Down Expand Up @@ -397,6 +417,18 @@ def random_sample_sql(
else:
return f" TABLESAMPLE ({percent} PERCENT) "

def access_extreme_array_element(
self, name: str, first_or_last: Literal["first", "last"]
) -> str:
if first_or_last == "first":
return f"{name}[{self.array_first_index}]"
if first_or_last == "last":
return f"element_at({name}, -1)"
raise ValueError(
f"Argument 'first_or_last' should be 'first' or 'last', "
f"received: '{first_or_last}'"
)

def explode_arrays_sql(
self,
tbl_name: str,
Expand Down Expand Up @@ -552,6 +584,22 @@ def random_sample_sql(
def infinity_expression(self):
return "'infinity'"

@property
def array_first_index(self):
return 1

def access_extreme_array_element(
self, name: str, first_or_last: Literal["first", "last"]
) -> str:
if first_or_last == "first":
return f"{name}[{self.array_first_index}]"
if first_or_last == "last":
return f"{name}[array_length({name}, 1)]"
raise ValueError(
f"Argument 'first_or_last' should be 'first' or 'last', "
f"received: '{first_or_last}'"
)


class AthenaDialect(SplinkDialect):
_dialect_name_for_factory = "athena"
Expand Down
50 changes: 50 additions & 0 deletions tests/test_column_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pandas as pd

from splink import ColumnExpression
from splink.internals.dialects import SplinkDialect

from .decorator import mark_with_dialects_excluding


@mark_with_dialects_excluding("sqlite")
def test_access_extreme_array_element(test_helpers, dialect):
# test data!
df_arr = pd.DataFrame(
[
{"unique_id": 1, "name_arr": ["first", "middle", "last"]},
{"unique_id": 2, "name_arr": ["aardvark", "llama", "ssnake", "zzebra"]},
{"unique_id": 3, "name_arr": ["only"]},
]
)

helper = test_helpers[dialect]
# register table with backend
table_name = "arr_tab"
table = helper.convert_frame(df_arr)
db_api = helper.DatabaseAPI(**helper.db_api_args())
arr_tab = db_api.register_table(table, table_name)

# construct a SQL query from ColumnExpressions and run it against backend
splink_dialect = SplinkDialect.from_string(dialect)
first_element = ColumnExpression(
"name_arr", sql_dialect=splink_dialect
).access_extreme_array_element("first")
last_element = ColumnExpression(
"name_arr", sql_dialect=splink_dialect
).access_extreme_array_element("last")
sql = (
f"SELECT unique_id, {first_element.name} AS first_element, "
f"{last_element.name} AS last_element "
f"FROM {arr_tab.physical_name} ORDER BY unique_id"
)
res = db_api.sql_to_splink_dataframe_checking_cache(
sql, "test_first"
).as_pandas_dataframe()

pd.testing.assert_series_equal(
res["first_element"],
pd.Series(["first", "aardvark", "only"], name="first_element"),
)
pd.testing.assert_series_equal(
res["last_element"], pd.Series(["last", "zzebra", "only"], name="last_element")
)

0 comments on commit 24667bb

Please sign in to comment.