Skip to content

Commit

Permalink
Feat: Add a transform_index() method to transform the dataset index (#41
Browse files Browse the repository at this point in the history
)
  • Loading branch information
adrien-berchet authored Mar 10, 2023
1 parent ce9ea3e commit da5bb83
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
16 changes: 14 additions & 2 deletions data_validation_framework/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ def read_dataset(self):
"""
return pd.read_csv(self.dataset_df, index_col=self.input_index_col)

def transform_index(self, df):
"""Method executed after loading the dataset to transform its index.
.. note::
This transformation is applied to both the dataset and the input reports.
"""
return df

def pre_process(self, df, args, kwargs):
"""Method executed before applying the external function."""

Expand Down Expand Up @@ -440,6 +449,7 @@ def _get_dataset(self):
if self.dataset_df is not None:
L.info("Input dataset: %s", Path(self.dataset_df).resolve())
new_df = self.read_dataset()
new_df = self.transform_index(new_df)
duplicated_index = new_df.index.duplicated()
if duplicated_index.any():
raise IndexError(
Expand Down Expand Up @@ -473,8 +483,10 @@ def _join_inputs(self, new_df):
}
L.debug("Importing the following reports: %s", all_report_paths)
all_dfs = {
task_obj: self._rename_cols(
pd.read_csv(path, index_col=INDEX_LABEL).rename_axis(index="index")
task_obj: self.transform_index(
self._rename_cols(
pd.read_csv(path, index_col=INDEX_LABEL).rename_axis(index="index")
)
)
for task_obj, path in all_report_paths.items()
}
Expand Down
102 changes: 102 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,107 @@ def check_exception(failed_task, exception): # pylint: disable=unused-variable
)
]

def test_pre_process_change_index(self, tmpdir, TestTask):
"""Test that the process fails if the index is changed by the preprocess."""
dataset_df_path = str(tmpdir / "dataset.csv")
base_dataset_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, index=[0, 1, 2, 3])
base_dataset_df.to_csv(dataset_df_path, index=True, index_label="index_col")

class TestTaskUpdateIndex(TestTask):
def pre_process(self, df, args, kwargs):
df.sort_index(ascending=False, inplace=True)

@staticmethod
def validation_function(df, output_path, *args, **kwargs):
pass

failed_tasks = []
exceptions = []

@TestTaskUpdateIndex.event_handler(luigi.Event.FAILURE)
def check_exception(failed_task, exception): # pylint: disable=unused-variable
failed_tasks.append(str(failed_task))
exceptions.append(str(exception))

failing_task = TestTaskUpdateIndex(
dataset_df=dataset_df_path,
input_index_col="index_col",
result_path=str(tmpdir / "out_preprocess_update_index"),
)
assert not luigi.build([failing_task], local_scheduler=True)

assert failed_tasks == [str(failing_task)]
assert exceptions == [
str(
IndexError(
"The index changed during the process. Please update your validation function "
"or your pre/post process functions to avoid this behaviour."
)
)
]

@pytest.mark.parametrize(
"task_type",
[int, str, object, float],
)
@pytest.mark.parametrize(
"workflow_type",
[int, str, object, float],
)
def test_read_dataset_change_index(
self, tmpdir, TestTask, dataset_df_path, task_type, workflow_type
):
"""Test that the process succeeds if the index is only changed by the preprocess."""

class TestTaskUpdateIndex(TestTask):
"""A simple Task."""

def transform_index(self, df):
df.index = df.index.astype(task_type)
return df

class TestWorkflow(task.ValidationWorkflow):
"""A validation workflow."""

def transform_index(self, df):
df.index = df.index.astype(workflow_type)
return df

def inputs(self):
return {
TestTaskUpdateIndex: {},
}

@staticmethod
def validation_function(df, output_path, *args, **kwargs):
if task_type == float and workflow_type == str:
assert len(df) == 0
else:
assert len(df) == 2

failed_tasks = []
exceptions = []

@TestWorkflow.event_handler(luigi.Event.FAILURE)
def check_exception(failed_task, exception): # pylint: disable=unused-variable
failed_tasks.append(str(failed_task))
exceptions.append(str(exception))

workflow_with_index_cast = TestWorkflow(
dataset_df=dataset_df_path,
result_path=str(tmpdir / "out_preprocess_update_index"),
)
assert luigi.build([workflow_with_index_cast], local_scheduler=True)

assert not failed_tasks
assert not exceptions
res = pd.read_csv(tmpdir / "out_preprocess_update_index" / "TestWorkflow" / "report.csv")
if task_type == float and workflow_type == str:
assert len(res) == 0
else:
assert len(res) == 2
assert res["is_valid"].all()

def test_missing_retcodes(self, tmpdir, dataset_df_path, TestTask):
"""Test invalid retcodes."""

Expand Down Expand Up @@ -1512,6 +1613,7 @@ def validation_function(df, output_path, *args, **kwargs):
res = pd.read_csv(tmpdir / "extra_requires" / "TestTaskB" / "report.csv")
assert (res["extra_path"] == str(tmpdir / "file.test")).all()
assert (res["extra_result"] == "result of TestTaskA").all()
assert Path(res.loc[0, "extra_path"]).exists()

def test_static_args_kwargs(self, dataset_df_path):
"""Test the args and kwargs feature."""
Expand Down

0 comments on commit da5bb83

Please sign in to comment.