Skip to content

Commit

Permalink
Merge pull request #1926 from moj-analytical-services/add_comparison_…
Browse files Browse the repository at this point in the history
…level_validation_check

Add comparison level validation check
  • Loading branch information
ThomasHepworth authored Feb 6, 2024
2 parents da3b332 + 70dd549 commit 670389e
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 120 deletions.
59 changes: 34 additions & 25 deletions splink/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from typing import List, Union


def format_text_as_red(text):
return f"\033[91m{text}\033[0m"


# base class for any type of custom exception
class SplinkException(Exception):
pass
Expand All @@ -18,12 +21,10 @@ class EMTrainingException(SplinkException):
class ComparisonSettingsException(SplinkException):
def __init__(self, message=""):
# Add the default message to the beginning of the provided message
full_message = (
"The following errors were identified within your "
"settings object's comparisons"
)
full_message = "Errors were detected in your settings object's comparisons"

if message:
full_message += ":\n" + message
full_message += "\n" + message
super().__init__(full_message)


Expand Down Expand Up @@ -54,30 +55,40 @@ def __init__(self):
@property
def errors(self) -> str:
"""Return concatenated error messages."""
return "\n".join(self.error_queue)

def append(self, error: Union[str, Exception]) -> None:
"""Append an error to the error list.
return "\n\n".join([f"{e}" for e in self.error_queue])

def log_error(self, error: Union[list, str, Exception]) -> None:
"""
Log an error or a list of errors.
Args:
error: An error message string or an Exception instance.
error: An error message, an Exception instance, or a list
containing strings or exceptions.
"""
# Ignore if None
if isinstance(error, (list, tuple)):
for e in error:
self._log_single_error(e)
else:
self._log_single_error(error)

def _log_single_error(self, error: Union[str, Exception]) -> None:
"""Log a single error message or Exception."""
if error is None:
return

self.raw_errors.append(error)

if isinstance(error, str):
self.error_queue.append(error)
elif isinstance(error, Exception):
error_name = error.__class__.__name__
logged_exception = f"{error_name}: {str(error)}"
self.error_queue.append(logged_exception)
error_str = self._format_error(error)
self.error_queue.append(error_str)

def _format_error(self, error: Union[str, Exception]) -> str:
"""Format an error message or Exception for logging."""
if isinstance(error, Exception):
return f"{format_text_as_red(error.__class__.__name__)}: {error}"
elif isinstance(error, str):
return f"{error}"
else:
raise ValueError(
"The 'error' argument must be a string or an Exception instance."
)
raise ValueError("Error must be a string or an Exception instance.")

def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""):
"""Raise a custom exception with all logged errors.
Expand All @@ -87,7 +98,5 @@ def raise_and_log_all_errors(self, exception=SplinkException, additional_txt="")
additional_txt: Additional text to append to the error message.
"""
if self.error_queue:
raise exception(f"\n{self.errors}{additional_txt}")


warnings.simplefilter("always", SplinkDeprecated)
error_message = f"\n{self.errors}\n{additional_txt}"
raise exception(error_message)
66 changes: 65 additions & 1 deletion splink/settings_validation/settings_validation_log_strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from functools import partial
from typing import NamedTuple
from typing import List, NamedTuple, Tuple


def indent_error_message(message):
"""Indents an error message by 4 spaces."""
return "\n ".join(message.splitlines())


class InvalidColumnsLogGenerator(NamedTuple):
Expand Down Expand Up @@ -136,3 +141,62 @@ def construct_missing_column_in_comparison_level_log(invalid_cls) -> str:
output_warning.append(construct_invalid_sql_log_string(comp_lvls))

return "\n".join(output_warning)


def create_invalid_comparison_level_log_string(
comparison_string: str, invalid_comparison_levels: List[Tuple[str, str]]
):
invalid_levels_str = ",\n".join(
[
f"- Type: {type_name}. Level: {level}"
for level, type_name in invalid_comparison_levels
]
)

log_message = (
f"\nThe comparison `{comparison_string}` contains the following invalid "
f"levels:\n{invalid_levels_str}\n\n"
"Please only include dictionaries or objects of "
"the `ComparisonLevel` class."
)

return indent_error_message(log_message)


def create_invalid_comparison_log_string(
comparison_string: str, comparison_level: bool
):
if comparison_level:
type_msg = "is a comparison level"
else:
type_msg = "is of an invalid data type"

log_message = (
f"\n{comparison_string}\n"
f"{type_msg} and must be nested within a comparison.\n"
"Please only include dictionaries or objects of the `Comparison` class.\n"
)
return indent_error_message(log_message)


def create_no_comparison_levels_error_log_string(comparison_string: str):
log_message = (
f"\n{comparison_string}\n"
"is missing the required `comparison_levels` dictionary"
"key. Please ensure you\ninclude this in all comparisons"
"used in your settings object.\n"
)
return indent_error_message(log_message)


def create_incorrect_dialect_import_log_string(
comparison_string: str, comparison_dialects: List[str]
):
log_message = (
f"\n{comparison_string}\n"
"contains the following invalid SQL dialect(s)"
f"within its comparison levels - {', '.join(comparison_dialects)}.\n"
"Please ensure that you're importing comparisons designed "
"for your specified linker.\n"
)
return indent_error_message(log_message)
176 changes: 90 additions & 86 deletions splink/settings_validation/valid_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

import logging
from typing import Dict, Union

from ..comparison import Comparison
from ..comparison_level import ComparisonLevel
from ..exceptions import ComparisonSettingsException, ErrorLogger, InvalidDialect
from .settings_validation_log_strings import (
create_incorrect_dialect_import_log_string,
create_invalid_comparison_level_log_string,
create_invalid_comparison_log_string,
create_no_comparison_levels_error_log_string,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,8 +51,12 @@ def validate_comparison_levels(
# Extract problematic comparisons
for c_dict in comparisons:
# If no error is found, append won't do anything
error_logger.append(evaluate_comparison_dtype_and_contents(c_dict))
error_logger.append(evaluate_comparison_dialects(c_dict, linker_dialect))
error_logger.log_error(evaluate_comparison_dtype_and_contents(c_dict))
error_logger.log_error(
evaluate_comparisons_for_imports_from_incorrect_dialects(
c_dict, linker_dialect
)
)

return error_logger

Expand All @@ -65,62 +76,79 @@ def log_comparison_errors(comparisons, linker_dialect):

# Raise and log any errors identified
plural_this = "this" if len(error_logger.raw_errors) == 1 else "these"
comp_hyperlink_txt = f"""
For more info on {plural_this} error, please visit:
https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html
"""
comp_hyperlink_txt = (
f"\nFor more info on how to construct comparisons and avoid {plural_this} "
"error, please visit:\n"
"https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html"
)

error_logger.raise_and_log_all_errors(
exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt
)


def check_comparison_level_types(
comparison_levels: Union[Comparison, Dict], comparison_str: str
):
"""
Given a comparison, check all of its contents are either a dictionary
or a ComparisonLevel.
"""

# Error to be handled in another function
if len(comparison_levels) == 0:
return

# Loop through our CLs and check their types. Report any invalid types to the user.
invalid_levels = []
for comp_level in comparison_levels:
if not isinstance(comp_level, (ComparisonLevel, dict)):
cl_str = f"{str(comp_level)[:50]}... "
invalid_levels.append((cl_str, type(comp_level)))

if invalid_levels:
error_message = create_invalid_comparison_level_log_string(
comparison_str, invalid_levels
)
return TypeError(error_message)


def evaluate_comparison_dtype_and_contents(comparison_dict):
"""
Given a comparison_dict, evaluate the comparison is valid.
If it's invalid, queue up an error.
Given a Comparison class or a comparison_dict, check that the comparison
is of a valid type and contains the required contents.
This can then be logged with `ErrorLogger.raise_and_log_all_errors`
or raised instantly as an error.
Checks include:
- Is the comparison a Comparison class, a dict or other?
- Does the comparison contain any comparison levels?
- Are the comparison levels contained in the comparison all valid?
Any identified errors are subsequently passed to the error Logger.
"""

comp_str = f"{str(comparison_dict)[:65]}... "

if not isinstance(comparison_dict, (Comparison, dict)):
if isinstance(comparison_dict, ComparisonLevel):
return TypeError(
f"""
{comp_str}
is a comparison level and
cannot be used as a standalone comparison.
"""
)
else:
return TypeError(
f"""
The comparison level `{comp_str}`
is of an invalid data type.
Please only include dictionaries or objects of
the `Comparison` class.
"""
)
else:
if isinstance(comparison_dict, Comparison):
comparison_dict = comparison_dict.as_dict()
comp_level = comparison_dict.get("comparison_levels")

if comp_level is None:
return SyntaxError(
f"""
{comp_str}
is missing the required `comparison_levels` dictionary
key. Please ensure you include this in all comparisons
used in your settings object.
"""
)
log_str = create_invalid_comparison_log_string(
comp_str, comparison_level=isinstance(comparison_dict, ComparisonLevel)
)
return TypeError(log_str)

if isinstance(comparison_dict, Comparison):
comparison_dict = comparison_dict.as_dict()

comp_levels = comparison_dict.get("comparison_levels")

if comp_levels is None:
return SyntaxError(create_no_comparison_levels_error_log_string(comp_str))

def evaluate_comparison_dialects(comparison_dict, sql_dialect):
# Check comparisons
return check_comparison_level_types(comp_levels, comp_str)


def evaluate_comparisons_for_imports_from_incorrect_dialects(
comparison_dict, sql_dialect
):
"""
Given a comparison_dict, assess whether the sql dialect is valid for
your selected linker.
Expand All @@ -144,53 +172,29 @@ def evaluate_comparison_dialects(comparison_dict, sql_dialect):
SparkLinker(df, settings) # errors due to mismatched CL imports
```
"""
comp_str = f"{str(comparison_dict)[:65]}... "

# This function doesn't currently support dictionary comparisons
# as it's assumed users won't submit the sql dialect when using
# that functionality.

# All other scenarios will be captured by
# `evaluate_comparison_dtype_and_contents` and can be skipped.
# Where sql_dialect is empty, we also can't verify the CLs.
if not isinstance(comparison_dict, (Comparison, dict)) or sql_dialect is None:
if not sql_dialect or not isinstance(comparison_dict, (Comparison, dict)):
return

# Alternatively, if we just want to check where the import's origin,
# we could evalute the import signature using:
# `comparison_dict.__class__`
if isinstance(comparison_dict, Comparison):
comparison_dialects = set(
[
extract_sql_dialect_from_cll(comp_level)
for comp_level in comparison_dict.comparison_levels
]
)
else: # only other value here is a dict
comparison_dialects = set(
[
extract_sql_dialect_from_cll(comp_level)
for comp_level in comparison_dict.get("comparison_levels", [])
]
)
# if no dialect has been supplied, ignore it
comp_str = f"{str(comparison_dict)[:65]}... "
comparison_levels = (
comparison_dict.comparison_levels
if isinstance(comparison_dict, Comparison)
else comparison_dict.get("comparison_levels", [])
)

comparison_dialects = set(
extract_sql_dialect_from_cll(cl) for cl in comparison_levels
)
comparison_dialects.discard(None)

comparison_dialects = [
dialect for dialect in comparison_dialects if sql_dialect != dialect
] # sort only needed to ensure our tests pass
# comparison_dialects = [
# dialect for dialect in comparison_dialects if sql_dialect != dialect
# ].sort() # sort only needed to ensure our tests pass

if comparison_dialects:
comparison_dialects.sort()
return InvalidDialect(
f"""
{comp_str}
contains the following invalid SQL dialect(s)
within its comparison levels - {', '.join(comparison_dialects)}.
Please ensure that you're importing comparisons designed
for your specified linker.
"""
# Filter out dialects that match the expected sql_dialect
invalid_dialects = [
dialect for dialect in comparison_dialects if dialect != sql_dialect
]

if invalid_dialects:
error_message = create_incorrect_dialect_import_log_string(
comp_str, sorted(invalid_dialects)
)
return InvalidDialect(error_message)
Loading

0 comments on commit 670389e

Please sign in to comment.