Skip to content

Commit

Permalink
Merge branch 'splink4_dev' into fix_2059
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL authored Apr 30, 2024
2 parents 71a53f5 + 1b9a3ea commit fee3f05
Show file tree
Hide file tree
Showing 51 changed files with 402 additions and 257 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,5 @@ strict_equality = true
strict_concatenate = true
check_untyped_defs = true
disallow_subclassing_any = true
# next set to enforce:
# disallow_untyped_decorators = true
# disallow_any_generics = true
disallow_untyped_decorators = true
disallow_any_generics = true
8 changes: 4 additions & 4 deletions splink/blocking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional

from sqlglot import parse_one
from sqlglot.expressions import Column, Expression, Join
Expand All @@ -23,7 +23,7 @@
from .settings import LinkTypeLiteralType


def blocking_rule_to_obj(br: BlockingRule | dict | str) -> BlockingRule:
def blocking_rule_to_obj(br: BlockingRule | dict[str, Any] | str) -> BlockingRule:
if isinstance(br, BlockingRule):
return br
elif isinstance(br, dict):
Expand Down Expand Up @@ -280,9 +280,9 @@ def create_blocked_pairs_sql(
class ExplodingBlockingRule(BlockingRule):
def __init__(
self,
blocking_rule: BlockingRule | dict | str,
blocking_rule: BlockingRule | dict[str, Any] | str,
sqlglot_dialect: str = None,
array_columns_to_explode: list = [],
array_columns_to_explode: list[str] = [],
):
if isinstance(blocking_rule, BlockingRule):
blocking_rule_sql = blocking_rule.blocking_rule_sql
Expand Down
6 changes: 4 additions & 2 deletions splink/blocking_rule_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import final
from typing import Any, final

from .blocking import BlockingRule, blocking_rule_to_obj
from .dialects import SplinkDialect
Expand All @@ -24,7 +26,7 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:
pass

@final
def create_blocking_rule_dict(self, sql_dialect_str: str) -> dict:
def create_blocking_rule_dict(self, sql_dialect_str: str) -> dict[str, Any]:
sql_dialect = SplinkDialect.from_string(sql_dialect_str)
level_dict = {
"blocking_rule": self.create_sql(sql_dialect),
Expand Down
6 changes: 4 additions & 2 deletions splink/blocking_rule_creator_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List, Union
from __future__ import annotations

from typing import Any, List, Union

from .blocking_rule_creator import BlockingRuleCreator
from .blocking_rule_library import CustomRule
from .misc import ensure_is_iterable


def to_blocking_rule_creator(
blocking_rule_creator: Union[dict, str, BlockingRuleCreator]
blocking_rule_creator: Union[dict[str, Any], str, BlockingRuleCreator]
):
if isinstance(blocking_rule_creator, dict):
return CustomRule(**blocking_rule_creator)
Expand Down
6 changes: 4 additions & 2 deletions splink/blocking_rule_library.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Union, final
from __future__ import annotations

from typing import Any, Union, final

from sqlglot import TokenError, parse_one

Expand Down Expand Up @@ -80,7 +82,7 @@ class _Merge(BlockingRuleCreator):
@final
def __init__(
self,
*blocking_rules: Union[BlockingRuleCreator, dict],
*blocking_rules: Union[BlockingRuleCreator, dict[str, Any]],
salting_partitions=None,
arrays_to_explode=None,
):
Expand Down
8 changes: 7 additions & 1 deletion splink/cache_dict_with_logging.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import logging
from collections import UserDict
from copy import copy
from typing import TYPE_CHECKING

from .splink_dataframe import SplinkDataFrame

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
TypedUserDict = UserDict[str, SplinkDataFrame]
else:
TypedUserDict = UserDict

class CacheDictWithLogging(UserDict):

class CacheDictWithLogging(TypedUserDict):
def __init__(self):
super().__init__()
self.executed_queries = []
Expand Down
16 changes: 8 additions & 8 deletions splink/cluster_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _quo_if_str(x):
return str(x)


def _clusters_sql(df_clustered_nodes, cluster_ids: list) -> str:
def _clusters_sql(df_clustered_nodes, cluster_ids: list[str]) -> str:
cluster_ids = [_quo_if_str(x) for x in cluster_ids]
cluster_ids_joined = ", ".join(cluster_ids)

Expand All @@ -42,8 +42,8 @@ def _clusters_sql(df_clustered_nodes, cluster_ids: list) -> str:


def df_clusters_as_records(
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list
) -> list[dict]:
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list[str]
) -> list[dict[str, Any]]:
"""Retrieves distinct clusters which exist in df_clustered_nodes based on
list of cluster IDs provided and converts them to a record dictionary.
Expand Down Expand Up @@ -87,7 +87,7 @@ def _nodes_sql(df_clustered_nodes, cluster_ids) -> str:


def create_df_nodes(
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list
linker: "Linker", df_clustered_nodes: SplinkDataFrame, cluster_ids: list[str]
) -> SplinkDataFrame:
"""Retrieves nodes from df_clustered_nodes for list of cluster IDs provided.
Expand Down Expand Up @@ -192,7 +192,7 @@ def _get_random_cluster_ids(

def _get_cluster_id_of_each_size(
linker: "Linker", connected_components: SplinkDataFrame, rows_per_partition: int
) -> list[dict]:
) -> list[dict[str, Any]]:
unique_id_col_name = linker._settings_obj.column_info_settings.unique_id_column_name
pipeline = CTEPipeline()
sql = f"""
Expand Down Expand Up @@ -239,7 +239,7 @@ def _get_lowest_density_clusters(
df_cluster_metrics: SplinkDataFrame,
rows_per_partition: int,
min_nodes: int,
) -> list[dict]:
) -> list[dict[str, Any]]:
"""Returns lowest density clusters of different sizes by
performing stratified sampling.
Expand Down Expand Up @@ -292,7 +292,7 @@ def _get_cluster_ids(
sample_size,
sample_seed,
_df_cluster_metrics: Optional[SplinkDataFrame] = None,
) -> tuple[list, list]:
) -> tuple[list[str], list[str]]:
if sampling_method == "random":
cluster_ids = _get_random_cluster_ids(
linker, df_clustered_nodes, sample_size, sample_seed
Expand Down Expand Up @@ -343,7 +343,7 @@ def render_splink_cluster_studio_html(
sample_size=10,
sample_seed=None,
cluster_ids: list[str] = None,
cluster_names: list = None,
cluster_names: list[str] = None,
overwrite: bool = False,
_df_cluster_metrics: SplinkDataFrame = None,
):
Expand Down
34 changes: 17 additions & 17 deletions splink/column_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class ColumnExpression:

def __init__(self, sql_expression: str, sql_dialect: SplinkDialect = None):
self.raw_sql_expression = sql_expression
self.operations: list[Callable] = []
self.operations: list[Callable[..., str]] = []
if sql_dialect is not None:
self.sql_dialect: SplinkDialect = sql_dialect

def _clone(self):
def _clone(self) -> "ColumnExpression":
clone = copy(self)
clone.operations = [op for op in self.operations]
return clone
Expand All @@ -58,7 +58,7 @@ def instantiate_if_str(
elif isinstance(str_or_column_expression, str):
return ColumnExpression(str_or_column_expression)

def parse_input_string(self, dialect: SplinkDialect):
def parse_input_string(self, dialect: SplinkDialect) -> str:
"""
The input into an ColumnExpression can be
- a column name or column reference e.g. first_name, first name
Expand All @@ -79,7 +79,7 @@ def parse_input_string(self, dialect: SplinkDialect):
).sql

@property
def raw_sql_is_pure_column_or_column_reference(self):
def raw_sql_is_pure_column_or_column_reference(self) -> bool:
# It's difficult (possibly impossible) to find a completely general
# algorithm that can distinguish between the two cases (col name, or sql
# expression), since lower(first_name) could technically be a column name
Expand All @@ -93,25 +93,25 @@ def raw_sql_is_pure_column_or_column_reference(self):
return True

@property
def is_pure_column_or_column_reference(self):
def is_pure_column_or_column_reference(self) -> bool:
if len(self.operations) > 0:
return False

return self.raw_sql_is_pure_column_or_column_reference

def apply_operations(self, name: str, dialect: SplinkDialect):
def apply_operations(self, name: str, dialect: SplinkDialect) -> str:
for op in self.operations:
name = op(name=name, dialect=dialect)
return name

def _lower_dialected(self, name, dialect):
def _lower_dialected(self, name: str, dialect: SplinkDialect) -> str:
lower_sql = sqlglot.parse_one("lower(___col___)").sql(
dialect=dialect.sqlglot_name
)

return lower_sql.replace("___col___", name)

def lower(self):
def lower(self) -> "ColumnExpression":
"""
Applies a lowercase transofrom to the input expression.
"""
Expand All @@ -121,14 +121,14 @@ def lower(self):

def _substr_dialected(
self, name: str, start: int, end: int, dialect: SplinkDialect
):
) -> str:
substr_sql = sqlglot.parse_one(f"substring(___col___, {start}, {end})").sql(
dialect=dialect.sqlglot_name
)

return substr_sql.replace("___col___", name)

def substr(self, start: int, length: int):
def substr(self, start: int, length: int) -> "ColumnExpression":
"""
Applies a substring transform to the input expression of a given length
starting from a specified index.
Expand All @@ -142,13 +142,13 @@ def substr(self, start: int, length: int):

return clone

def _cast_to_string_dialected(self, name: str, dialect: SplinkDialect):
def _cast_to_string_dialected(self, name: str, dialect: SplinkDialect) -> str:
cast_sql = sqlglot.parse_one("cast(___col___ as string)").sql(
dialect=dialect.sqlglot_name
)
return cast_sql.replace("___col___", name)

def cast_to_string(self):
def cast_to_string(self) -> "ColumnExpression":
"""
Applies a cast to string transform to the input expression.
"""
Expand All @@ -173,7 +173,7 @@ def _regex_extract_dialected(
capture_group=capture_group,
)

def regex_extract(self, pattern: str, capture_group: int = 0):
def regex_extract(self, pattern: str, capture_group: int = 0) -> "ColumnExpression":
"""Applies a regex extract transform to the input expression.
Args:
Expand All @@ -197,10 +197,10 @@ def _try_parse_date_dialected(
name: str,
dialect: SplinkDialect,
date_format: str = None,
):
) -> str:
return dialect.try_parse_date(name, date_format=date_format)

def try_parse_date(self, date_format: str = None):
def try_parse_date(self, date_format: str = None) -> "ColumnExpression":
clone = self._clone()
op = partial(
clone._try_parse_date_dialected,
Expand All @@ -215,10 +215,10 @@ def _try_parse_timestamp_dialected(
name: str,
dialect: SplinkDialect,
timestamp_format: str = None,
):
) -> str:
return dialect.try_parse_timestamp(name, timestamp_format=timestamp_format)

def try_parse_timestamp(self, timestamp_format: str = None):
def try_parse_timestamp(self, timestamp_format: str = None) -> "ColumnExpression":
clone = self._clone()
op = partial(
clone._try_parse_timestamp_dialected,
Expand Down
4 changes: 2 additions & 2 deletions splink/comparison.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, List, Optional

from .comparison_level import ComparisonLevel, _default_m_values, _default_u_values
from .misc import dedupe_preserving_order, join_list_with_commas_final_and
Expand Down Expand Up @@ -361,7 +361,7 @@ def _is_trained(self):
return self._all_m_are_trained and self._all_u_are_trained

@property
def _as_detailed_records(self) -> list[dict]:
def _as_detailed_records(self) -> list[dict[str, Any]]:
records = []
for cl in self.comparison_levels:
record = {}
Expand Down
6 changes: 4 additions & 2 deletions splink/comparison_creator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, final
from typing import Any, Dict, List, Optional, Union, final

from .column_expression import ColumnExpression
from .comparison import Comparison
Expand Down Expand Up @@ -126,7 +128,7 @@ def get_comparison(self, sql_dialect_str: str) -> Comparison:
)

@final
def create_comparison_dict(self, sql_dialect_str: str) -> dict:
def create_comparison_dict(self, sql_dialect_str: str) -> dict[str, Any]:
level_dict = {
"comparison_description": self.create_description(),
"output_column_name": self.create_output_column_name(),
Expand Down
4 changes: 2 additions & 2 deletions splink/comparison_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def __init__(
self._max_level: Optional[bool] = None

# Enable the level to 'know' when it's been trained
self._trained_m_probabilities: list = []
self._trained_u_probabilities: list = []
self._trained_m_probabilities: list[dict[str, Any]] = []
self._trained_u_probabilities: list[dict[str, Any]] = []
# controls warnings from model training - ensures we only send once
self._m_warning_sent = False
self._u_warning_sent = False
Expand Down
10 changes: 6 additions & 4 deletions splink/comparison_level_composition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Iterable, Union, final
from typing import Any, Iterable, Union, final

from .blocking import BlockingRule
from .comparison_creator import ComparisonLevelCreator
Expand All @@ -9,7 +9,7 @@


def _ensure_is_comparison_level_creator(
cl: Union[ComparisonLevelCreator, dict]
cl: Union[ComparisonLevelCreator, dict[str, Any]]
) -> ComparisonLevelCreator:
if isinstance(cl, dict):
from .comparison_level_library import CustomLevel
Expand All @@ -28,7 +28,9 @@ class _Merge(ComparisonLevelCreator):
_clause: str = ""

@final
def __init__(self, *comparison_levels: Union[ComparisonLevelCreator, dict]):
def __init__(
self, *comparison_levels: Union[ComparisonLevelCreator, dict[str, Any]]
):
num_levels = len(comparison_levels)
if num_levels == 0:
raise ValueError(f"Must provide at least one level to {type(self)}()")
Expand Down Expand Up @@ -92,7 +94,7 @@ class Not(ComparisonLevelCreator):
comparison level you wish to negate with 'NOT'
"""

def __init__(self, comparison_level: Union[ComparisonLevelCreator, dict]):
def __init__(self, comparison_level: Union[ComparisonLevelCreator, dict[str, Any]]):
self.comparison_level = _ensure_is_comparison_level_creator(comparison_level)
# turn null levels into non-null levels, otherwise do nothing
if self.comparison_level.is_null_level:
Expand Down
Loading

0 comments on commit fee3f05

Please sign in to comment.