From 454ca968c1b362decd12ebb63b4c4850afc887c2 Mon Sep 17 00:00:00 2001 From: Dazhong Xia Date: Tue, 14 Jan 2025 16:30:43 -0500 Subject: [PATCH] Finish unittest -> pytest conversion (#4014) * Refactor CsvExtractor unit tests for better test independence * Refactor unit tests for better test independence * Refactor unit tests for better test independence * Refactor unit tests for better test independence * Refactor unit tests for better test independence * Remove unnecessary IDE configuration files * [pre-commit.ci] auto fixes from pre-commit.com hooks For more information, see https://pre-commit.ci * Refactor excel_test.py for better test independence * Refactor excel_test.py for better test independence * [pre-commit.ci] auto fixes from pre-commit.com hooks For more information, see https://pre-commit.ci * Refactor excel_test.py for better test independence * Refactor excel_test.py for better test independance * [pre-commit.ci] auto fixes from pre-commit.com hooks For more information, see https://pre-commit.ci * Refactor classes_test.py for improved test organization * Refactor mocks in classes_test.py for improved readability * Convert unittest assertions to pytest style in resource_cache_test.py * Convert unittest assertions to pytest style in resource_cache_test.py * Convert unittest assertions to pytest style in resource_cache_test.py * chore: revert to old setup structure for ferc1 output test The main improvement I see to the test setup is to make the graph definitions more atomic. Currently we have the graph definitions split between node creation in setup and edge definition in the tests themselves. * [pre-commit.ci] auto fixes from pre-commit.com hooks For more information, see https://pre-commit.ci --------- Co-authored-by: Gaurav Gurjar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zane Selvans --- test/unit/extract/csv_test.py | 74 +++-- test/unit/extract/excel_test.py | 54 ++-- test/unit/extract/phmsagas_test.py | 28 +- test/unit/output/ferc1_test.py | 13 +- test/unit/workspace/datastore_test.py | 298 +++++++++------------ test/unit/workspace/resource_cache_test.py | 136 +++++----- 6 files changed, 277 insertions(+), 326 deletions(-) diff --git a/test/unit/extract/csv_test.py b/test/unit/extract/csv_test.py index 8f47bf653c..cb328851bb 100644 --- a/test/unit/extract/csv_test.py +++ b/test/unit/extract/csv_test.py @@ -1,7 +1,3 @@ -"""Unit tests for pudl.extract.csv module.""" - -from unittest.mock import MagicMock, patch - import pandas as pd import pytest @@ -16,16 +12,16 @@ class FakeExtractor(CsvExtractor): - def __init__(self): + def __init__(self, mocker): # TODO: Make these tests independent of the eia176 implementation self.METADATA = GenericMetadata("eia176") - super().__init__(ds=MagicMock()) + super().__init__(ds=mocker.MagicMock()) @pytest.fixture -def extractor(): - # Create an instance of the CsvExtractor class - return FakeExtractor() +def extractor(mocker): + # Create an instance of the CsvExtractor class with mocker + return FakeExtractor(mocker) def test_source_filename_valid_partition(extractor): @@ -45,8 +41,8 @@ def test_source_filename_multiple_selections(extractor): extractor.source_filename(PAGE, **multiple_selections) -@patch("pudl.extract.csv.pd") -def test_load_source(mock_pd, extractor): +def test_load_source(mocker, extractor): + mock_pd = mocker.patch("pudl.extract.csv.pd") assert extractor.load_source(PAGE, **PARTITION) == mock_pd.read_csv.return_value extractor.ds.get_zipfile_resource.assert_called_once_with(DATASET, **PARTITION) zipfile = extractor.ds.get_zipfile_resource.return_value.__enter__.return_value @@ -55,7 +51,7 @@ def test_load_source(mock_pd, extractor): mock_pd.read_csv.assert_called_once_with(file) -def test_extract(extractor): +def test_extract(mocker, extractor): # Create a sample of data we could expect from an EIA CSV company_field = "company" company_data = "Total of All Companies" @@ -64,22 +60,20 @@ def test_extract(extractor): # TODO: Once FakeExtractor is independent of eia176, mock out populating _column_map for PARTITION_SELECTION; # Also include negative tests, i.e., for partition selections not in the _column_map - with ( - patch.object(CsvExtractor, "load_source", return_value=df), - patch.object( - # Testing the rename - GenericMetadata, - "get_column_map", - return_value={"company_rename": company_field}, - ), - patch.object( - # Transposing the df here to get the orientation we expect get_page_cols to return - CsvExtractor, - "get_page_cols", - return_value=df.T.index, - ), - ): - res = extractor.extract(**PARTITION) + mocker.patch.object(CsvExtractor, "load_source", return_value=df) + # Testing the rename + mocker.patch.object( + GenericMetadata, + "get_column_map", + return_value={"company_rename": company_field}, + ) + # Transposing the df here to get the orientation we expect get_page_cols to return + mocker.patch.object( + CsvExtractor, + "get_page_cols", + return_value=df.T.index, + ) + res = extractor.extract(**PARTITION) assert len(res) == 1 # Assert only one page extracted assert list(res.keys()) == [PAGE] # Assert it is named correctly assert ( @@ -87,11 +81,9 @@ def test_extract(extractor): ) # Assert that column correctly renamed and data is there. -@patch.object(FakeExtractor, "METADATA") -def test_validate_exact_columns(mock_metadata, extractor): +def test_validate_exact_columns(mocker, extractor): # Mock the partition selection and page columns - # mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with the exact expected columns df = pd.DataFrame(columns=["col1", "col2"]) @@ -100,11 +92,9 @@ def test_validate_exact_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_extra_columns(mock_metadata, extractor): +def test_validate_extra_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with extra columns df = pd.DataFrame(columns=["col1", "col2", "col3"]) @@ -114,11 +104,9 @@ def test_validate_extra_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_missing_columns(mock_metadata, extractor): +def test_validate_missing_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with missing columns df = pd.DataFrame(columns=["col1"]) @@ -130,11 +118,9 @@ def test_validate_missing_columns(mock_metadata, extractor): extractor.validate(df, "page1", partition="partition1") -@patch.object(FakeExtractor, "METADATA") -def test_validate_extra_and_missing_columns(mock_metadata, extractor): +def test_validate_extra_and_missing_columns(mocker, extractor): # Mock the partition selection and page columns - mock_metadata._get_partition_selection.return_value = "partition1" - extractor.get_page_cols = MagicMock(return_value={"col1", "col2"}) + extractor.get_page_cols = mocker.MagicMock(return_value={"col1", "col2"}) # Create a DataFrame with both extra and missing columns df = pd.DataFrame(columns=["col1", "col3"]) diff --git a/test/unit/extract/excel_test.py b/test/unit/extract/excel_test.py index f9a85f0ceb..2d2aa7de24 100644 --- a/test/unit/extract/excel_test.py +++ b/test/unit/extract/excel_test.py @@ -1,35 +1,42 @@ """Unit tests for pudl.extract.excel module.""" -import unittest -from unittest import mock as mock - import pandas as pd +import pytest from pudl.extract import excel -class TestMetadata(unittest.TestCase): +class TestMetadata: """Tests basic operation of the excel.Metadata object.""" + @pytest.fixture(autouse=True) def setUp(self): - """Cosntructs test metadata instance for testing.""" + """Constructs test metadata instance for testing.""" self._metadata = excel.ExcelMetadata("test") def test_basics(self): """Test that basic API method return expected results.""" - self.assertEqual("test", self._metadata.get_dataset_name()) - self.assertListEqual( - ["books", "boxes", "shoes"], self._metadata.get_all_pages() - ) - self.assertListEqual( - ["author", "pages", "title"], self._metadata.get_all_columns("books") - ) - self.assertDictEqual( - {"book_title": "title", "name": "author", "pages": "pages"}, - self._metadata.get_column_map("books", year=2010), - ) - self.assertEqual(10, self._metadata.get_skiprows("boxes", year=2011)) - self.assertEqual(1, self._metadata.get_sheet_name("boxes", year=2011)) + assert self._metadata.get_dataset_name() == "test" + assert self._metadata.get_all_pages() == ["books", "boxes", "shoes"] + assert self._metadata.get_all_columns("books") == ["author", "pages", "title"] + assert self._metadata.get_column_map("books", year=2010) == { + "book_title": "title", + "name": "author", + "pages": "pages", + } + assert self._metadata.get_skiprows("boxes", year=2011) == 10 + assert self._metadata.get_sheet_name("boxes", year=2011) == 1 + + def test_metadata_methods(self): + """Test various metadata methods.""" + assert self._metadata.get_all_columns("books") == ["author", "pages", "title"] + assert self._metadata.get_column_map("books", year=2010) == { + "book_title": "title", + "name": "author", + "pages": "pages", + } + assert self._metadata.get_skiprows("boxes", year=2011) == 10 + assert self._metadata.get_sheet_name("boxes", year=2011) == 1 class FakeExtractor(excel.ExcelExtractor): @@ -77,11 +84,10 @@ def _fake_data_frames(page_name, **kwargs): return fake_data[page_name] -class TestExtractor(unittest.TestCase): +class TestExtractor: """Test operation of the excel.Extractor class.""" - @staticmethod - def test_extract(): + def test_extract(self): extractor = FakeExtractor() res = extractor.extract(year=[2010, 2011]) expected_books = { @@ -103,7 +109,7 @@ def test_extract(): # def test_resulting_dataframes(self): # """Checks that pages across years are merged and columns are translated.""" # dfs = FakeExtractor().extract([2010, 2011], testing=True) - # self.assertEqual(set(['books', 'boxes']), set(dfs.keys())) + # assert set(['books', 'boxes']) == set(dfs.keys()) # pd.testing.assert_frame_equal( # pd.DataFrame(data={ # 'author': ['Laozi', 'Benjamin Hoff'], @@ -118,5 +124,5 @@ def test_extract(): # }), # dfs['boxes']) - # TODO(rousik@gmail.com): need to figure out how to test process_$x methods. - # TODO(rousik@gmail.com): we should test that empty columns are properly added. + # TODO: need to figure out how to test process_$x methods. + # TODO: we should test that empty columns are properly added. diff --git a/test/unit/extract/phmsagas_test.py b/test/unit/extract/phmsagas_test.py index ace2b9d966..2a6b21f1fc 100644 --- a/test/unit/extract/phmsagas_test.py +++ b/test/unit/extract/phmsagas_test.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock, patch - import pandas as pd import pytest @@ -8,20 +6,20 @@ class FakeExtractor(Extractor): - def __init__(self): + def __init__(self, mocker): self.METADATA = ExcelMetadata("phmsagas") - super().__init__(ds=MagicMock()) - self._metadata = MagicMock() + super().__init__(ds=mocker.Mock()) + self._metadata = mocker.Mock() @pytest.fixture -def extractor(): +def extractor(mocker): # Create an instance of the CsvExtractor class - return FakeExtractor() + return FakeExtractor(mocker) -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_drop_columns(mock_logger, extractor): +def test_process_renamed_drop_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "gas_transmission_gathering" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -38,8 +36,8 @@ def test_process_renamed_drop_columns(mock_logger, extractor): mock_logger.info.assert_called_once() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_keep_columns(mock_logger, extractor): +def test_process_renamed_keep_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "gas_transmission_gathering" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -56,8 +54,8 @@ def test_process_renamed_keep_columns(mock_logger, extractor): mock_logger.info.assert_not_called() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_drop_unnamed_columns(mock_logger, extractor): +def test_process_renamed_drop_unnamed_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "some_form" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] @@ -74,8 +72,8 @@ def test_process_renamed_drop_unnamed_columns(mock_logger, extractor): mock_logger.warning.assert_not_called() -@patch("pudl.extract.phmsagas.logger") -def test_process_renamed_warn_unnamed_columns(mock_logger, extractor): +def test_process_renamed_warn_unnamed_columns(mocker, extractor): + mock_logger = mocker.patch("pudl.extract.phmsagas.logger") # Mock metadata methods extractor._metadata.get_form.return_value = "some_form" extractor._metadata.get_all_columns.return_value = ["col1", "col2"] diff --git a/test/unit/output/ferc1_test.py b/test/unit/output/ferc1_test.py index 7b87401f1d..c32478746f 100644 --- a/test/unit/output/ferc1_test.py +++ b/test/unit/output/ferc1_test.py @@ -19,7 +19,6 @@ """ import logging -import unittest from io import StringIO import networkx as nx @@ -37,11 +36,7 @@ logger = logging.getLogger(__name__) -class TestForestSetup(unittest.TestCase): - def setUp(self): - # this is where you add nodes you want to use - pass - +class TestForestSetup: def _exploded_calcs_from_edges(self, edges: list[tuple[NodeId, NodeId]]): records = [] for parent, child in edges: @@ -89,8 +84,8 @@ def build_forest_and_annotated_tags( return annotated_tags -class TestPrunnedNode(TestForestSetup): - def setUp(self): +class TestPrunedNode(TestForestSetup): + def setup_method(self): self.root = NodeId( table_name="table_1", xbrl_factoid="reported_1", @@ -133,7 +128,7 @@ def test_pruned_nodes(self): class TestTagPropagation(TestForestSetup): - def setUp(self): + def setup_method(self): self.parent = NodeId( table_name="table_1", xbrl_factoid="reported_1", diff --git a/test/unit/workspace/datastore_test.py b/test/unit/workspace/datastore_test.py index d6840ad1a6..06c6972246 100644 --- a/test/unit/workspace/datastore_test.py +++ b/test/unit/workspace/datastore_test.py @@ -3,7 +3,6 @@ import io import json import re -import unittest import zipfile from typing import Any @@ -46,154 +45,121 @@ def _make_descriptor( ) -class TestDatapackageDescriptor(unittest.TestCase): - """Unit tests for the DatapackageDescriptor class.""" +def test_get_partition_filters(): + desc = _make_descriptor( + "blabla", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue"), + _make_resource("baz", group="second", color="black", order=1), + ) + assert list(desc.get_partition_filters()) == [ + {"group": "first", "color": "red"}, + {"group": "first", "color": "blue"}, + {"group": "second", "color": "black", "order": 1}, + ] + assert list(desc.get_partition_filters(group="first")) == [ + {"group": "first", "color": "red"}, + {"group": "first", "color": "blue"}, + ] + assert list(desc.get_partition_filters(color="blue")) == [ + {"group": "first", "color": "blue"}, + ] + assert list(desc.get_partition_filters(color="blue", group="second")) == [] + + +def test_get_resource_path(): + """Check that get_resource_path returns correct paths.""" + desc = _make_descriptor( + "blabla", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue"), + ) + assert desc.get_resource_path("foo") == "http://localhost/foo" + assert desc.get_resource_path("bar") == "http://localhost/bar" + with pytest.raises(KeyError): + desc.get_resource_path("other") - def test_get_partition_filters(self): - desc = _make_descriptor( - "blabla", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue"), - _make_resource("baz", group="second", color="black", order=1), - ) - self.assertEqual( - [ - {"group": "first", "color": "red"}, - {"group": "first", "color": "blue"}, - {"group": "second", "color": "black", "order": 1}, - ], - list(desc.get_partition_filters()), - ) - self.assertEqual( - [ - {"group": "first", "color": "red"}, - {"group": "first", "color": "blue"}, - ], - list(desc.get_partition_filters(group="first")), - ) - self.assertEqual( - [ - {"group": "first", "color": "blue"}, - ], - list(desc.get_partition_filters(color="blue")), - ) - self.assertEqual( - [], list(desc.get_partition_filters(color="blue", group="second")) - ) - def test_get_resource_path(self): - """Check that get_resource_path returns correct paths.""" - desc = _make_descriptor( - "blabla", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue"), - ) - self.assertEqual("http://localhost/foo", desc.get_resource_path("foo")) - self.assertEqual("http://localhost/bar", desc.get_resource_path("bar")) - # The following resource does not exist and should throw KeyError - self.assertRaises(KeyError, desc.get_resource_path, "other") - - def test_modernize_zenodo_legacy_api_url(self): - legacy_url = "https://zenodo.org/api/files/082e4932-c772-4e9c-a670-376a1acc3748/datapackage.json" - - descriptor = datastore.DatapackageDescriptor( - {"resources": [{"name": "datapackage.json", "path": legacy_url}]}, - dataset="test", - doi="10.5281/zenodo.123123", - ) +def test_modernize_zenodo_legacy_api_url(): + legacy_url = "https://zenodo.org/api/files/082e4932-c772-4e9c-a670-376a1acc3748/datapackage.json" - assert ( - descriptor.get_resource_path("datapackage.json") - == "https://zenodo.org/records/123123/files/datapackage.json" - ) + descriptor = datastore.DatapackageDescriptor( + {"resources": [{"name": "datapackage.json", "path": legacy_url}]}, + dataset="test", + doi="10.5281/zenodo.123123", + ) - def test_get_resources_filtering(self): - """Verifies correct operation of get_resources().""" - desc = _make_descriptor( - "data", - "doi-123", - _make_resource("foo", group="first", color="red"), - _make_resource("bar", group="first", color="blue", rank=5), - _make_resource( - "baz", group="second", color="blue", rank=5, mood="VeryHappy" - ), - ) - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - PudlResourceKey("data", "doi-123", "bar"), - PudlResourceKey("data", "doi-123", "baz"), - ], - list(desc.get_resources()), - ) - # Simple filtering by one attribute. - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - PudlResourceKey("data", "doi-123", "bar"), - ], - list(desc.get_resources(group="first")), - ) - # Filter by two attributes - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "bar"), - ], - list(desc.get_resources(group="first", rank=5)), - ) - # Attributes that do not match anything - self.assertEqual( - [], - list(desc.get_resources(group="second", shape="square")), - ) - # Search attribute values are cast to lowercase strings - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "baz"), - ], - list(desc.get_resources(rank="5", mood="VERYhappy")), - ) - # Test lookup by name - self.assertEqual( - [ - PudlResourceKey("data", "doi-123", "foo"), - ], - list(desc.get_resources("foo")), - ) + assert ( + descriptor.get_resource_path("datapackage.json") + == "https://zenodo.org/records/123123/files/datapackage.json" + ) - def test_json_string_representation(self): - """Checks that json representation parses to the same dict.""" - desc = _make_descriptor( - "data", - "doi-123", - _make_resource("foo", group="first"), - _make_resource("bar", group="second"), - _make_resource("baz"), - ) - self.assertEqual( + +def test_get_resources_filtering(): + """Verifies correct operation of get_resources().""" + desc = _make_descriptor( + "data", + "doi-123", + _make_resource("foo", group="first", color="red"), + _make_resource("bar", group="first", color="blue", rank=5), + _make_resource("baz", group="second", color="blue", rank=5, mood="VeryHappy"), + ) + assert list(desc.get_resources()) == [ + PudlResourceKey("data", "doi-123", "foo"), + PudlResourceKey("data", "doi-123", "bar"), + PudlResourceKey("data", "doi-123", "baz"), + ] + # Simple filtering by one attribute. + assert list(desc.get_resources(group="first")) == [ + PudlResourceKey("data", "doi-123", "foo"), + PudlResourceKey("data", "doi-123", "bar"), + ] + # Filter by two attributes + assert list(desc.get_resources(group="first", rank=5)) == [ + PudlResourceKey("data", "doi-123", "bar"), + ] + # Attributes that do not match anything + assert list(desc.get_resources(group="second", shape="square")) == [] + # Search attribute values are cast to lowercase strings + assert list(desc.get_resources(rank="5", mood="VERYhappy")) == [ + PudlResourceKey("data", "doi-123", "baz"), + ] + # Test lookup by name + assert list(desc.get_resources("foo")) == [ + PudlResourceKey("data", "doi-123", "foo"), + ] + + +def test_json_string_representation(): + """Checks that json representation parses to the same dict.""" + desc = _make_descriptor( + "data", + "doi-123", + _make_resource("foo", group="first"), + _make_resource("bar", group="second"), + _make_resource("baz"), + ) + assert json.loads(desc.get_json_string()) == { + "resources": [ { - "resources": [ - { - "name": "foo", - "path": "http://localhost/foo", - "parts": {"group": "first"}, - }, - { - "name": "bar", - "path": "http://localhost/bar", - "parts": {"group": "second"}, - }, - { - "name": "baz", - "path": "http://localhost/baz", - "parts": {}, - }, - ], + "name": "foo", + "path": "http://localhost/foo", + "parts": {"group": "first"}, }, - json.loads(desc.get_json_string()), - ) + { + "name": "bar", + "path": "http://localhost/bar", + "parts": {"group": "second"}, + }, + { + "name": "baz", + "path": "http://localhost/baz", + "parts": {}, + }, + ], + } class MockableZenodoFetcher(datastore.ZenodoFetcher): @@ -210,7 +176,7 @@ def __init__( self._descriptor_cache = descriptors -class TestZenodoFetcher(unittest.TestCase): +class TestZenodoFetcher: """Unit tests for ZenodoFetcher class.""" MOCK_EPACEMS_DEPOSITION = { @@ -243,7 +209,8 @@ class TestZenodoFetcher(unittest.TestCase): r"^10\.(5072|5281)/zenodo\.(\d+)$", PROD_EPACEMS_DOI ).group(2) - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self): """Constructs mockable Zenodo fetcher based on MOCK_EPACEMS_DATAPACKAGE.""" self.fetcher = MockableZenodoFetcher( descriptors={ @@ -263,38 +230,36 @@ def test_doi_format_is_correct(self): identified. """ zf = datastore.ZenodoFetcher() - self.assertTrue(zf.get_known_datasets()) + assert zf.get_known_datasets() for dataset, doi in zf.zenodo_dois: - self.assertTrue( - zf.get_doi(dataset) == doi, - msg=f"Zenodo DOI for {dataset} matches result of get_doi()", + assert zf.get_doi(dataset) == doi, ( + f"Zenodo DOI for {dataset} matches result of get_doi()" ) - self.assertFalse( - re.fullmatch(r"10\.5072/zenodo\.[0-9]{5,10}", doi), - msg=f"Zenodo sandbox DOI found for {dataset}: {doi}", + assert not re.fullmatch(r"10\.5072/zenodo\.[0-9]{5,10}", doi), ( + f"Zenodo sandbox DOI found for {dataset}: {doi}" ) - self.assertTrue( - re.fullmatch(r"10\.5281/zenodo\.[0-9]{5,10}", doi), - msg=f"Zenodo production DOI for {dataset} is {doi}", + assert re.fullmatch(r"10\.5281/zenodo\.[0-9]{5,10}", doi), ( + f"Zenodo production DOI for {dataset} is {doi}" ) def test_get_known_datasets(self): """Call to get_known_datasets() produces the expected results.""" - self.assertEqual( - sorted(name for name, doi in datastore.ZenodoFetcher().zenodo_dois), - self.fetcher.get_known_datasets(), + assert ( + sorted(name for name, doi in datastore.ZenodoFetcher().zenodo_dois) + == self.fetcher.get_known_datasets() ) def test_get_unknown_dataset(self): """Ensure that we get a failure when attempting to access an invalid dataset.""" - self.assertRaises(AttributeError, self.fetcher.get_doi, "unknown") + with pytest.raises(AttributeError): + self.fetcher.get_doi("unknown") def test_doi_of_prod_epacems_matches(self): """Most of the tests assume specific DOI for production epacems dataset. This test verifies that the expected value is in use. """ - self.assertEqual(self.PROD_EPACEMS_DOI, self.fetcher.get_doi("epacems")) + assert self.fetcher.get_doi("epacems") == self.PROD_EPACEMS_DOI @responses.activate def test_get_descriptor_http_calls(self): @@ -311,8 +276,7 @@ def test_get_descriptor_http_calls(self): json=self.MOCK_EPACEMS_DATAPACKAGE, ) desc = fetcher.get_descriptor("epacems") - self.assertEqual(self.MOCK_EPACEMS_DATAPACKAGE, desc.datapackage_json) - # self.assertTrue(responses.assert_call_count("http://localhost/my/datapackage.json", 1)) + assert desc.datapackage_json == self.MOCK_EPACEMS_DATAPACKAGE @responses.activate def test_get_resource(self): @@ -321,21 +285,21 @@ def test_get_resource(self): res = self.fetcher.get_resource( PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "first") ) - self.assertEqual(b"blah", res) + assert res == b"blah" @responses.activate def test_get_resource_with_invalid_checksum(self): """Test that resource with bad checksum raises ChecksumMismatchError.""" responses.add(responses.GET, "http://localhost/first", body="wrongContent") res = PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "first") - self.assertRaises( - datastore.ChecksumMismatchError, self.fetcher.get_resource, res - ) + with pytest.raises(datastore.ChecksumMismatchError): + self.fetcher.get_resource(res) def test_get_resource_with_nonexistent_resource_fails(self): """If resource does not exist, get_resource() throws KeyError.""" res = PudlResourceKey("epacems", self.PROD_EPACEMS_DOI, "nonexistent") - self.assertRaises(KeyError, self.fetcher.get_resource, res) + with pytest.raises(KeyError): + self.fetcher.get_resource(res) def test_get_zipfile_resource_failure(mocker): @@ -415,4 +379,4 @@ def test_get_zipfile_resources_eventual_success(mocker): assert test_file.read().decode(encoding="utf-8") == file_contents -# TODO(rousik): add unit tests for Datasource class as well +# TODO: add unit tests for Datasource class as well diff --git a/test/unit/workspace/resource_cache_test.py b/test/unit/workspace/resource_cache_test.py index 95d582d680..feba668eed 100644 --- a/test/unit/workspace/resource_cache_test.py +++ b/test/unit/workspace/resource_cache_test.py @@ -2,7 +2,6 @@ import shutil import tempfile -import unittest from pathlib import Path import requests.exceptions as requests_exceptions @@ -13,7 +12,7 @@ from pudl.workspace.resource_cache import PudlResourceKey, extend_gcp_retry_predicate -class TestGoogleCloudStorageCache(unittest.TestCase): +class TestGoogleCloudStorageCache: """Unit tests for the GoogleCloudStorageCache class.""" def test_bad_request_predicate(self): @@ -21,77 +20,77 @@ def test_bad_request_predicate(self): bad_request_predicate = extend_gcp_retry_predicate(_should_retry, BadRequest) # Check default exceptions. - self.assertFalse(_should_retry(BadRequest(message="Bad request!"))) - self.assertTrue(_should_retry(requests_exceptions.Timeout())) + assert not _should_retry(BadRequest(message="Bad request!")) + assert _should_retry(requests_exceptions.Timeout()) - # Check extended predicate handles default exceptionss and BadRequest. - self.assertTrue(bad_request_predicate(requests_exceptions.Timeout())) - self.assertTrue(bad_request_predicate(BadRequest(message="Bad request!"))) + # Check extended predicate handles default exceptions and BadRequest. + assert bad_request_predicate(requests_exceptions.Timeout()) + assert bad_request_predicate(BadRequest(message="Bad request!")) -class TestLocalFileCache(unittest.TestCase): +class TestLocalFileCache: """Unit tests for the LocalFileCache class.""" - def setUp(self): + def setup_method(self): """Prepares temporary directory for storing cache contents.""" self.test_dir = tempfile.mkdtemp() self.cache = resource_cache.LocalFileCache(Path(self.test_dir)) - def tearDown(self): + def teardown_method(self): """Deletes content of the temporary directories.""" shutil.rmtree(self.test_dir) def test_add_single_resource(self): """Adding resource has expected effect on later get() and contains() calls.""" res = PudlResourceKey("ds", "doi", "file.txt") - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) self.cache.add(res, b"blah") - self.assertTrue(self.cache.contains(res)) - self.assertEqual(b"blah", self.cache.get(res)) + assert self.cache.contains(res) + assert self.cache.get(res) == b"blah" def test_that_two_cache_objects_share_storage(self): """Two LocalFileCache instances with the same path share the object storage.""" second_cache = resource_cache.LocalFileCache(Path(self.test_dir)) res = PudlResourceKey("dataset", "doi", "file.txt") - self.assertFalse(self.cache.contains(res)) - self.assertFalse(second_cache.contains(res)) + assert not self.cache.contains(res) + assert not second_cache.contains(res) self.cache.add(res, b"testContents") - self.assertTrue(self.cache.contains(res)) - self.assertTrue(second_cache.contains(res)) - self.assertEqual(b"testContents", second_cache.get(res)) + assert self.cache.contains(res) + assert second_cache.contains(res) + assert second_cache.get(res) == b"testContents" def test_deletion(self): """Deleting resources has expected effect on later get() / contains() calls.""" res = PudlResourceKey("a", "b", "c") - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) self.cache.add(res, b"sampleContents") - self.assertTrue(self.cache.contains(res)) + assert self.cache.contains(res) self.cache.delete(res) - self.assertFalse(self.cache.contains(res)) + assert not self.cache.contains(res) def test_read_only_add_and_delete_do_nothing(self): """Test that in read_only mode, add() and delete() calls are ignored.""" res = PudlResourceKey("a", "b", "c") ro_cache = resource_cache.LocalFileCache(Path(self.test_dir), read_only=True) - self.assertTrue(ro_cache.is_read_only()) + assert ro_cache.is_read_only() ro_cache.add(res, b"sample") - self.assertFalse(ro_cache.contains(res)) + assert not ro_cache.contains(res) # Use read-write cache to insert resource self.cache.add(res, b"sample") - self.assertFalse(self.cache.is_read_only()) - self.assertTrue(ro_cache.contains(res)) + assert not self.cache.is_read_only() + assert ro_cache.contains(res) # Deleting via ro cache should not happen ro_cache.delete(res) - self.assertTrue(ro_cache.contains(res)) + assert ro_cache.contains(res) -class TestLayeredCache(unittest.TestCase): +class TestLayeredCache: """Unit tests for LayeredCache class.""" - def setUp(self): + def setup_method(self): """Constructs two LocalFileCache layers pointed at temporary directories.""" self.layered_cache = resource_cache.LayeredCache() self.test_dir_1 = tempfile.mkdtemp() @@ -106,53 +105,56 @@ def tearDown(self): def test_add_caching_layers(self): """Adding layers has expected effect on the subsequent num_layers() calls.""" - self.assertEqual(0, self.layered_cache.num_layers()) + # self.assertEqual(0, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 0 self.layered_cache.add_cache_layer(self.cache_1) - self.assertEqual(1, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 1 self.layered_cache.add_cache_layer(self.cache_2) - self.assertEqual(2, self.layered_cache.num_layers()) + assert self.layered_cache.num_layers() == 2 def test_add_to_first_layer(self): """Adding to layered cache by default stores entires in the first layer.""" self.layered_cache.add_cache_layer(self.cache_1) self.layered_cache.add_cache_layer(self.cache_2) res = PudlResourceKey("a", "b", "x.txt") - self.assertFalse(self.layered_cache.contains(res)) + # self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.add(res, b"sampleContent") - self.assertTrue(self.layered_cache.contains(res)) - self.assertTrue(self.cache_1.contains(res)) - self.assertFalse(self.cache_2.contains(res)) + assert self.layered_cache.contains(res) + assert self.cache_1.contains(res) + assert not self.cache_2.contains(res) def test_get_uses_innermost_layer(self): """Resource is retrieved from the leftmost layer that contains it.""" res = PudlResourceKey("a", "b", "x.txt") self.layered_cache.add_cache_layer(self.cache_1) self.layered_cache.add_cache_layer(self.cache_2) - # self.cache_1.add(res, "firstLayer") + self.cache_1.add(res, b"firstLayer") self.cache_2.add(res, b"secondLayer") - self.assertEqual(b"secondLayer", self.layered_cache.get(res)) + # assert self.layered_cache.get(res) == b"secondLayer" + assert self.layered_cache.get(res) == b"firstLayer" self.cache_1.add(res, b"firstLayer") - self.assertEqual(b"firstLayer", self.layered_cache.get(res)) + assert self.layered_cache.get(res) == b"firstLayer" # Set on layered cache updates innermost layer self.layered_cache.add(res, b"newContents") - self.assertEqual(b"newContents", self.layered_cache.get(res)) - self.assertEqual(b"newContents", self.cache_1.get(res)) - self.assertEqual(b"secondLayer", self.cache_2.get(res)) + assert self.layered_cache.get(res) == b"newContents" + assert self.cache_1.get(res) == b"newContents" + assert self.cache_2.get(res) == b"secondLayer" # Deletion also only affects innermost layer self.layered_cache.delete(res) - self.assertTrue(self.layered_cache.contains(res)) - self.assertFalse(self.cache_1.contains(res)) - self.assertTrue(self.cache_2.contains(res)) - self.assertEqual(b"secondLayer", self.layered_cache.get(res)) + assert self.layered_cache.contains(res) + assert not self.cache_1.contains(res) + assert self.cache_2.contains(res) + assert self.cache_2.get(res) == b"secondLayer" def test_add_with_no_layers_does_nothing(self): """When add() is called on cache with no layers nothing happens.""" res = PudlResourceKey("a", "b", "c") - self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.add(res, b"sample") - self.assertFalse(self.layered_cache.contains(res)) + assert not self.layered_cache.contains(res) self.layered_cache.delete(res) def test_read_only_layers_skipped_when_adding(self): @@ -163,19 +165,19 @@ def test_read_only_layers_skipped_when_adding(self): res = PudlResourceKey("a", "b", "c") - self.assertFalse(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertFalse(c2.contains(res)) + assert not lc.contains(res) + assert not c1.contains(res) + assert not c2.contains(res) lc.add(res, b"test") - self.assertTrue(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertTrue(c2.contains(res)) + assert lc.contains(res) + assert not c1.contains(res) + assert c2.contains(res) lc.delete(res) - self.assertFalse(lc.contains(res)) - self.assertFalse(c1.contains(res)) - self.assertFalse(c2.contains(res)) + assert not lc.contains(res) + assert not c1.contains(res) + assert not c2.contains(res) def test_read_only_cache_ignores_modifications(self): """When cache is marked as read_only, add() and delete() calls are ignored.""" @@ -183,22 +185,22 @@ def test_read_only_cache_ignores_modifications(self): r2 = PudlResourceKey("a", "b", "r2") self.cache_1.add(r1, b"xxx") self.cache_2.add(r2, b"yyy") - self.assertTrue(self.cache_1.contains(r1)) - self.assertTrue(self.cache_2.contains(r2)) + assert self.cache_1.contains(r1) + assert self.cache_2.contains(r2) lc = resource_cache.LayeredCache(self.cache_1, self.cache_2, read_only=True) - self.assertTrue(lc.contains(r1)) - self.assertTrue(lc.contains(r2)) + assert lc.contains(r1) + assert lc.contains(r2) lc.delete(r1) lc.delete(r2) - self.assertTrue(lc.contains(r1)) - self.assertTrue(lc.contains(r2)) - self.assertTrue(self.cache_1.contains(r1)) - self.assertTrue(self.cache_2.contains(r2)) + assert lc.contains(r1) + assert lc.contains(r2) + assert self.cache_1.contains(r1) + assert self.cache_2.contains(r2) r_new = PudlResourceKey("a", "b", "new") lc.add(r_new, b"xyz") - self.assertFalse(lc.contains(r_new)) - self.assertFalse(self.cache_1.contains(r_new)) - self.assertFalse(self.cache_2.contains(r_new)) + assert not lc.contains(r_new) + assert not self.cache_1.contains(r_new) + assert not self.cache_2.contains(r_new)