From 2b1a4537f69752f00973a21c344e1b5b6a1303c2 Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Wed, 31 Jan 2024 14:10:37 +0000 Subject: [PATCH 01/10] Clean up how we load settings in the linker --- splink/exceptions.py | 27 +++++++------ splink/linker.py | 94 +++++++++++++++++--------------------------- 2 files changed, 53 insertions(+), 68 deletions(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index 8fb0b503b0..c21e2e183e 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -18,10 +18,7 @@ class EMTrainingException(SplinkException): class ComparisonSettingsException(SplinkException): def __init__(self, message=""): # Add the default message to the beginning of the provided message - full_message = ( - "The following errors were identified within your " - "settings object's comparisons" - ) + full_message = "Errors were detected in your settings object's comparisons:" if message: full_message += ":\n" + message super().__init__(full_message) @@ -54,14 +51,23 @@ def __init__(self): @property def errors(self) -> str: """Return concatenated error messages.""" - return "\n".join(self.error_queue) + return "\n".join([f" - {e}" for e in self.error_queue]) - def append(self, error: Union[str, Exception]) -> None: + def append(self, error: Union[list, str, Exception]) -> None: """Append an error to the error list. Args: - error: An error message string or an Exception instance. + error: An error message string, an Exception instance or + a list or tuple containing your strings or exceptions. """ + + if isinstance(error, (list, tuple)): + for e in error: + self._append(e) + else: + self._append(error) + + def _append(self, error: Union[str, Exception]) -> None: # Ignore if None if error is None: return @@ -79,7 +85,9 @@ def append(self, error: Union[str, Exception]) -> None: "The 'error' argument must be a string or an Exception instance." ) - def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""): + def raise_and_log_all_errors( + self, exception: Exception = SplinkException, additional_txt="" + ) -> Exception: """Raise a custom exception with all logged errors. Args: @@ -88,6 +96,3 @@ def raise_and_log_all_errors(self, exception=SplinkException, additional_txt="") """ if self.error_queue: raise exception(f"\n{self.errors}{additional_txt}") - - -warnings.simplefilter("always", SplinkDeprecated) diff --git a/splink/linker.py b/splink/linker.py index e481f44271..33f680566b 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -226,16 +226,6 @@ def __init__( self._intermediate_table_cache: dict = CacheDictWithLogging() - if not isinstance(settings_dict, (dict, type(None))): - # Run if you've entered a filepath - # feed it a blank settings dictionary - self._setup_settings_objs(None) - self.load_settings(settings_dict) - else: - self._validate_settings_components(settings_dict) - settings_dict = deepcopy(settings_dict) - self._setup_settings_objs(settings_dict) - homogenised_tables, homogenised_aliases = self._register_input_tables( input_table_or_tables, input_table_aliases, @@ -246,8 +236,8 @@ def __init__( homogenised_tables, homogenised_aliases ) - self._validate_input_dfs() - self._validate_settings(validate_settings) + self._setup_settings_objs(deepcopy(settings_dict), validate_settings) + self._em_training_sessions = [] self._find_new_matches_mode = False @@ -328,14 +318,14 @@ def _source_dataset_column_already_exists(self): @property def _cache_uid(self): - if self._settings_dict: + if getattr(self, "_settings_dict", None): return self._settings_obj._cache_uid else: return self._cache_uid_no_settings @_cache_uid.setter def _cache_uid(self, value): - if self._settings_dict: + if getattr(self, "_settings_dict", None): self._settings_obj._cache_uid = value else: self._cache_uid_no_settings = value @@ -477,25 +467,22 @@ def _register_input_tables(self, input_tables, input_aliases, accepted_df_dtypes return homogenised_tables, homogenised_aliases - def _setup_settings_objs(self, settings_dict): - # Setup the linker class's required settings - self._settings_dict = settings_dict - - # if settings_dict is passed, set sql_dialect on it if missing, and make sure - # incompatible dialect not passed - if settings_dict is not None and settings_dict.get("sql_dialect", None) is None: - settings_dict["sql_dialect"] = self._sql_dialect - - if settings_dict is None: - self._cache_uid_no_settings = ascii_uid(8) - else: - uid = settings_dict.get("linker_uid", ascii_uid(8)) - settings_dict["linker_uid"] = uid + def _setup_settings_objs(self, settings_dict, validate_settings: bool = True): + # Always sets a default cache uid -> _cache_uid_no_settings + self._cache_uid = ascii_uid(8) if settings_dict is None: self._settings_obj_ = None - else: - self._settings_obj_ = Settings(settings_dict) + return + + if not isinstance(settings_dict, (str, dict)): + raise ValueError( + "Invalid settings object supplied. Ensure this is either " + "None, a dictionary or a filepath to a settings object saved " + "as a json file." + ) + + self.load_settings(settings_dict, validate_settings) def _check_for_valid_settings(self): if ( @@ -509,24 +496,14 @@ def _check_for_valid_settings(self): else: return True - def _validate_settings_components(self, settings_dict): - # Vaidate our settings after plugging them through - # `Settings()` - if settings_dict is None: - return - - log_comparison_errors( - # null if not in dict - check using value is ignored - settings_dict.get("comparisons", None), - self._sql_dialect, - ) - def _validate_settings(self, validate_settings): # Vaidate our settings after plugging them through # `Settings()` if not self._check_for_valid_settings(): return + self._validate_input_dfs() + # Run miscellaneous checks on our settings dictionary. _validate_dialect( settings_dialect=self._settings_obj._sql_dialect, @@ -1142,27 +1119,25 @@ def load_settings( settings_dict = json.loads(p.read_text()) # Store the cache ID so it can be reloaded after cache invalidation - cache_id = self._cache_uid - # So we don't run into any issues with generated tables having - # invalid columns as settings have been tweaked, invalidate - # the cache and allow these tables to be recomputed. - - # This is less efficient, but triggers infrequently and ensures we don't - # run into issues where the defaults used conflict with the actual values - # supplied in settings. + cache_uid = self._cache_uid - # This is particularly relevant with `source_dataset`, which appears within - # concat_with_tf. + # Invalidate the cache if anything currently exists. If the settings are + # changing, our charts, tf tables, etc may need changing. self.invalidate_cache() - # If a uid already exists in your settings object, prioritise this - settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_id) - settings_dict["sql_dialect"] = settings_dict.get( + self._settings_dict = settings_dict # overwrite or add + + # Get the SQL dialect from settings_dict or use the default + sql_dialect = settings_dict.get( "sql_dialect", self._sql_dialect ) - self._settings_dict = settings_dict + settings_dict["sql_dialect"] = sql_dialect + settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid) + + # Check the user's comparisons (if they exist) + log_comparison_errors(settings_dict.get("comparisons"), sql_dialect) self._settings_obj_ = Settings(settings_dict) - self._validate_input_dfs() + # Check the final settings object self._validate_settings(validate_settings) def load_model(self, model_path: Path): @@ -3753,6 +3728,11 @@ def invalidate_cache(self): will be recomputed. This is useful, for example, if the input data tables have changed. """ + + # Nothing to delete + if len(self._intermediate_table_cache) == 0: + return + # Before Splink executes a SQL command, it checks the cache to see # whether a table already exists with the name of the output table From 91588b9f72150e43491f6701757fec75d3fe46fe Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Thu, 1 Feb 2024 19:05:31 +0000 Subject: [PATCH 02/10] undo changes to exceptions.py --- splink/exceptions.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index c21e2e183e..8fb0b503b0 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -18,7 +18,10 @@ class EMTrainingException(SplinkException): class ComparisonSettingsException(SplinkException): def __init__(self, message=""): # Add the default message to the beginning of the provided message - full_message = "Errors were detected in your settings object's comparisons:" + full_message = ( + "The following errors were identified within your " + "settings object's comparisons" + ) if message: full_message += ":\n" + message super().__init__(full_message) @@ -51,23 +54,14 @@ def __init__(self): @property def errors(self) -> str: """Return concatenated error messages.""" - return "\n".join([f" - {e}" for e in self.error_queue]) + return "\n".join(self.error_queue) - def append(self, error: Union[list, str, Exception]) -> None: + def append(self, error: Union[str, Exception]) -> None: """Append an error to the error list. Args: - error: An error message string, an Exception instance or - a list or tuple containing your strings or exceptions. + error: An error message string or an Exception instance. """ - - if isinstance(error, (list, tuple)): - for e in error: - self._append(e) - else: - self._append(error) - - def _append(self, error: Union[str, Exception]) -> None: # Ignore if None if error is None: return @@ -85,9 +79,7 @@ def _append(self, error: Union[str, Exception]) -> None: "The 'error' argument must be a string or an Exception instance." ) - def raise_and_log_all_errors( - self, exception: Exception = SplinkException, additional_txt="" - ) -> Exception: + def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""): """Raise a custom exception with all logged errors. Args: @@ -96,3 +88,6 @@ def raise_and_log_all_errors( """ if self.error_queue: raise exception(f"\n{self.errors}{additional_txt}") + + +warnings.simplefilter("always", SplinkDeprecated) From 058761247a8c9ac10847656d23a238c59f0e1502 Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Thu, 1 Feb 2024 21:15:46 +0000 Subject: [PATCH 03/10] lint with black --- scripts/preprocess_markdown_includes.py | 4 +--- splink/linker.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/scripts/preprocess_markdown_includes.py b/scripts/preprocess_markdown_includes.py index b78b83be09..b1ac4e229b 100644 --- a/scripts/preprocess_markdown_includes.py +++ b/scripts/preprocess_markdown_includes.py @@ -49,8 +49,6 @@ contributing_text = f.read() # in docs CONTRIBUTING.md 'thinks' it's already in docs/, so remove that level from # relative links -new_text = re.sub( - "docs/", "", contributing_text -) +new_text = re.sub("docs/", "", contributing_text) with open(Path(".") / "CONTRIBUTING.md", "w") as f: f.write(new_text) diff --git a/splink/linker.py b/splink/linker.py index 33f680566b..1f07bd5726 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1128,9 +1128,7 @@ def load_settings( self._settings_dict = settings_dict # overwrite or add # Get the SQL dialect from settings_dict or use the default - sql_dialect = settings_dict.get( - "sql_dialect", self._sql_dialect - ) + sql_dialect = settings_dict.get("sql_dialect", self._sql_dialect) settings_dict["sql_dialect"] = sql_dialect settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid) From 49e0b1e43a835ff366d597d6c243158e7baf5c7d Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Mon, 5 Feb 2024 11:17:03 +0000 Subject: [PATCH 04/10] Add a validation check for invalid levels within comparisons --- splink/exceptions.py | 51 ++--- .../settings_validation_log_strings.py | 65 +++++++ splink/settings_validation/valid_types.py | 176 +++++++++--------- 3 files changed, 183 insertions(+), 109 deletions(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index c21e2e183e..812bb8e45a 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -2,6 +2,10 @@ from typing import List, Union +def format_text_as_red(text): + return f"\033[91m{text}\033[0m" + + # base class for any type of custom exception class SplinkException(Exception): pass @@ -18,9 +22,9 @@ class EMTrainingException(SplinkException): class ComparisonSettingsException(SplinkException): def __init__(self, message=""): # Add the default message to the beginning of the provided message - full_message = "Errors were detected in your settings object's comparisons:" + full_message = "Errors were detected in your settings object's comparisons" if message: - full_message += ":\n" + message + full_message += "\n" + message super().__init__(full_message) @@ -51,39 +55,39 @@ def __init__(self): @property def errors(self) -> str: """Return concatenated error messages.""" - return "\n".join([f" - {e}" for e in self.error_queue]) - def append(self, error: Union[list, str, Exception]) -> None: - """Append an error to the error list. + return "\n\n".join([f"{e}" for e in self.error_queue]) - Args: - error: An error message string, an Exception instance or - a list or tuple containing your strings or exceptions. + def log_error(self, error: Union[list, str, Exception]) -> None: """ + Log an error or a list of errors. + Args: + error: An error message, an Exception instance, or a list containing strings or exceptions. + """ if isinstance(error, (list, tuple)): for e in error: - self._append(e) + self._log_single_error(e) else: - self._append(error) + self._log_single_error(error) - def _append(self, error: Union[str, Exception]) -> None: - # Ignore if None + def _log_single_error(self, error: Union[str, Exception]) -> None: + """Log a single error message or Exception.""" if error is None: return self.raw_errors.append(error) - - if isinstance(error, str): - self.error_queue.append(error) - elif isinstance(error, Exception): - error_name = error.__class__.__name__ - logged_exception = f"{error_name}: {str(error)}" - self.error_queue.append(logged_exception) + error_str = self._format_error(error) + self.error_queue.append(error_str) + + def _format_error(self, error: Union[str, Exception]) -> str: + """Format an error message or Exception for logging.""" + if isinstance(error, Exception): + return f"{format_text_as_red(error.__class__.__name__)}: {error}" + elif isinstance(error, str): + return f"{error}" else: - raise ValueError( - "The 'error' argument must be a string or an Exception instance." - ) + raise ValueError("Error must be a string or an Exception instance.") def raise_and_log_all_errors( self, exception: Exception = SplinkException, additional_txt="" @@ -95,4 +99,5 @@ def raise_and_log_all_errors( additional_txt: Additional text to append to the error message. """ if self.error_queue: - raise exception(f"\n{self.errors}{additional_txt}") + error_message = f"\n{self.errors}\n{additional_txt}" + raise exception(error_message) diff --git a/splink/settings_validation/settings_validation_log_strings.py b/splink/settings_validation/settings_validation_log_strings.py index 6571ed29f8..29c1e6a4cb 100644 --- a/splink/settings_validation/settings_validation_log_strings.py +++ b/splink/settings_validation/settings_validation_log_strings.py @@ -1,5 +1,11 @@ from functools import partial from typing import NamedTuple +from typing import List, Tuple + + +def indent_error_message(message): + """Indents an error message by 4 spaces.""" + return "\n ".join(message.splitlines()) class InvalidColumnsLogGenerator(NamedTuple): @@ -136,3 +142,62 @@ def construct_missing_column_in_comparison_level_log(invalid_cls) -> str: output_warning.append(construct_invalid_sql_log_string(comp_lvls)) return "\n".join(output_warning) + + +def create_invalid_comparison_level_log_string( + comparison_string: str, invalid_comparison_levels: List[Tuple[str, str]] +): + invalid_levels_str = ",\n".join( + [ + f"- Type: {type_name}. Level: {level}" + for level, type_name in invalid_comparison_levels + ] + ) + + log_message = ( + f"\nThe comparison `{comparison_string}` contains the following invalid levels:\n" + f"{invalid_levels_str}\n\n" + "Please only include dictionaries or objects of " + "the `ComparisonLevel` class." + ) + + return indent_error_message(log_message) + + +def create_invalid_comparison_log_string( + comparison_string: str, comparison_level: bool +): + if comparison_level: + type_msg = "is a comparison level" + else: + type_msg = "is of an invalid data type" + + log_message = ( + f"\n{comparison_string}\n" + f"{type_msg} and cannot be used as a standalone comparison.\n" + "Please only include dictionaries or objects of the `Comparison` class.\n" + ) + return indent_error_message(log_message) + + +def create_no_comparison_levels_error_log_string(comparison_string: str): + log_message = ( + f"\n{comparison_string}\n" + "is missing the required `comparison_levels` dictionary" + "key. Please ensure you\ninclude this in all comparisons" + "used in your settings object.\n" + ) + return indent_error_message(log_message) + + +def create_incorrect_dialect_import_log_string( + comparison_string: str, comparison_dialects: List[str] +): + log_message = ( + f"\n{comparison_string}\n" + "contains the following invalid SQL dialect(s)" + f"within its comparison levels - {', '.join(comparison_dialects)}.\n" + "Please ensure that you're importing comparisons designed " + "for your specified linker.\n" + ) + return indent_error_message(log_message) diff --git a/splink/settings_validation/valid_types.py b/splink/settings_validation/valid_types.py index 98bea5297e..56ffa12d5d 100644 --- a/splink/settings_validation/valid_types.py +++ b/splink/settings_validation/valid_types.py @@ -1,10 +1,17 @@ from __future__ import annotations import logging +from typing import Dict, Union from ..comparison import Comparison from ..comparison_level import ComparisonLevel from ..exceptions import ComparisonSettingsException, ErrorLogger, InvalidDialect +from .settings_validation_log_strings import ( + create_incorrect_dialect_import_log_string, + create_invalid_comparison_level_log_string, + create_invalid_comparison_log_string, + create_no_comparison_levels_error_log_string, +) logger = logging.getLogger(__name__) @@ -44,8 +51,12 @@ def validate_comparison_levels( # Extract problematic comparisons for c_dict in comparisons: # If no error is found, append won't do anything - error_logger.append(evaluate_comparison_dtype_and_contents(c_dict)) - error_logger.append(evaluate_comparison_dialects(c_dict, linker_dialect)) + error_logger.log_error(evaluate_comparison_dtype_and_contents(c_dict)) + error_logger.log_error( + evaluate_comparisons_for_imports_from_incorrect_dialects( + c_dict, linker_dialect + ) + ) return error_logger @@ -65,62 +76,79 @@ def log_comparison_errors(comparisons, linker_dialect): # Raise and log any errors identified plural_this = "this" if len(error_logger.raw_errors) == 1 else "these" - comp_hyperlink_txt = f""" - For more info on {plural_this} error, please visit: - https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html - """ + comp_hyperlink_txt = ( + f"\nFor more info on how to construct comparisons and avoid {plural_this} " + "error, please visit:\n" + "https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html" + ) error_logger.raise_and_log_all_errors( exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt ) +def check_comparison_level_types( + comparison_levels: Union[Comparison, Dict], comparison_str: str +): + """ + Given a comparison, check all of its contents are either a dictionary + or a ComparisonLevel. + """ + + # Error to be handled in another function + if len(comparison_levels) == 0: + return + + # Loop through our CLs and check their types. Report any invalid types to the user. + invalid_levels = [] + for comp_level in comparison_levels: + if not isinstance(comp_level, (ComparisonLevel, dict)): + cl_str = f"{str(comp_level)[:50]}... " + invalid_levels.append((cl_str, type(comp_level))) + + if invalid_levels: + error_message = create_invalid_comparison_level_log_string( + comparison_str, invalid_levels + ) + return TypeError(error_message) + + def evaluate_comparison_dtype_and_contents(comparison_dict): """ - Given a comparison_dict, evaluate the comparison is valid. - If it's invalid, queue up an error. + Given a Comparison class or a comparison_dict, check that the comparison + is of a valid type and contains the required contents. - This can then be logged with `ErrorLogger.raise_and_log_all_errors` - or raised instantly as an error. + Checks include: + - Is the comparison a Comparison class, a dict or other? + - Does the comparison contain any comparison levels? + - Are the comparison levels contained in the comparison all valid? + + Any identified errors are subsequently passed to the error Logger. """ comp_str = f"{str(comparison_dict)[:65]}... " if not isinstance(comparison_dict, (Comparison, dict)): - if isinstance(comparison_dict, ComparisonLevel): - return TypeError( - f""" - {comp_str} - is a comparison level and - cannot be used as a standalone comparison. - """ - ) - else: - return TypeError( - f""" - The comparison level `{comp_str}` - is of an invalid data type. - Please only include dictionaries or objects of - the `Comparison` class. - """ - ) - else: - if isinstance(comparison_dict, Comparison): - comparison_dict = comparison_dict.as_dict() - comp_level = comparison_dict.get("comparison_levels") - - if comp_level is None: - return SyntaxError( - f""" - {comp_str} - is missing the required `comparison_levels` dictionary - key. Please ensure you include this in all comparisons - used in your settings object. - """ - ) + log_str = create_invalid_comparison_log_string( + comp_str, comparison_level=isinstance(comparison_dict, ComparisonLevel) + ) + return TypeError(log_str) + + if isinstance(comparison_dict, Comparison): + comparison_dict = comparison_dict.as_dict() + + comp_levels = comparison_dict.get("comparison_levels") + if comp_levels is None: + return SyntaxError(create_no_comparison_levels_error_log_string(comp_str)) -def evaluate_comparison_dialects(comparison_dict, sql_dialect): + # Check comparisons + return check_comparison_level_types(comp_levels, comp_str) + + +def evaluate_comparisons_for_imports_from_incorrect_dialects( + comparison_dict, sql_dialect +): """ Given a comparison_dict, assess whether the sql dialect is valid for your selected linker. @@ -144,53 +172,29 @@ def evaluate_comparison_dialects(comparison_dict, sql_dialect): SparkLinker(df, settings) # errors due to mismatched CL imports ``` """ - comp_str = f"{str(comparison_dict)[:65]}... " - # This function doesn't currently support dictionary comparisons - # as it's assumed users won't submit the sql dialect when using - # that functionality. - - # All other scenarios will be captured by - # `evaluate_comparison_dtype_and_contents` and can be skipped. - # Where sql_dialect is empty, we also can't verify the CLs. - if not isinstance(comparison_dict, (Comparison, dict)) or sql_dialect is None: + if not sql_dialect or not isinstance(comparison_dict, (Comparison, dict)): return - # Alternatively, if we just want to check where the import's origin, - # we could evalute the import signature using: - # `comparison_dict.__class__` - if isinstance(comparison_dict, Comparison): - comparison_dialects = set( - [ - extract_sql_dialect_from_cll(comp_level) - for comp_level in comparison_dict.comparison_levels - ] - ) - else: # only other value here is a dict - comparison_dialects = set( - [ - extract_sql_dialect_from_cll(comp_level) - for comp_level in comparison_dict.get("comparison_levels", []) - ] - ) - # if no dialect has been supplied, ignore it + comp_str = f"{str(comparison_dict)[:65]}... " + comparison_levels = ( + comparison_dict.comparison_levels + if isinstance(comparison_dict, Comparison) + else comparison_dict.get("comparison_levels", []) + ) + + comparison_dialects = set( + extract_sql_dialect_from_cll(cl) for cl in comparison_levels + ) comparison_dialects.discard(None) - comparison_dialects = [ - dialect for dialect in comparison_dialects if sql_dialect != dialect - ] # sort only needed to ensure our tests pass - # comparison_dialects = [ - # dialect for dialect in comparison_dialects if sql_dialect != dialect - # ].sort() # sort only needed to ensure our tests pass - - if comparison_dialects: - comparison_dialects.sort() - return InvalidDialect( - f""" - {comp_str} - contains the following invalid SQL dialect(s) - within its comparison levels - {', '.join(comparison_dialects)}. - Please ensure that you're importing comparisons designed - for your specified linker. - """ + # Filter out dialects that match the expected sql_dialect + invalid_dialects = [ + dialect for dialect in comparison_dialects if dialect != sql_dialect + ] + + if invalid_dialects: + error_message = create_incorrect_dialect_import_log_string( + comp_str, sorted(invalid_dialects) ) + return InvalidDialect(error_message) From 0dc4cf34d89cb78e9184d232ac8e4bc2d4aa705c Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Mon, 5 Feb 2024 11:17:34 +0000 Subject: [PATCH 05/10] Test invalid level within a comparison --- tests/test_settings_validation.py | 34 +++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/test_settings_validation.py b/tests/test_settings_validation.py index dcef28c989..a980a3ffe2 100644 --- a/tests/test_settings_validation.py +++ b/tests/test_settings_validation.py @@ -2,6 +2,7 @@ import pandas as pd import pytest +import re from splink.comparison import Comparison from splink.convert_v2_to_v3 import convert_settings_from_v2_to_v3 @@ -330,6 +331,7 @@ def test_validate_sql_dialect(): def test_comparison_validation(): import splink.athena.comparison_level_library as ath_cll import splink.duckdb.comparison_level_library as cll + import splink.duckdb.comparison_library as cl import splink.spark.comparison_level_library as sp_cll from splink.exceptions import InvalidDialect from splink.spark.comparison_library import exact_match @@ -344,10 +346,7 @@ def test_comparison_validation(): settings = get_settings_dict() # Contents aren't tested as of yet - email_no_comp_level = { - "comparison_lvls": [], - } - + email_no_comp_level = {"comparison_lvls": []} # cll instead of cl email_cc = cll.exact_match_level("email") settings["comparisons"][3] = email_cc @@ -368,6 +367,17 @@ def test_comparison_validation(): ] } ) + # a comparison containing another comparison + settings["comparisons"].append( + { + "comparison_levels": [ + cll.null_level("test"), + # Invalid Spark cll + cl.exact_match("test"), + cll.else_level(), + ] + } + ) log_comparison_errors(None, "duckdb") # confirm it works with None as an input... @@ -381,17 +391,25 @@ def test_comparison_validation(): # Check our errors are raised errors = error_logger.raw_errors + # -3 as we have three valid comparisons assert len(error_logger.raw_errors) == len(settings["comparisons"]) - 3 - # Our expected error types and part of the corresponding error text + # These errors are raised in the order they are defined in the settings expected_errors = ( (TypeError, "is a comparison level"), (TypeError, "is of an invalid data type."), (SyntaxError, "missing the required `comparison_levels`"), (InvalidDialect, "within its comparison levels - spark."), - (InvalidDialect, "within its comparison levels - presto, spark."), + (InvalidDialect, re.compile(r".*(presto.*spark|spark.*presto).*")), + (TypeError, "contains the following invalid levels"), ) for n, (e, txt) in enumerate(expected_errors): - with pytest.raises(e, match=txt): - raise errors[n] + if isinstance(txt, re.Pattern): + # If txt is a compiled regular expression, use re.search + with pytest.raises(e) as exc_info: + raise errors[n] + assert txt.search(str(exc_info.value)), f"Regex did not match for error {n}" + else: + with pytest.raises(e, match=txt): + raise errors[n] From ddc13281c8b2b783114fb615ea65e9f6d32e1ef7 Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Mon, 5 Feb 2024 11:30:08 +0000 Subject: [PATCH 06/10] lint & format settings validation code --- splink/exceptions.py | 1 - splink/linker.py | 4 +--- splink/settings_validation/settings_validation_log_strings.py | 3 +-- tests/test_settings_validation.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index 812bb8e45a..79df3459e8 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -1,4 +1,3 @@ -import warnings from typing import List, Union diff --git a/splink/linker.py b/splink/linker.py index 33f680566b..1f07bd5726 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -1128,9 +1128,7 @@ def load_settings( self._settings_dict = settings_dict # overwrite or add # Get the SQL dialect from settings_dict or use the default - sql_dialect = settings_dict.get( - "sql_dialect", self._sql_dialect - ) + sql_dialect = settings_dict.get("sql_dialect", self._sql_dialect) settings_dict["sql_dialect"] = sql_dialect settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid) diff --git a/splink/settings_validation/settings_validation_log_strings.py b/splink/settings_validation/settings_validation_log_strings.py index 29c1e6a4cb..786f8462f3 100644 --- a/splink/settings_validation/settings_validation_log_strings.py +++ b/splink/settings_validation/settings_validation_log_strings.py @@ -1,6 +1,5 @@ from functools import partial -from typing import NamedTuple -from typing import List, Tuple +from typing import List, NamedTuple, Tuple def indent_error_message(message): diff --git a/tests/test_settings_validation.py b/tests/test_settings_validation.py index a980a3ffe2..676e30c1c8 100644 --- a/tests/test_settings_validation.py +++ b/tests/test_settings_validation.py @@ -1,8 +1,8 @@ import logging +import re import pandas as pd import pytest -import re from splink.comparison import Comparison from splink.convert_v2_to_v3 import convert_settings_from_v2_to_v3 From c35463f65a2fbaa04e63f8b4a69dc82b87718157 Mon Sep 17 00:00:00 2001 From: Tom Hepworth <45356472+ThomasHepworth@users.noreply.github.com> Date: Mon, 5 Feb 2024 11:37:18 +0000 Subject: [PATCH 07/10] lint with black --- splink/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index 2bf4dd3c12..9210a73c5d 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -22,7 +22,7 @@ class ComparisonSettingsException(SplinkException): def __init__(self, message=""): # Add the default message to the beginning of the provided message full_message = "Errors were detected in your settings object's comparisons" - + if message: full_message += "\n" + message super().__init__(full_message) From 84ff619c5fd85092680e533547879d3f90dfebd3 Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Mon, 5 Feb 2024 11:41:33 +0000 Subject: [PATCH 08/10] satisfy the linter --- splink/exceptions.py | 3 ++- splink/settings_validation/settings_validation_log_strings.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/splink/exceptions.py b/splink/exceptions.py index 9210a73c5d..a97e22ae6c 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -63,7 +63,8 @@ def log_error(self, error: Union[list, str, Exception]) -> None: Log an error or a list of errors. Args: - error: An error message, an Exception instance, or a list containing strings or exceptions. + error: An error message, an Exception instance, or a list + containing strings or exceptions. """ if isinstance(error, (list, tuple)): for e in error: diff --git a/splink/settings_validation/settings_validation_log_strings.py b/splink/settings_validation/settings_validation_log_strings.py index 786f8462f3..307b5ba4a7 100644 --- a/splink/settings_validation/settings_validation_log_strings.py +++ b/splink/settings_validation/settings_validation_log_strings.py @@ -154,8 +154,8 @@ def create_invalid_comparison_level_log_string( ) log_message = ( - f"\nThe comparison `{comparison_string}` contains the following invalid levels:\n" - f"{invalid_levels_str}\n\n" + f"\nThe comparison `{comparison_string}` contains the following invalid " + f"levels:\n{invalid_levels_str}\n\n" "Please only include dictionaries or objects of " "the `ComparisonLevel` class." ) From da3b33238751fca16f2545f201c6569a57288947 Mon Sep 17 00:00:00 2001 From: Tom Hepworth Date: Tue, 6 Feb 2024 15:31:59 +0000 Subject: [PATCH 09/10] rm markdown includes --- scripts/preprocess_markdown_includes.py | 54 ------------------------- 1 file changed, 54 deletions(-) delete mode 100644 scripts/preprocess_markdown_includes.py diff --git a/scripts/preprocess_markdown_includes.py b/scripts/preprocess_markdown_includes.py deleted file mode 100644 index b1ac4e229b..0000000000 --- a/scripts/preprocess_markdown_includes.py +++ /dev/null @@ -1,54 +0,0 @@ -# this script to be run at the root of the repo -# otherwise will not correctly traverse docs folder -import logging -import os -import re -from pathlib import Path - -regex = ( - # opening tag and any whitespace - r"{%\s*" - # include-markdown literal and more whitespace - r"include-markdown\s*" - # the path in double-quotes (unvalidated) - r"\"(.*)\"" - # more whitespace and closing tag - r"\s*%}" -) - - -for root, _dirs, files in os.walk("docs"): - for file in files: - # only process md files - if re.fullmatch(r".*\.md", file): - with open(Path(root) / file, "r") as f: - current_text = f.read() - if re.search(regex, current_text): - # if we have an include tag, replace text and overwrite - for match in re.finditer(regex, current_text): - text_to_replace = match.group(0) - include_path = match.group(1) - try: - with open(Path(root) / include_path) as f_inc: - include_text = f_inc.read() - new_text = re.sub( - text_to_replace, include_text, current_text - ) - with open(Path(root) / file, "w") as f: - f.write(new_text) - # update text, in case we are iterating - current_text = new_text - # if we can't find include file then warn but carry on - except FileNotFoundError as e: - logging.warning(f"Couldn't find specified include file: {e}") - -# nasty hack: -# also adjust CONTRIBUTING.md links - mismatch between links as on github and -# where it fits in folder hierarchy of docs -with open(Path(".") / "CONTRIBUTING.md", "r") as f: - contributing_text = f.read() -# in docs CONTRIBUTING.md 'thinks' it's already in docs/, so remove that level from -# relative links -new_text = re.sub("docs/", "", contributing_text) -with open(Path(".") / "CONTRIBUTING.md", "w") as f: - f.write(new_text) From 70dd54935c025f804c7c1830e513ddb74f8e0a2e Mon Sep 17 00:00:00 2001 From: Tom Hepworth <45356472+ThomasHepworth@users.noreply.github.com> Date: Tue, 6 Feb 2024 15:34:24 +0000 Subject: [PATCH 10/10] Update splink/settings_validation/settings_validation_log_strings.py Co-authored-by: Robin Linacre --- splink/settings_validation/settings_validation_log_strings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/splink/settings_validation/settings_validation_log_strings.py b/splink/settings_validation/settings_validation_log_strings.py index 307b5ba4a7..5a051b3d66 100644 --- a/splink/settings_validation/settings_validation_log_strings.py +++ b/splink/settings_validation/settings_validation_log_strings.py @@ -173,7 +173,7 @@ def create_invalid_comparison_log_string( log_message = ( f"\n{comparison_string}\n" - f"{type_msg} and cannot be used as a standalone comparison.\n" + f"{type_msg} and must be nested within a comparison.\n" "Please only include dictionaries or objects of the `Comparison` class.\n" ) return indent_error_message(log_message)