Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Apr 16, 2024
1 parent f1d5446 commit 16987e2
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions splink/profile_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import re
from copy import deepcopy
from typing import List, Optional, Union

from .charts import altair_or_json, load_chart_definition
from .column_expression import ColumnExpression
from .database_api import DatabaseAPI
from .misc import ensure_is_list
from .pipeline import CTEPipeline
from .vertically_concatenate import vertically_concatenate_sql
Expand Down Expand Up @@ -197,7 +199,11 @@ def _add_100_percentile_to_df_percentiles(percentile_rows):


def profile_columns(
table_or_tables, db_api, column_expressions=None, top_n=10, bottom_n=10
table_or_tables,
db_api: DatabaseAPI,
column_expressions: Optional[List[Union[str, ColumnExpression]]] = None,
top_n=10,
bottom_n=10,
):
"""
Profiles the specified columns of the dataframe initiated with the linker.
Expand Down Expand Up @@ -263,7 +269,7 @@ def profile_columns(
for ce in column_expressions_raw
]

column_expressions = expressions_to_sql(column_expressions_raw)
column_expressions_as_sql = expressions_to_sql(column_expressions_raw)

sql = _col_or_expr_frequencies_raw_data_sql(
column_expressions_raw, "__splink__df_concat"
Expand All @@ -274,27 +280,26 @@ def profile_columns(

pipeline = CTEPipeline(input_dataframes=[df_raw])
sqls = _get_df_percentiles()
for sql in sqls:
pipeline.enqueue_sql(sql["sql"], sql["output_table_name"])
pipeline.enqueue_list_of_sqls(sqls)

df_percentiles = db_api.sql_pipeline_to_splink_dataframe(pipeline)
percentile_rows_all = df_percentiles.as_record_dict()

pipeline = CTEPipeline(input_dataframes=[df_raw])
sql = _get_df_top_bottom_n(column_expressions, top_n, "desc")
sql = _get_df_top_bottom_n(column_expressions_as_sql, top_n, "desc")
pipeline.enqueue_sql(sql, "__splink__df_top_n")
df_top_n = db_api.sql_pipeline_to_splink_dataframe(pipeline)
top_n_rows_all = df_top_n.as_record_dict()

pipeline = CTEPipeline(input_dataframes=[df_raw])
sql = _get_df_top_bottom_n(column_expressions, bottom_n, "asc")
sql = _get_df_top_bottom_n(column_expressions_as_sql, bottom_n, "asc")
pipeline.enqueue_sql(sql, "__splink__df_bottom_n")
df_bottom_n = db_api.sql_pipeline_to_splink_dataframe(pipeline)
bottom_n_rows_all = df_bottom_n.as_record_dict()

inner_charts = []

for expression in column_expressions:
for expression in column_expressions_as_sql:
percentile_rows = [
p for p in percentile_rows_all if p["group_name"] == _group_name(expression)
]
Expand Down

0 comments on commit 16987e2

Please sign in to comment.