From 16987e281972c199ce3789d6fc00b4abfe1ccb28 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 16 Apr 2024 12:08:00 +0100 Subject: [PATCH] fix mypy --- splink/profile_data.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/splink/profile_data.py b/splink/profile_data.py index f37d83c3b3..c2ef57eea3 100644 --- a/splink/profile_data.py +++ b/splink/profile_data.py @@ -1,9 +1,11 @@ import logging import re from copy import deepcopy +from typing import List, Optional, Union from .charts import altair_or_json, load_chart_definition from .column_expression import ColumnExpression +from .database_api import DatabaseAPI from .misc import ensure_is_list from .pipeline import CTEPipeline from .vertically_concatenate import vertically_concatenate_sql @@ -197,7 +199,11 @@ def _add_100_percentile_to_df_percentiles(percentile_rows): def profile_columns( - table_or_tables, db_api, column_expressions=None, top_n=10, bottom_n=10 + table_or_tables, + db_api: DatabaseAPI, + column_expressions: Optional[List[Union[str, ColumnExpression]]] = None, + top_n=10, + bottom_n=10, ): """ Profiles the specified columns of the dataframe initiated with the linker. @@ -263,7 +269,7 @@ def profile_columns( for ce in column_expressions_raw ] - column_expressions = expressions_to_sql(column_expressions_raw) + column_expressions_as_sql = expressions_to_sql(column_expressions_raw) sql = _col_or_expr_frequencies_raw_data_sql( column_expressions_raw, "__splink__df_concat" @@ -274,27 +280,26 @@ def profile_columns( pipeline = CTEPipeline(input_dataframes=[df_raw]) sqls = _get_df_percentiles() - for sql in sqls: - pipeline.enqueue_sql(sql["sql"], sql["output_table_name"]) + pipeline.enqueue_list_of_sqls(sqls) df_percentiles = db_api.sql_pipeline_to_splink_dataframe(pipeline) percentile_rows_all = df_percentiles.as_record_dict() pipeline = CTEPipeline(input_dataframes=[df_raw]) - sql = _get_df_top_bottom_n(column_expressions, top_n, "desc") + sql = _get_df_top_bottom_n(column_expressions_as_sql, top_n, "desc") pipeline.enqueue_sql(sql, "__splink__df_top_n") df_top_n = db_api.sql_pipeline_to_splink_dataframe(pipeline) top_n_rows_all = df_top_n.as_record_dict() pipeline = CTEPipeline(input_dataframes=[df_raw]) - sql = _get_df_top_bottom_n(column_expressions, bottom_n, "asc") + sql = _get_df_top_bottom_n(column_expressions_as_sql, bottom_n, "asc") pipeline.enqueue_sql(sql, "__splink__df_bottom_n") df_bottom_n = db_api.sql_pipeline_to_splink_dataframe(pipeline) bottom_n_rows_all = df_bottom_n.as_record_dict() inner_charts = [] - for expression in column_expressions: + for expression in column_expressions_as_sql: percentile_rows = [ p for p in percentile_rows_all if p["group_name"] == _group_name(expression) ]