Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update load settings and make it the defacto load logic #1921

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions scripts/create_venv.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
#!/bin/bash

deactivate
python3 -m venv splink-venv
source splink-venv/bin/activate
# Deactivate any currently active virtual environment
deactivate 2>/dev/null

# Check if a virtual environment name is provided as an argument
if [ -z "$1" ]; then
# No name provided, use default name 'venv'
VENV_NAME="venv"
else
VENV_NAME=$1
fi

python3 -m venv $VENV_NAME
source $VENV_NAME/bin/activate

pip install --upgrade pip
pip install poetry
poetry install
pip install pytest
# Attempt to install dependencies using poetry
if ! poetry install; then
echo "Poetry failed to install your packages. You may need to update the poetry.toml "
echo 'file to use py>=3.8 and then run \`poetry add "pandas>=2.0.0"\` before continuing.'
exit 1
fi


echo "Virtual environment '$VENV_NAME' set up and activated using poetry."
56 changes: 0 additions & 56 deletions scripts/preprocess_markdown_includes.py

This file was deleted.

59 changes: 34 additions & 25 deletions splink/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from typing import List, Union


def format_text_as_red(text):
return f"\033[91m{text}\033[0m"


# base class for any type of custom exception
class SplinkException(Exception):
pass
Expand All @@ -18,12 +21,10 @@ class EMTrainingException(SplinkException):
class ComparisonSettingsException(SplinkException):
def __init__(self, message=""):
# Add the default message to the beginning of the provided message
full_message = (
"The following errors were identified within your "
"settings object's comparisons"
)
full_message = "Errors were detected in your settings object's comparisons"

if message:
full_message += ":\n" + message
full_message += "\n" + message
super().__init__(full_message)


Expand Down Expand Up @@ -54,30 +55,40 @@ def __init__(self):
@property
def errors(self) -> str:
"""Return concatenated error messages."""
return "\n".join(self.error_queue)

def append(self, error: Union[str, Exception]) -> None:
"""Append an error to the error list.
return "\n\n".join([f"{e}" for e in self.error_queue])

def log_error(self, error: Union[list, str, Exception]) -> None:
"""
Log an error or a list of errors.

Args:
error: An error message string or an Exception instance.
error: An error message, an Exception instance, or a list
containing strings or exceptions.
"""
# Ignore if None
if isinstance(error, (list, tuple)):
for e in error:
self._log_single_error(e)
else:
self._log_single_error(error)

def _log_single_error(self, error: Union[str, Exception]) -> None:
"""Log a single error message or Exception."""
if error is None:
return

self.raw_errors.append(error)

if isinstance(error, str):
self.error_queue.append(error)
elif isinstance(error, Exception):
error_name = error.__class__.__name__
logged_exception = f"{error_name}: {str(error)}"
self.error_queue.append(logged_exception)
error_str = self._format_error(error)
self.error_queue.append(error_str)

def _format_error(self, error: Union[str, Exception]) -> str:
"""Format an error message or Exception for logging."""
if isinstance(error, Exception):
return f"{format_text_as_red(error.__class__.__name__)}: {error}"
elif isinstance(error, str):
return f"{error}"
else:
raise ValueError(
"The 'error' argument must be a string or an Exception instance."
)
raise ValueError("Error must be a string or an Exception instance.")

def raise_and_log_all_errors(self, exception=SplinkException, additional_txt=""):
"""Raise a custom exception with all logged errors.
Expand All @@ -87,7 +98,5 @@ def raise_and_log_all_errors(self, exception=SplinkException, additional_txt="")
additional_txt: Additional text to append to the error message.
"""
if self.error_queue:
raise exception(f"\n{self.errors}{additional_txt}")


warnings.simplefilter("always", SplinkDeprecated)
error_message = f"\n{self.errors}\n{additional_txt}"
raise exception(error_message)
96 changes: 37 additions & 59 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,6 @@ def __init__(

self._intermediate_table_cache: dict = CacheDictWithLogging()

if not isinstance(settings_dict, (dict, type(None))):
# Run if you've entered a filepath
# feed it a blank settings dictionary
self._setup_settings_objs(None)
self.load_settings(settings_dict)
else:
self._validate_settings_components(settings_dict)
settings_dict = deepcopy(settings_dict)
self._setup_settings_objs(settings_dict)

homogenised_tables, homogenised_aliases = self._register_input_tables(
input_table_or_tables,
input_table_aliases,
Expand All @@ -246,8 +236,8 @@ def __init__(
homogenised_tables, homogenised_aliases
)

self._validate_input_dfs()
self._validate_settings(validate_settings)
self._setup_settings_objs(deepcopy(settings_dict), validate_settings)

self._em_training_sessions = []

self._find_new_matches_mode = False
Expand Down Expand Up @@ -328,14 +318,14 @@ def _source_dataset_column_already_exists(self):

@property
def _cache_uid(self):
if self._settings_dict:
if getattr(self, "_settings_dict", None):
return self._settings_obj._cache_uid
else:
return self._cache_uid_no_settings

@_cache_uid.setter
def _cache_uid(self, value):
if self._settings_dict:
if getattr(self, "_settings_dict", None):
self._settings_obj._cache_uid = value
else:
self._cache_uid_no_settings = value
Expand Down Expand Up @@ -477,25 +467,22 @@ def _register_input_tables(self, input_tables, input_aliases, accepted_df_dtypes

return homogenised_tables, homogenised_aliases

def _setup_settings_objs(self, settings_dict):
# Setup the linker class's required settings
self._settings_dict = settings_dict

# if settings_dict is passed, set sql_dialect on it if missing, and make sure
# incompatible dialect not passed
if settings_dict is not None and settings_dict.get("sql_dialect", None) is None:
settings_dict["sql_dialect"] = self._sql_dialect

if settings_dict is None:
self._cache_uid_no_settings = ascii_uid(8)
else:
uid = settings_dict.get("linker_uid", ascii_uid(8))
settings_dict["linker_uid"] = uid
def _setup_settings_objs(self, settings_dict, validate_settings: bool = True):
# Always sets a default cache uid -> _cache_uid_no_settings
self._cache_uid = ascii_uid(8)

if settings_dict is None:
self._settings_obj_ = None
else:
self._settings_obj_ = Settings(settings_dict)
return

if not isinstance(settings_dict, (str, dict)):
raise ValueError(
"Invalid settings object supplied. Ensure this is either "
"None, a dictionary or a filepath to a settings object saved "
"as a json file."
)

self.load_settings(settings_dict, validate_settings)

def _check_for_valid_settings(self):
if (
Expand All @@ -509,24 +496,14 @@ def _check_for_valid_settings(self):
else:
return True

def _validate_settings_components(self, settings_dict):
# Vaidate our settings after plugging them through
# `Settings(<settings>)`
if settings_dict is None:
return

log_comparison_errors(
# null if not in dict - check using value is ignored
settings_dict.get("comparisons", None),
self._sql_dialect,
)

def _validate_settings(self, validate_settings):
# Vaidate our settings after plugging them through
# `Settings(<settings>)`
if not self._check_for_valid_settings():
return

self._validate_input_dfs()

# Run miscellaneous checks on our settings dictionary.
_validate_dialect(
settings_dialect=self._settings_obj._sql_dialect,
Expand Down Expand Up @@ -1142,27 +1119,23 @@ def load_settings(
settings_dict = json.loads(p.read_text())

# Store the cache ID so it can be reloaded after cache invalidation
cache_id = self._cache_uid
# So we don't run into any issues with generated tables having
# invalid columns as settings have been tweaked, invalidate
# the cache and allow these tables to be recomputed.

# This is less efficient, but triggers infrequently and ensures we don't
# run into issues where the defaults used conflict with the actual values
# supplied in settings.
cache_uid = self._cache_uid

# This is particularly relevant with `source_dataset`, which appears within
# concat_with_tf.
# Invalidate the cache if anything currently exists. If the settings are
# changing, our charts, tf tables, etc may need changing.
self.invalidate_cache()

# If a uid already exists in your settings object, prioritise this
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_id)
settings_dict["sql_dialect"] = settings_dict.get(
"sql_dialect", self._sql_dialect
)
self._settings_dict = settings_dict
self._settings_dict = settings_dict # overwrite or add

# Get the SQL dialect from settings_dict or use the default
sql_dialect = settings_dict.get("sql_dialect", self._sql_dialect)
settings_dict["sql_dialect"] = sql_dialect
settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid)

# Check the user's comparisons (if they exist)
log_comparison_errors(settings_dict.get("comparisons"), sql_dialect)
self._settings_obj_ = Settings(settings_dict)
self._validate_input_dfs()
# Check the final settings object
self._validate_settings(validate_settings)

def load_model(self, model_path: Path):
Expand Down Expand Up @@ -3753,6 +3726,11 @@ def invalidate_cache(self):
will be recomputed.
This is useful, for example, if the input data tables have changed.
"""

# Nothing to delete
if len(self._intermediate_table_cache) == 0:
return

# Before Splink executes a SQL command, it checks the cache to see
# whether a table already exists with the name of the output table

Expand Down
Loading
Loading