Skip to content

Commit

Permalink
Merge pull request #2143 from moj-analytical-services/process_input_t…
Browse files Browse the repository at this point in the history
…ables_simplification

Process input tables simplification
  • Loading branch information
RobinL authored Apr 24, 2024
2 parents 9d27601 + cb9ffd3 commit 0df2cc6
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 15 deletions.
15 changes: 8 additions & 7 deletions splink/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
)
from .exceptions import SplinkException
from .logging_messages import execute_sql_logging_message_info, log_sql
from .misc import (
ascii_uid,
parse_duration,
)
from .misc import ascii_uid, ensure_is_list, parse_duration
from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame

Expand Down Expand Up @@ -191,7 +188,6 @@ def sql_pipeline_to_splink_dataframe(
"""

if not self.debug_mode:

sql_gen = pipeline.generate_cte_pipeline_sql()
output_tablename_templated = pipeline.output_table_name

Expand Down Expand Up @@ -231,6 +227,7 @@ def register_multiple_tables(
input_aliases: Optional[List[str]] = None,
overwrite: bool = False,
) -> Dict[str, SplinkDataFrame]:
input_tables = self.process_input_tables(input_tables)

tables_as_splink_dataframes = {}
existing_tables = []
Expand Down Expand Up @@ -271,9 +268,12 @@ def register_multiple_tables(
return tables_as_splink_dataframes

@final
def register_table(self, input, table_name, overwrite=False) -> SplinkDataFrame:
def register_table(
self, input_table, table_name, overwrite=False
) -> SplinkDataFrame:

tables_dict = self.register_multiple_tables(
[input], [table_name], overwrite=overwrite
[input_table], [table_name], overwrite=overwrite
)
return tables_dict[table_name]

Expand Down Expand Up @@ -328,6 +328,7 @@ def process_input_tables(self, input_tables) -> List:
for linker.
Default just passes through - backends can specialise if desired
"""
input_tables = ensure_is_list(input_tables)
return input_tables

# should probably also be responsible for cache
Expand Down
1 change: 1 addition & 0 deletions splink/duckdb/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def accepted_df_dtypes(self):
return accepted_df_dtypes

def process_input_tables(self, input_tables):
input_tables = super().process_input_tables(input_tables)
return [
self.load_from_file(t) if isinstance(t, str) else t for t in input_tables
]
4 changes: 1 addition & 3 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,9 @@ def __init__(

# TODO: Add test of what happens if the db_api is for a different backend
# to the sql_dialect set in the settings dict
input_tables = ensure_is_list(input_table_or_tables)
input_tables = self.db_api.process_input_tables(input_tables)

self._input_tables_dict = self._register_input_tables(
input_tables,
input_table_or_tables,
input_table_aliases,
)

Expand Down
6 changes: 1 addition & 5 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,7 @@ def profile_columns(
values to display in the respective charts.
"""

tables = ensure_is_list(table_or_tables)

tables = db_api.process_input_tables(tables)

splink_df_dict = db_api.register_multiple_tables(tables)
splink_df_dict = db_api.register_multiple_tables(table_or_tables)
input_dataframes = list(splink_df_dict.values())
input_aliases = list(splink_df_dict.keys())
input_columns = input_dataframes[0].columns_escaped
Expand Down

0 comments on commit 0df2cc6

Please sign in to comment.