diff --git a/great_expectations/experimental/metric_repository/batch_inspector.py b/great_expectations/experimental/metric_repository/batch_inspector.py index 71f92bfbc047..1c98e9be3477 100644 --- a/great_expectations/experimental/metric_repository/batch_inspector.py +++ b/great_expectations/experimental/metric_repository/batch_inspector.py @@ -1,11 +1,12 @@ from __future__ import annotations import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from great_expectations.experimental.metric_repository.metrics import ( Metric, MetricRun, + MetricTypes, ) if TYPE_CHECKING: @@ -28,6 +29,34 @@ def __init__( self._context = context self._metric_retrievers = metric_retrievers + def compute_metric_list_run( + self, + data_asset_id: uuid.UUID, + batch_request: BatchRequest, + metric_list: Optional[List[MetricTypes]], + ) -> MetricRun: + """Method that computes a MetricRun for a list of metrics. + + Called by GX Agent to compute a MetricRun as part of a RunMetricsEvent. + + Args: + data_asset_id (uuid.UUID): current data asset id. + batch_request (BatchRequest): BatchRequest for current batch. + metrics_list (Optional[List[MetricTypes]]): List of metrics to compute. + + Returns: + MetricRun: _description_ + """ + # TODO: eventually we will keep this and retire `compute_metric_run`. + metrics: list[Metric] = [] + for metric_retriever in self._metric_retrievers: + metrics.extend( + metric_retriever.get_metrics( + batch_request=batch_request, metric_list=metric_list + ) + ) + return MetricRun(data_asset_id=data_asset_id, metrics=metrics) + def compute_metric_run( self, data_asset_id: uuid.UUID, batch_request: BatchRequest ) -> MetricRun: diff --git a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py index ae8fa6dc4cf7..5309049281c7 100644 --- a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, List, Sequence +from typing import TYPE_CHECKING, List, Optional, Sequence from great_expectations.compatibility.typing_extensions import override from great_expectations.experimental.metric_repository.metric_retriever import ( @@ -10,6 +10,7 @@ from great_expectations.experimental.metric_repository.metrics import ( ColumnMetric, Metric, + MetricTypes, ) if TYPE_CHECKING: @@ -27,7 +28,14 @@ def __init__(self, context: AbstractDataContext): super().__init__(context=context) @override - def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: + def get_metrics( + self, + batch_request: BatchRequest, + metric_list: Optional[List[MetricTypes]] = None, + ) -> Sequence[Metric]: + # Note: Signature includes metric_list for compatibility with the MetricRetriever interface. + # It is not used by ColumnDescriptiveMetricsMetricRetriever. + table_metrics = self._calculate_table_metrics(batch_request) # We need to skip columns that do not report a type, because the metric computation diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index efbb08e00991..635216b40003 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from itertools import chain from typing import TYPE_CHECKING, List, Optional, Sequence @@ -13,6 +14,9 @@ MetricTypes, ) +logger = logging.getLogger(__name__) + + if TYPE_CHECKING: from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent.batch_request import BatchRequest @@ -44,8 +48,14 @@ def get_metrics( ) metrics_result.extend(table_metrics) - # exit early if only Table Metrics exist if not self._column_metrics_in_metric_list(metric_list): + # if no column metrics are present in the metric list, we can return the table metrics + return metrics_result + + if MetricTypes.TABLE_COLUMN_TYPES not in metric_list: + logger.warning( + "TABLE_COLUMN_TYPES metric is required to compute column metrics. Skipping column metrics." + ) return metrics_result table_column_types = list( diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index 7b1a472c28e1..d860665e7c03 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, List, + Optional, Sequence, ) @@ -50,7 +51,11 @@ def get_validator(self, batch_request: BatchRequest) -> Validator: return self._validator @abc.abstractmethod - def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: + def get_metrics( + self, + batch_request: BatchRequest, + metric_list: Optional[List[MetricTypes]] = None, + ) -> Sequence[Metric]: raise NotImplementedError def _generate_metric_id(self) -> uuid.UUID: @@ -58,11 +63,16 @@ def _generate_metric_id(self) -> uuid.UUID: def _get_metric_from_computed_metrics( self, - metric_name: str, + metric_name: str | MetricTypes, computed_metrics: _MetricsDict, aborted_metrics: _AbortedMetricsInfoDict, metric_lookup_key: _MetricKey | None = None, ) -> tuple[Any, MetricException | None]: + # look up is done by string + # TODO: update to be MetricTypes once MetricListMetricRetriever implementation is complete. + if isinstance(metric_name, MetricTypes): + metric_name = metric_name.value + if metric_lookup_key is None: metric_lookup_key = ( metric_name, @@ -91,7 +101,7 @@ def _get_metric_from_computed_metrics( return value, metric_exception def _generate_table_metric_configurations( - self, table_metric_names: list[str] + self, table_metric_names: list[str | MetricTypes] ) -> list[MetricConfiguration]: table_metric_configs = [ MetricConfiguration( @@ -271,7 +281,6 @@ def _get_table_columns(self, batch_request: BatchRequest) -> Metric: def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: metric_name = MetricTypes.TABLE_COLUMN_TYPES - metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") table_metric_configs = self._generate_table_metric_configurations( table_metric_names=[metric_name] diff --git a/great_expectations/validator/metric_configuration.py b/great_expectations/validator/metric_configuration.py index 71289105eb3f..b19526dc6772 100644 --- a/great_expectations/validator/metric_configuration.py +++ b/great_expectations/validator/metric_configuration.py @@ -1,11 +1,12 @@ import json -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from great_expectations.core._docs_decorators import public_api from great_expectations.core.domain import Domain from great_expectations.core.id_dict import IDDict from great_expectations.core.metric_domain_types import MetricDomainTypes from great_expectations.core.util import convert_to_json_serializable +from great_expectations.experimental.metric_repository.metrics import MetricTypes @public_api @@ -17,7 +18,7 @@ class MetricConfiguration: be used to evaluate Expectations or to summarize the result of the Validation. Args: - metric_name (str): name of the Metric defined by the current MetricConfiguration. + metric_name (str or MetricTypes enum): name of the Metric defined by the current MetricConfiguration. metric_domain_kwargs (dict): provides information on where the Metric can be calculated. For instance, a MapCondition metric can include the name of the column that the Metric is going to be run on. metric_value_kwargs (optional[dict]): Optional kwargs that define values specific to each Metric. For instance, @@ -27,10 +28,12 @@ class MetricConfiguration: def __init__( self, - metric_name: str, + metric_name: Union[str, MetricTypes], metric_domain_kwargs: dict, metric_value_kwargs: Optional[dict] = None, ) -> None: + if isinstance(metric_name, MetricTypes): + metric_name = metric_name.value self._metric_name = metric_name if not isinstance(metric_domain_kwargs, IDDict): diff --git a/tests/experimental/metric_repository/test_batch_inspector.py b/tests/experimental/metric_repository/test_batch_inspector.py index b30f5fbdc591..10c83ecbd115 100644 --- a/tests/experimental/metric_repository/test_batch_inspector.py +++ b/tests/experimental/metric_repository/test_batch_inspector.py @@ -13,12 +13,97 @@ ) from great_expectations.experimental.metric_repository.metrics import ( MetricRun, + MetricTypes, TableMetric, ) pytestmark = pytest.mark.unit +# compute_metric_list_run tests +def test_compute_metric_list_run_with_no_metric_retrievers(mocker): + mock_context = Mock(spec=CloudDataContext) + batch_inspector = BatchInspector(context=mock_context, metric_retrievers=[]) + mock_batch_request = Mock(spec=BatchRequest) + + data_asset_id = uuid.uuid4() + + metric_run = batch_inspector.compute_metric_list_run( + data_asset_id=data_asset_id, batch_request=mock_batch_request, metric_list=[] + ) + assert metric_run == MetricRun(data_asset_id=data_asset_id, metrics=[]) + + +def test_compute_metric_list_run_calls_metric_retrievers(): + mock_context = Mock(spec=CloudDataContext) + mock_metric_retriever = MagicMock(spec=MetricRetriever) + batch_inspector = BatchInspector( + context=mock_context, metric_retrievers=[mock_metric_retriever] + ) + mock_batch_request = Mock(spec=BatchRequest) + + data_asset_id = uuid.uuid4() + + metric_list = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_NULL_COUNT, + ] + + batch_inspector.compute_metric_list_run( + data_asset_id=data_asset_id, + batch_request=mock_batch_request, + metric_list=metric_list, + ) + + assert mock_metric_retriever.get_metrics.call_count == 1 + + mock_metric_retriever.get_metrics.assert_called_once_with( + batch_request=mock_batch_request, metric_list=metric_list + ) + + +def test_compute_metric_list_run_returns_metric_run(): + mock_context = Mock(spec=CloudDataContext) + mock_metric_retriever = MagicMock(spec=MetricRetriever) + + mock_metric = Mock(spec=TableMetric) + mock_metric_retriever.get_metrics.return_value = [mock_metric] + + batch_inspector = BatchInspector( + context=mock_context, metric_retrievers=[mock_metric_retriever] + ) + mock_batch_request = Mock(spec=BatchRequest) + + data_asset_id = uuid.uuid4() + + metric_run = batch_inspector.compute_metric_list_run( + data_asset_id=data_asset_id, + batch_request=mock_batch_request, + metric_list=[ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_NULL_COUNT, + ], + ) + + assert metric_run == MetricRun( + data_asset_id=data_asset_id, + metrics=[mock_metric], + ) + + +# compute_metric_run tests. Will eventually go away once compute_metric_list_run is fully implemented. def test_compute_metric_run_with_no_metric_retrievers(): mock_context = Mock(spec=CloudDataContext) batch_inspector = BatchInspector(context=mock_context, metric_retrievers=[]) diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py index a6c0c3919262..b49cbdccb3d5 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -20,8 +20,12 @@ pytestmark = pytest.mark.unit +import logging + from pytest_mock import MockerFixture +LOGGER = logging.getLogger(__name__) + def test_get_metrics_table_metrics_only(mocker: MockerFixture): mock_context = mocker.Mock(spec=CloudDataContext) @@ -232,6 +236,69 @@ def test_get_metrics_full_list(mocker: MockerFixture): ] +def test_column_metrics_not_returned_of_column_types_missing( + mocker: MockerFixture, caplog +): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + computed_metrics = { + ("table.row_count", (), ()): 2, + ("table.columns", (), ()): ["timestamp_col"], + } + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + # MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_NULL_COUNT, + ] + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = mocker.Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = mocker.Mock(spec=BatchRequest) + + mocker.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", + return_value=[], + ) + mocker.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", + return_value=["timestamp_col"], + ) + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=2, + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["timestamp_col"], + exception=None, + ), + ] + assert ( + "TABLE_COLUMN_TYPES metric is required to compute column metrics. Skipping column metrics." + in caplog.text + ) + + def test_get_metrics_metrics_missing(mocker: MockerFixture): """This test is meant to simulate metrics missing from the computed metrics.""" mock_context = mocker.Mock(spec=CloudDataContext)