Skip to content

Commit

Permalink
Merge pull request #1921 from moj-analytical-services/update_load_set…
Browse files Browse the repository at this point in the history
…tings_and_make_it_the_defacto_load_logic

Update load settings and make it the defacto load logic
  • Loading branch information
ThomasHepworth authored Feb 6, 2024
2 parents 2a2d24c + 670389e commit 051e4ac
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 179 deletions.
59 changes: 34 additions & 25 deletions splink/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from typing import List, Union


def format_text_as_red(text):
return f"\033[91m{text}\033[0m"


# base class for any type of custom exception
class SplinkException(Exception):
pass
Expand All @@ -18,12 +21,10 @@ class EMTrainingException(SplinkException):
class ComparisonSettingsException(SplinkException):
def __init__(self, message=""):
# Add the default message to the beginning of the provided message
full_message = (
"The following errors were identified within your "
"settings object's comparisons"
)
full_message = "Errors were detected in your settings object's comparisons"

if message:
full_message += ":\n" + message
full_message += "\n" + message
super().__init__(full_message)


Expand Down Expand Up @@ -54,30 +55,40 @@ def __init__(self):
@property
def errors(self) -> str:
"""Return concatenated error messages."""
return "\n".join(self.error_queue)

def append(self, error: Union[str, Exception]) -> None:
"""Append an error to the error list.
return "\n\n".join([f"{e}" for e in self.error_queue])

def log_error(self, error: Union[list, str, Exception]) -> None:
"""
Log an error or a list of errors.
Args:
error: An error message string or an Exception instance.
error: An error message, an Exception instance, or a list
containing strings or exceptions.
"""
# Ignore if None
if isinstance(error, (list, tuple)):
for e in error:
self._log_single_error(e)
else:
self._log_single_error(error)

def _log_single_error(self, error: Union[str, Exception]) -> None:
"""Log a single error message or Exception."""
if error is None:
return

self.raw_errors.append(error)

if isinstance(error, str):
self.error_queue.append(error)
elif isinstance(error, Exception):
error_name = error.__class__.__name__
logged_exception = f"{error_name}: {str(error)}"
self.error_queue.append(logged_exception)
error_str = self._format_error(error)
self.error_queue.append(error_str)

def _format_error(self, error: Union[str, Exception]) -> str:
"""Format an error message or Exception for logging."""
if isinstance(error, Exception):
return f"{format_text_as_red(error.__class__.__name__)}: {error}"
elif isinstance(error, str):
return f"{error}"
else:
raise ValueError(
"The 'error' argument must be a string or an Exception instance."
)
raise ValueError("Error must be a string or an Exception instance.")

def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""):
"""Raise a custom exception with all logged errors.
Expand All @@ -87,7 +98,5 @@ def raise_and_log_all_errors(self, exception=SplinkException, additional_txt="")
additional_txt: Additional text to append to the error message.
"""
if self.error_queue:
raise exception(f"\n{self.errors}{additional_txt}")


warnings.simplefilter("always", SplinkDeprecated)
error_message = f"\n{self.errors}\n{additional_txt}"
raise exception(error_message)
96 changes: 37 additions & 59 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,6 @@ def __init__(

self._intermediate_table_cache: dict = CacheDictWithLogging()

if not isinstance(settings_dict, (dict, type(None))):
# Run if you've entered a filepath
# feed it a blank settings dictionary
self._setup_settings_objs(None)
self.load_settings(settings_dict)
else:
self._validate_settings_components(settings_dict)
settings_dict = deepcopy(settings_dict)
self._setup_settings_objs(settings_dict)

homogenised_tables, homogenised_aliases = self._register_input_tables(
input_table_or_tables,
input_table_aliases,
Expand All @@ -246,8 +236,8 @@ def __init__(
homogenised_tables, homogenised_aliases
)

self._validate_input_dfs()
self._validate_settings(validate_settings)
self._setup_settings_objs(deepcopy(settings_dict), validate_settings)

self._em_training_sessions = []

self._find_new_matches_mode = False
Expand Down Expand Up @@ -328,14 +318,14 @@ def _source_dataset_column_already_exists(self):

@property
def _cache_uid(self):
if self._settings_dict:
if getattr(self, "_settings_dict", None):
return self._settings_obj._cache_uid
else:
return self._cache_uid_no_settings

@_cache_uid.setter
def _cache_uid(self, value):
if self._settings_dict:
if getattr(self, "_settings_dict", None):
self._settings_obj._cache_uid = value
else:
self._cache_uid_no_settings = value
Expand Down Expand Up @@ -477,25 +467,22 @@ def _register_input_tables(self, input_tables, input_aliases, accepted_df_dtypes

return homogenised_tables, homogenised_aliases

def _setup_settings_objs(self, settings_dict):
# Setup the linker class's required settings
self._settings_dict = settings_dict

# if settings_dict is passed, set sql_dialect on it if missing, and make sure
# incompatible dialect not passed
if settings_dict is not None and settings_dict.get("sql_dialect", None) is None:
settings_dict["sql_dialect"] = self._sql_dialect

if settings_dict is None:
self._cache_uid_no_settings = ascii_uid(8)
else:
uid = settings_dict.get("linker_uid", ascii_uid(8))
settings_dict["linker_uid"] = uid
def _setup_settings_objs(self, settings_dict, validate_settings: bool = True):
# Always sets a default cache uid -> _cache_uid_no_settings
self._cache_uid = ascii_uid(8)

if settings_dict is None:
self._settings_obj_ = None
else:
self._settings_obj_ = Settings(settings_dict)
return

if not isinstance(settings_dict, (str, dict)):
raise ValueError(
"Invalid settings object supplied. Ensure this is either "
"None, a dictionary or a filepath to a settings object saved "
"as a json file."
)

self.load_settings(settings_dict, validate_settings)

def _check_for_valid_settings(self):
if (
Expand All @@ -509,24 +496,14 @@ def _check_for_valid_settings(self):
else:
return True

def _validate_settings_components(self, settings_dict):
# Vaidate our settings after plugging them through
# `Settings(<settings>)`
if settings_dict is None:
return

log_comparison_errors(
# null if not in dict - check using value is ignored
settings_dict.get("comparisons", None),
self._sql_dialect,
)

def _validate_settings(self, validate_settings):
# Vaidate our settings after plugging them through
# `Settings(<settings>)`
if not self._check_for_valid_settings():
return

self._validate_input_dfs()

# Run miscellaneous checks on our settings dictionary.
_validate_dialect(
settings_dialect=self._settings_obj._sql_dialect,
Expand Down Expand Up @@ -1142,27 +1119,23 @@ def load_settings(
settings_dict = json.loads(p.read_text())

# Store the cache ID so it can be reloaded after cache invalidation
cache_id = self._cache_uid
# So we don't run into any issues with generated tables having
# invalid columns as settings have been tweaked, invalidate
# the cache and allow these tables to be recomputed.

# This is less efficient, but triggers infrequently and ensures we don't
# run into issues where the defaults used conflict with the actual values
# supplied in settings.
cache_uid = self._cache_uid

# This is particularly relevant with `source_dataset`, which appears within
# concat_with_tf.
# Invalidate the cache if anything currently exists. If the settings are
# changing, our charts, tf tables, etc may need changing.
self.invalidate_cache()

# If a uid already exists in your settings object, prioritise this
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_id)
settings_dict["sql_dialect"] = settings_dict.get(
"sql_dialect", self._sql_dialect
)
self._settings_dict = settings_dict
self._settings_dict = settings_dict # overwrite or add

# Get the SQL dialect from settings_dict or use the default
sql_dialect = settings_dict.get("sql_dialect", self._sql_dialect)
settings_dict["sql_dialect"] = sql_dialect
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid)

# Check the user's comparisons (if they exist)
log_comparison_errors(settings_dict.get("comparisons"), sql_dialect)
self._settings_obj_ = Settings(settings_dict)
self._validate_input_dfs()
# Check the final settings object
self._validate_settings(validate_settings)

def load_model(self, model_path: Path):
Expand Down Expand Up @@ -3753,6 +3726,11 @@ def invalidate_cache(self):
will be recomputed.
This is useful, for example, if the input data tables have changed.
"""

# Nothing to delete
if len(self._intermediate_table_cache) == 0:
return

# Before Splink executes a SQL command, it checks the cache to see
# whether a table already exists with the name of the output table

Expand Down
66 changes: 65 additions & 1 deletion splink/settings_validation/settings_validation_log_strings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from functools import partial
from typing import NamedTuple
from typing import List, NamedTuple, Tuple


def indent_error_message(message):
"""Indents an error message by 4 spaces."""
return "\n ".join(message.splitlines())


class InvalidColumnsLogGenerator(NamedTuple):
Expand Down Expand Up @@ -136,3 +141,62 @@ def construct_missing_column_in_comparison_level_log(invalid_cls) -> str:
output_warning.append(construct_invalid_sql_log_string(comp_lvls))

return "\n".join(output_warning)


def create_invalid_comparison_level_log_string(
comparison_string: str, invalid_comparison_levels: List[Tuple[str, str]]
):
invalid_levels_str = ",\n".join(
[
f"- Type: {type_name}. Level: {level}"
for level, type_name in invalid_comparison_levels
]
)

log_message = (
f"\nThe comparison `{comparison_string}` contains the following invalid "
f"levels:\n{invalid_levels_str}\n\n"
"Please only include dictionaries or objects of "
"the `ComparisonLevel` class."
)

return indent_error_message(log_message)


def create_invalid_comparison_log_string(
comparison_string: str, comparison_level: bool
):
if comparison_level:
type_msg = "is a comparison level"
else:
type_msg = "is of an invalid data type"

log_message = (
f"\n{comparison_string}\n"
f"{type_msg} and must be nested within a comparison.\n"
"Please only include dictionaries or objects of the `Comparison` class.\n"
)
return indent_error_message(log_message)


def create_no_comparison_levels_error_log_string(comparison_string: str):
log_message = (
f"\n{comparison_string}\n"
"is missing the required `comparison_levels` dictionary"
"key. Please ensure you\ninclude this in all comparisons"
"used in your settings object.\n"
)
return indent_error_message(log_message)


def create_incorrect_dialect_import_log_string(
comparison_string: str, comparison_dialects: List[str]
):
log_message = (
f"\n{comparison_string}\n"
"contains the following invalid SQL dialect(s)"
f"within its comparison levels - {', '.join(comparison_dialects)}.\n"
"Please ensure that you're importing comparisons designed "
"for your specified linker.\n"
)
return indent_error_message(log_message)
Loading

0 comments on commit 051e4ac

Please sign in to comment.