Skip to content

Commit

Permalink
Allow df_concat to be created without a linker
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Apr 24, 2024
1 parent 0df2cc6 commit bc5eb21
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 58 deletions.
7 changes: 1 addition & 6 deletions splink/database_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,7 @@ def register_multiple_tables(
existing_tables = []

if not input_aliases:
# If any of the input_tables are strings, this means they refer
# to tables that already exist in the database and an alias is not needed
input_aliases = [
table if isinstance(table, str) else f"__splink__{ascii_uid(8)}"
for table in input_tables
]
input_aliases = [f"__splink__{ascii_uid(8)}" for table in input_tables]

for table, alias in zip(input_tables, input_aliases):
if isinstance(table, str):
Expand Down
15 changes: 13 additions & 2 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
_composite_unique_id_from_nodes_sql,
)
from .unlinkables import unlinkables_data

from .column_expression import ColumnExpression
from .vertically_concatenate import (
vertically_concatenate_sql,
split_df_concat_with_tf_into_two_tables_sqls,
Expand Down Expand Up @@ -1669,7 +1671,10 @@ def compute_graph_metrics(
)

def profile_columns(
self, column_expressions: str | list[str] | None = None, top_n=10, bottom_n=10
self,
column_expressions: Optional[List[Union[str, ColumnExpression]]] = None,
top_n=10,
bottom_n=10,
):
"""
Profiles the specified columns of the dataframe initiated with the linker.
Expand Down Expand Up @@ -2695,7 +2700,13 @@ def count_num_comparisons_from_blocking_rule(

pipeline = CTEPipeline()

sql = vertically_concatenate_sql(self)
sds_name = self._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=self._input_tables_dict,
salting_reqiured=self._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sql = number_of_comparisons_generated_by_blocking_rule_post_filters_sql(
Expand Down
52 changes: 30 additions & 22 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import re
from copy import deepcopy
from typing import List, Optional, Union

from .charts import altair_or_json, load_chart_definition
from .column_expression import ColumnExpression
from .database_api import DatabaseAPI
from .misc import ensure_is_list
from .pipeline import CTEPipeline
from .vertically_concatenate import vertically_concatenate_sql

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -195,7 +199,11 @@ def _add_100_percentile_to_df_percentiles(percentile_rows):


def profile_columns(
table_or_tables, db_api, column_expressions=None, top_n=10, bottom_n=10
table_or_tables,
db_api: DatabaseAPI,
column_expressions: Optional[List[Union[str, ColumnExpression]]] = None,
top_n=10,
bottom_n=10,
):
"""
Profiles the specified columns of the dataframe initiated with the linker.
Expand Down Expand Up @@ -235,32 +243,33 @@ def profile_columns(
"""

splink_df_dict = db_api.register_multiple_tables(table_or_tables)

pipeline = CTEPipeline()
sql = vertically_concatenate_sql(
splink_df_dict, salting_reqiured=False, source_dataset_column_name=None
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

input_dataframes = list(splink_df_dict.values())
input_aliases = list(splink_df_dict.keys())

input_columns = input_dataframes[0].columns_escaped

if not column_expressions:
column_expressions_raw = input_columns
else:
column_expressions_raw = ensure_is_list(column_expressions)
column_expressions = expressions_to_sql(column_expressions_raw)

pipeline = CTEPipeline(input_dataframes)
# If the user has provided any ColumnExpression, convert to string
for i in column_expressions_raw:
if isinstance(i, ColumnExpression):
i.sql_dialect = db_api.sql_dialect

cols_to_select = ", ".join(input_columns)
template = """
select {cols_to_select}
from {table_name}
"""

sql_df_concat = " UNION ALL".join(
[
template.format(cols_to_select=cols_to_select, table_name=table_name)
for table_name in input_aliases
]
)
column_expressions_raw = [
ce.name if isinstance(ce, ColumnExpression) else ce
for ce in column_expressions_raw
]

pipeline.enqueue_sql(sql_df_concat, "__splink__df_concat")
column_expressions_as_sql = expressions_to_sql(column_expressions_raw)

sql = _col_or_expr_frequencies_raw_data_sql(
column_expressions_raw, "__splink__df_concat"
Expand All @@ -271,27 +280,26 @@ def profile_columns(

pipeline = CTEPipeline(input_dataframes=[df_raw])
sqls = _get_df_percentiles()
for sql in sqls:
pipeline.enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

df_percentiles = db_api.sql_pipeline_to_splink_dataframe(pipeline)
percentile_rows_all = df_percentiles.as_record_dict()

pipeline = CTEPipeline(input_dataframes=[df_raw])
sql = _get_df_top_bottom_n(column_expressions, top_n, "desc")
sql = _get_df_top_bottom_n(column_expressions_as_sql, top_n, "desc")
pipeline.enqueue_sql(sql, "__splink__df_top_n")
df_top_n = db_api.sql_pipeline_to_splink_dataframe(pipeline)
top_n_rows_all = df_top_n.as_record_dict()

pipeline = CTEPipeline(input_dataframes=[df_raw])
sql = _get_df_top_bottom_n(column_expressions, bottom_n, "asc")
sql = _get_df_top_bottom_n(column_expressions_as_sql, bottom_n, "asc")
pipeline.enqueue_sql(sql, "__splink__df_bottom_n")
df_bottom_n = db_api.sql_pipeline_to_splink_dataframe(pipeline)
bottom_n_rows_all = df_bottom_n.as_record_dict()

inner_charts = []

for expression in column_expressions:
for expression in column_expressions_as_sql:
percentile_rows = [
p for p in percentile_rows_all if p["group_name"] == _group_name(expression)
]
Expand Down
5 changes: 5 additions & 0 deletions splink/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,11 @@ def human_readable_description(self):

@property
def salting_required(self):
# see https://github.com/duckdb/duckdb/discussions/9710
# in duckdb to parallelise we need salting
if self._sql_dialect == "duckdb":
return True

for br in self._blocking_rules_to_generate_predictions:
if isinstance(br, SaltedBlockingRule):
return True
Expand Down
76 changes: 49 additions & 27 deletions splink/vertically_concatenate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict

from .pipeline import CTEPipeline
from .splink_dataframe import SplinkDataFrame
Expand All @@ -14,9 +14,13 @@
from .linker import Linker


def vertically_concatenate_sql(linker: Linker) -> str:
def vertically_concatenate_sql(
input_tables: Dict[str, SplinkDataFrame],
salting_reqiured: bool,
source_dataset_column_name: str = None,
) -> str:
"""
Using `input_table_or_tables`, create a single table with the columns and
Using `input_tables`, create a single table with the columns and
rows required for linking.
This table will later be the basis for the generation of pairwise record comparises,
Expand All @@ -30,37 +34,32 @@ def vertically_concatenate_sql(linker: Linker) -> str:
"""

# Use column order from first table in dict
df_obj = next(iter(linker._input_tables_dict.values()))
df_obj = next(iter(input_tables.values()))
columns = df_obj.columns_escaped

select_columns_sql = ", ".join(columns)

salting_reqiured = False

source_dataset_col_req = (
linker._settings_obj.column_info_settings.source_dataset_column_name is not None
)

salting_reqiured = linker._settings_obj.salting_required

# see https://github.com/duckdb/duckdb/discussions/9710
# in duckdb to parallelise we need salting
if linker._sql_dialect == "duckdb":
salting_reqiured = True

if salting_reqiured:
salt_sql = ", random() as __splink_salt"
else:
salt_sql = ""

if source_dataset_col_req:
sqls_to_union = []
source_dataset_column_already_exists = False
if source_dataset_column_name:
source_dataset_column_already_exists = source_dataset_column_name in [
c.unquote().name for c in df_obj.columns
]

create_sds_if_needed = ""
select_columns_sql = ", ".join(columns)
if len(input_tables) > 1:
sqls_to_union = []

for df_obj in linker._input_tables_dict.values():
if not linker._source_dataset_column_already_exists:
for df_obj in input_tables.values():
if source_dataset_column_already_exists:
create_sds_if_needed = ""
else:
create_sds_if_needed = f"'{df_obj.templated_name}' as source_dataset,"

sql = f"""
select
{create_sds_if_needed}
Expand All @@ -81,14 +80,19 @@ def vertically_concatenate_sql(linker: Linker) -> str:


def enqueue_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline:

cache = linker._intermediate_table_cache
if "__splink__df_concat_with_tf" in cache:
nodes_with_tf = cache.get_with_logging("__splink__df_concat_with_tf")
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

sql = vertically_concatenate_sql(linker)
sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=linker._input_tables_dict,
salting_reqiured=linker._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker, pipeline)
Expand All @@ -104,7 +108,13 @@ def compute_df_concat_with_tf(linker: Linker, pipeline: CTEPipeline) -> SplinkDa
if "__splink__df_concat_with_tf" in cache:
return cache.get_with_logging("__splink__df_concat_with_tf")

sql = vertically_concatenate_sql(linker)
sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=linker._input_tables_dict,
salting_reqiured=linker._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

sqls = compute_all_term_frequencies_sqls(linker, pipeline)
Expand All @@ -131,7 +141,13 @@ def enqueue_df_concat(linker: Linker, pipeline: CTEPipeline) -> CTEPipeline:
pipeline.append_input_dataframe(nodes_with_tf)
return pipeline

sql = vertically_concatenate_sql(linker)
sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=linker._input_tables_dict,
salting_reqiured=linker._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

return pipeline
Expand All @@ -148,7 +164,13 @@ def compute_df_concat(linker: Linker, pipeline: CTEPipeline) -> SplinkDataFrame:
df.templated_name = "__splink__df_concat"
return df

sql = vertically_concatenate_sql(linker)
sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=linker._input_tables_dict,
salting_reqiured=linker._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)
pipeline.enqueue_sql(sql, "__splink__df_concat")

nodes_with_tf = db_api.sql_pipeline_to_splink_dataframe(pipeline)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_total_comparison_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,14 @@ def make_dummy_frame(row_count):

linker = Linker(dfs, settings, database_api=db_api)
pipeline = CTEPipeline()
sql = vertically_concatenate_sql(linker)

sds_name = linker._settings_obj.column_info_settings.source_dataset_column_name

sql = vertically_concatenate_sql(
input_tables=linker._input_tables_dict,
salting_reqiured=linker._settings_obj.salting_required,
source_dataset_column_name=sds_name,
)

pipeline.enqueue_sql(sql, "__splink__df_concat")
df_concat = linker.db_api.sql_pipeline_to_splink_dataframe(pipeline)
Expand Down

0 comments on commit bc5eb21

Please sign in to comment.