Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comparison level validation check #1926

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 34 additions & 25 deletions splink/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from typing import List, Union


def format_text_as_red(text):
return f"\033[91m{text}\033[0m"


# base class for any type of custom exception
class SplinkException(Exception):
pass
Expand All @@ -18,12 +21,10 @@ class EMTrainingException(SplinkException):
class ComparisonSettingsException(SplinkException):
def __init__(self, message=""):
# Add the default message to the beginning of the provided message
full_message = (
"The following errors were identified within your "
"settings object's comparisons"
)
full_message = "Errors were detected in your settings object's comparisons"

if message:
full_message += ":\n" + message
full_message += "\n" + message
super().__init__(full_message)


Expand Down Expand Up @@ -54,30 +55,40 @@ def __init__(self):
@property
def errors(self) -> str:
"""Return concatenated error messages."""
return "\n".join(self.error_queue)

def append(self, error: Union[str, Exception]) -> None:
"""Append an error to the error list.
return "\n\n".join([f"{e}" for e in self.error_queue])

def log_error(self, error: Union[list, str, Exception]) -> None:
"""
Log an error or a list of errors.

Args:
error: An error message string or an Exception instance.
error: An error message, an Exception instance, or a list
containing strings or exceptions.
"""
# Ignore if None
if isinstance(error, (list, tuple)):
for e in error:
self._log_single_error(e)
else:
self._log_single_error(error)

def _log_single_error(self, error: Union[str, Exception]) -> None:
"""Log a single error message or Exception."""
if error is None:
return

self.raw_errors.append(error)

if isinstance(error, str):
self.error_queue.append(error)
elif isinstance(error, Exception):
error_name = error.__class__.__name__
logged_exception = f"{error_name}: {str(error)}"
self.error_queue.append(logged_exception)
error_str = self._format_error(error)
self.error_queue.append(error_str)

def _format_error(self, error: Union[str, Exception]) -> str:
"""Format an error message or Exception for logging."""
if isinstance(error, Exception):
return f"{format_text_as_red(error.__class__.__name__)}: {error}"
elif isinstance(error, str):
return f"{error}"
else:
raise ValueError(
"The 'error' argument must be a string or an Exception instance."
)
raise ValueError("Error must be a string or an Exception instance.")

def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""):
"""Raise a custom exception with all logged errors.
Expand All @@ -87,7 +98,5 @@ def raise_and_log_all_errors(self, exception=SplinkException, additional_txt="")
additional_txt: Additional text to append to the error message.
"""
if self.error_queue:
raise exception(f"\n{self.errors}{additional_txt}")


warnings.simplefilter("always", SplinkDeprecated)
error_message = f"\n{self.errors}\n{additional_txt}"
raise exception(error_message)
66 changes: 65 additions & 1 deletion splink/settings_validation/settings_validation_log_strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from functools import partial
from typing import NamedTuple
from typing import List, NamedTuple, Tuple


def indent_error_message(message):
"""Indents an error message by 4 spaces."""
return "\n ".join(message.splitlines())


class InvalidColumnsLogGenerator(NamedTuple):
Expand Down Expand Up @@ -136,3 +141,62 @@ def construct_missing_column_in_comparison_level_log(invalid_cls) -> str:
output_warning.append(construct_invalid_sql_log_string(comp_lvls))

return "\n".join(output_warning)


def create_invalid_comparison_level_log_string(
comparison_string: str, invalid_comparison_levels: List[Tuple[str, str]]
):
invalid_levels_str = ",\n".join(
[
f"- Type: {type_name}. Level: {level}"
for level, type_name in invalid_comparison_levels
]
)

log_message = (
f"\nThe comparison `{comparison_string}` contains the following invalid "
f"levels:\n{invalid_levels_str}\n\n"
"Please only include dictionaries or objects of "
"the `ComparisonLevel` class."
)

return indent_error_message(log_message)


def create_invalid_comparison_log_string(
comparison_string: str, comparison_level: bool
):
if comparison_level:
type_msg = "is a comparison level"
else:
type_msg = "is of an invalid data type"

log_message = (
f"\n{comparison_string}\n"
f"{type_msg} and must be nested within a comparison.\n"
"Please only include dictionaries or objects of the `Comparison` class.\n"
)
return indent_error_message(log_message)


def create_no_comparison_levels_error_log_string(comparison_string: str):
log_message = (
f"\n{comparison_string}\n"
"is missing the required `comparison_levels` dictionary"
"key. Please ensure you\ninclude this in all comparisons"
"used in your settings object.\n"
)
return indent_error_message(log_message)


def create_incorrect_dialect_import_log_string(
comparison_string: str, comparison_dialects: List[str]
):
log_message = (
f"\n{comparison_string}\n"
"contains the following invalid SQL dialect(s)"
f"within its comparison levels - {', '.join(comparison_dialects)}.\n"
"Please ensure that you're importing comparisons designed "
"for your specified linker.\n"
)
return indent_error_message(log_message)
176 changes: 90 additions & 86 deletions splink/settings_validation/valid_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

import logging
from typing import Dict, Union

from ..comparison import Comparison
from ..comparison_level import ComparisonLevel
from ..exceptions import ComparisonSettingsException, ErrorLogger, InvalidDialect
from .settings_validation_log_strings import (
create_incorrect_dialect_import_log_string,
create_invalid_comparison_level_log_string,
create_invalid_comparison_log_string,
create_no_comparison_levels_error_log_string,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,8 +51,12 @@ def validate_comparison_levels(
# Extract problematic comparisons
for c_dict in comparisons:
# If no error is found, append won't do anything
error_logger.append(evaluate_comparison_dtype_and_contents(c_dict))
error_logger.append(evaluate_comparison_dialects(c_dict, linker_dialect))
error_logger.log_error(evaluate_comparison_dtype_and_contents(c_dict))
error_logger.log_error(
evaluate_comparisons_for_imports_from_incorrect_dialects(
c_dict, linker_dialect
)
)

return error_logger

Expand All @@ -65,62 +76,79 @@ def log_comparison_errors(comparisons, linker_dialect):

# Raise and log any errors identified
plural_this = "this" if len(error_logger.raw_errors) == 1 else "these"
comp_hyperlink_txt = f"""
For more info on {plural_this} error, please visit:
https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html
"""
comp_hyperlink_txt = (
f"\nFor more info on how to construct comparisons and avoid {plural_this} "
"error, please visit:\n"
"https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html"
)

error_logger.raise_and_log_all_errors(
exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt
)


def check_comparison_level_types(
comparison_levels: Union[Comparison, Dict], comparison_str: str
):
"""
Given a comparison, check all of its contents are either a dictionary
or a ComparisonLevel.
"""

# Error to be handled in another function
if len(comparison_levels) == 0:
return

# Loop through our CLs and check their types. Report any invalid types to the user.
invalid_levels = []
for comp_level in comparison_levels:
if not isinstance(comp_level, (ComparisonLevel, dict)):
cl_str = f"{str(comp_level)[:50]}... "
invalid_levels.append((cl_str, type(comp_level)))

if invalid_levels:
error_message = create_invalid_comparison_level_log_string(
comparison_str, invalid_levels
)
return TypeError(error_message)


def evaluate_comparison_dtype_and_contents(comparison_dict):
"""
Given a comparison_dict, evaluate the comparison is valid.
If it's invalid, queue up an error.
Given a Comparison class or a comparison_dict, check that the comparison
is of a valid type and contains the required contents.

This can then be logged with `ErrorLogger.raise_and_log_all_errors`
or raised instantly as an error.
Checks include:
- Is the comparison a Comparison class, a dict or other?
- Does the comparison contain any comparison levels?
- Are the comparison levels contained in the comparison all valid?

Any identified errors are subsequently passed to the error Logger.
"""

comp_str = f"{str(comparison_dict)[:65]}... "

if not isinstance(comparison_dict, (Comparison, dict)):
if isinstance(comparison_dict, ComparisonLevel):
return TypeError(
f"""
{comp_str}
is a comparison level and
cannot be used as a standalone comparison.
"""
)
else:
return TypeError(
f"""
The comparison level `{comp_str}`
is of an invalid data type.
Please only include dictionaries or objects of
the `Comparison` class.
"""
)
else:
if isinstance(comparison_dict, Comparison):
comparison_dict = comparison_dict.as_dict()
comp_level = comparison_dict.get("comparison_levels")

if comp_level is None:
return SyntaxError(
f"""
{comp_str}
is missing the required `comparison_levels` dictionary
key. Please ensure you include this in all comparisons
used in your settings object.
"""
)
log_str = create_invalid_comparison_log_string(
comp_str, comparison_level=isinstance(comparison_dict, ComparisonLevel)
)
return TypeError(log_str)

if isinstance(comparison_dict, Comparison):
comparison_dict = comparison_dict.as_dict()

comp_levels = comparison_dict.get("comparison_levels")

if comp_levels is None:
return SyntaxError(create_no_comparison_levels_error_log_string(comp_str))

def evaluate_comparison_dialects(comparison_dict, sql_dialect):
# Check comparisons
return check_comparison_level_types(comp_levels, comp_str)


def evaluate_comparisons_for_imports_from_incorrect_dialects(
comparison_dict, sql_dialect
):
"""
Given a comparison_dict, assess whether the sql dialect is valid for
your selected linker.
Expand All @@ -144,53 +172,29 @@ def evaluate_comparison_dialects(comparison_dict, sql_dialect):
SparkLinker(df, settings) # errors due to mismatched CL imports
```
"""
comp_str = f"{str(comparison_dict)[:65]}... "

# This function doesn't currently support dictionary comparisons
# as it's assumed users won't submit the sql dialect when using
# that functionality.

# All other scenarios will be captured by
# `evaluate_comparison_dtype_and_contents` and can be skipped.
# Where sql_dialect is empty, we also can't verify the CLs.
if not isinstance(comparison_dict, (Comparison, dict)) or sql_dialect is None:
if not sql_dialect or not isinstance(comparison_dict, (Comparison, dict)):
return

# Alternatively, if we just want to check where the import's origin,
# we could evalute the import signature using:
# `comparison_dict.__class__`
if isinstance(comparison_dict, Comparison):
comparison_dialects = set(
[
extract_sql_dialect_from_cll(comp_level)
for comp_level in comparison_dict.comparison_levels
]
)
else: # only other value here is a dict
comparison_dialects = set(
[
extract_sql_dialect_from_cll(comp_level)
for comp_level in comparison_dict.get("comparison_levels", [])
]
)
# if no dialect has been supplied, ignore it
comp_str = f"{str(comparison_dict)[:65]}... "
comparison_levels = (
comparison_dict.comparison_levels
if isinstance(comparison_dict, Comparison)
else comparison_dict.get("comparison_levels", [])
)

comparison_dialects = set(
extract_sql_dialect_from_cll(cl) for cl in comparison_levels
)
comparison_dialects.discard(None)

comparison_dialects = [
dialect for dialect in comparison_dialects if sql_dialect != dialect
] # sort only needed to ensure our tests pass
# comparison_dialects = [
# dialect for dialect in comparison_dialects if sql_dialect != dialect
# ].sort() # sort only needed to ensure our tests pass

if comparison_dialects:
comparison_dialects.sort()
return InvalidDialect(
f"""
{comp_str}
contains the following invalid SQL dialect(s)
within its comparison levels - {', '.join(comparison_dialects)}.
Please ensure that you're importing comparisons designed
for your specified linker.
"""
# Filter out dialects that match the expected sql_dialect
invalid_dialects = [
dialect for dialect in comparison_dialects if dialect != sql_dialect
]

if invalid_dialects:
error_message = create_incorrect_dialect_import_log_string(
comp_str, sorted(invalid_dialects)
)
return InvalidDialect(error_message)
Loading
Loading