Skip to content

Commit

Permalink
[MAINTENANCE] TableMetrics - BatchInspector updates (0.18.x) (#9630)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shinnnyshinshin authored Mar 20, 2024
1 parent 7b483d5 commit abcf671
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import uuid
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional

from great_expectations.experimental.metric_repository.metrics import (
Metric,
MetricRun,
MetricTypes,
)

if TYPE_CHECKING:
Expand All @@ -28,6 +29,34 @@ def __init__(
self._context = context
self._metric_retrievers = metric_retrievers

def compute_metric_list_run(
self,
data_asset_id: uuid.UUID,
batch_request: BatchRequest,
metric_list: Optional[List[MetricTypes]],
) -> MetricRun:
"""Method that computes a MetricRun for a list of metrics.
Called by GX Agent to compute a MetricRun as part of a RunMetricsEvent.
Args:
data_asset_id (uuid.UUID): current data asset id.
batch_request (BatchRequest): BatchRequest for current batch.
metrics_list (Optional[List[MetricTypes]]): List of metrics to compute.
Returns:
MetricRun: _description_
"""
# TODO: eventually we will keep this and retire `compute_metric_run`.
metrics: list[Metric] = []
for metric_retriever in self._metric_retrievers:
metrics.extend(
metric_retriever.get_metrics(
batch_request=batch_request, metric_list=metric_list
)
)
return MetricRun(data_asset_id=data_asset_id, metrics=metrics)

def compute_metric_run(
self, data_asset_id: uuid.UUID, batch_request: BatchRequest
) -> MetricRun:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING, List, Sequence
from typing import TYPE_CHECKING, List, Optional, Sequence

from great_expectations.compatibility.typing_extensions import override
from great_expectations.experimental.metric_repository.metric_retriever import (
Expand All @@ -10,6 +10,7 @@
from great_expectations.experimental.metric_repository.metrics import (
ColumnMetric,
Metric,
MetricTypes,
)

if TYPE_CHECKING:
Expand All @@ -27,7 +28,14 @@ def __init__(self, context: AbstractDataContext):
super().__init__(context=context)

@override
def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]:
def get_metrics(
self,
batch_request: BatchRequest,
metric_list: Optional[List[MetricTypes]] = None,
) -> Sequence[Metric]:
# Note: Signature includes metric_list for compatibility with the MetricRetriever interface.
# It is not used by ColumnDescriptiveMetricsMetricRetriever.

table_metrics = self._calculate_table_metrics(batch_request)

# We need to skip columns that do not report a type, because the metric computation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from itertools import chain
from typing import TYPE_CHECKING, List, Optional, Sequence

Expand All @@ -13,6 +14,9 @@
MetricTypes,
)

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from great_expectations.data_context import AbstractDataContext
from great_expectations.datasource.fluent.batch_request import BatchRequest
Expand Down Expand Up @@ -44,8 +48,14 @@ def get_metrics(
)
metrics_result.extend(table_metrics)

# exit early if only Table Metrics exist
if not self._column_metrics_in_metric_list(metric_list):
# if no column metrics are present in the metric list, we can return the table metrics
return metrics_result

if MetricTypes.TABLE_COLUMN_TYPES not in metric_list:
logger.warning(
"TABLE_COLUMN_TYPES metric is required to compute column metrics. Skipping column metrics."
)
return metrics_result

table_column_types = list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
List,
Optional,
Sequence,
)

Expand Down Expand Up @@ -50,19 +51,28 @@ def get_validator(self, batch_request: BatchRequest) -> Validator:
return self._validator

@abc.abstractmethod
def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]:
def get_metrics(
self,
batch_request: BatchRequest,
metric_list: Optional[List[MetricTypes]] = None,
) -> Sequence[Metric]:
raise NotImplementedError

def _generate_metric_id(self) -> uuid.UUID:
return uuid.uuid4()

def _get_metric_from_computed_metrics(
self,
metric_name: str,
metric_name: str | MetricTypes,
computed_metrics: _MetricsDict,
aborted_metrics: _AbortedMetricsInfoDict,
metric_lookup_key: _MetricKey | None = None,
) -> tuple[Any, MetricException | None]:
# look up is done by string
# TODO: update to be MetricTypes once MetricListMetricRetriever implementation is complete.
if isinstance(metric_name, MetricTypes):
metric_name = metric_name.value

if metric_lookup_key is None:
metric_lookup_key = (
metric_name,
Expand Down Expand Up @@ -91,7 +101,7 @@ def _get_metric_from_computed_metrics(
return value, metric_exception

def _generate_table_metric_configurations(
self, table_metric_names: list[str]
self, table_metric_names: list[str | MetricTypes]
) -> list[MetricConfiguration]:
table_metric_configs = [
MetricConfiguration(
Expand Down Expand Up @@ -271,7 +281,6 @@ def _get_table_columns(self, batch_request: BatchRequest) -> Metric:

def _get_table_column_types(self, batch_request: BatchRequest) -> Metric:
metric_name = MetricTypes.TABLE_COLUMN_TYPES

metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True")
table_metric_configs = self._generate_table_metric_configurations(
table_metric_names=[metric_name]
Expand Down
9 changes: 6 additions & 3 deletions great_expectations/validator/metric_configuration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from great_expectations.core._docs_decorators import public_api
from great_expectations.core.domain import Domain
from great_expectations.core.id_dict import IDDict
from great_expectations.core.metric_domain_types import MetricDomainTypes
from great_expectations.core.util import convert_to_json_serializable
from great_expectations.experimental.metric_repository.metrics import MetricTypes


@public_api
Expand All @@ -17,7 +18,7 @@ class MetricConfiguration:
be used to evaluate Expectations or to summarize the result of the Validation.
Args:
metric_name (str): name of the Metric defined by the current MetricConfiguration.
metric_name (str or MetricTypes enum): name of the Metric defined by the current MetricConfiguration.
metric_domain_kwargs (dict): provides information on where the Metric can be calculated. For instance, a
MapCondition metric can include the name of the column that the Metric is going to be run on.
metric_value_kwargs (optional[dict]): Optional kwargs that define values specific to each Metric. For instance,
Expand All @@ -27,10 +28,12 @@ class MetricConfiguration:

def __init__(
self,
metric_name: str,
metric_name: Union[str, MetricTypes],
metric_domain_kwargs: dict,
metric_value_kwargs: Optional[dict] = None,
) -> None:
if isinstance(metric_name, MetricTypes):
metric_name = metric_name.value
self._metric_name = metric_name

if not isinstance(metric_domain_kwargs, IDDict):
Expand Down
85 changes: 85 additions & 0 deletions tests/experimental/metric_repository/test_batch_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,97 @@
)
from great_expectations.experimental.metric_repository.metrics import (
MetricRun,
MetricTypes,
TableMetric,
)

pytestmark = pytest.mark.unit


# compute_metric_list_run tests
def test_compute_metric_list_run_with_no_metric_retrievers(mocker):
mock_context = Mock(spec=CloudDataContext)
batch_inspector = BatchInspector(context=mock_context, metric_retrievers=[])
mock_batch_request = Mock(spec=BatchRequest)

data_asset_id = uuid.uuid4()

metric_run = batch_inspector.compute_metric_list_run(
data_asset_id=data_asset_id, batch_request=mock_batch_request, metric_list=[]
)
assert metric_run == MetricRun(data_asset_id=data_asset_id, metrics=[])


def test_compute_metric_list_run_calls_metric_retrievers():
mock_context = Mock(spec=CloudDataContext)
mock_metric_retriever = MagicMock(spec=MetricRetriever)
batch_inspector = BatchInspector(
context=mock_context, metric_retrievers=[mock_metric_retriever]
)
mock_batch_request = Mock(spec=BatchRequest)

data_asset_id = uuid.uuid4()

metric_list = [
MetricTypes.TABLE_ROW_COUNT,
MetricTypes.TABLE_COLUMNS,
MetricTypes.TABLE_COLUMN_TYPES,
MetricTypes.COLUMN_MIN,
MetricTypes.COLUMN_MAX,
MetricTypes.COLUMN_MEAN,
MetricTypes.COLUMN_MEDIAN,
MetricTypes.COLUMN_NULL_COUNT,
]

batch_inspector.compute_metric_list_run(
data_asset_id=data_asset_id,
batch_request=mock_batch_request,
metric_list=metric_list,
)

assert mock_metric_retriever.get_metrics.call_count == 1

mock_metric_retriever.get_metrics.assert_called_once_with(
batch_request=mock_batch_request, metric_list=metric_list
)


def test_compute_metric_list_run_returns_metric_run():
mock_context = Mock(spec=CloudDataContext)
mock_metric_retriever = MagicMock(spec=MetricRetriever)

mock_metric = Mock(spec=TableMetric)
mock_metric_retriever.get_metrics.return_value = [mock_metric]

batch_inspector = BatchInspector(
context=mock_context, metric_retrievers=[mock_metric_retriever]
)
mock_batch_request = Mock(spec=BatchRequest)

data_asset_id = uuid.uuid4()

metric_run = batch_inspector.compute_metric_list_run(
data_asset_id=data_asset_id,
batch_request=mock_batch_request,
metric_list=[
MetricTypes.TABLE_ROW_COUNT,
MetricTypes.TABLE_COLUMNS,
MetricTypes.TABLE_COLUMN_TYPES,
MetricTypes.COLUMN_MIN,
MetricTypes.COLUMN_MAX,
MetricTypes.COLUMN_MEAN,
MetricTypes.COLUMN_MEDIAN,
MetricTypes.COLUMN_NULL_COUNT,
],
)

assert metric_run == MetricRun(
data_asset_id=data_asset_id,
metrics=[mock_metric],
)


# compute_metric_run tests. Will eventually go away once compute_metric_list_run is fully implemented.
def test_compute_metric_run_with_no_metric_retrievers():
mock_context = Mock(spec=CloudDataContext)
batch_inspector = BatchInspector(context=mock_context, metric_retrievers=[])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@

pytestmark = pytest.mark.unit

import logging

from pytest_mock import MockerFixture

LOGGER = logging.getLogger(__name__)


def test_get_metrics_table_metrics_only(mocker: MockerFixture):
mock_context = mocker.Mock(spec=CloudDataContext)
Expand Down Expand Up @@ -232,6 +236,69 @@ def test_get_metrics_full_list(mocker: MockerFixture):
]


def test_column_metrics_not_returned_of_column_types_missing(
mocker: MockerFixture, caplog
):
mock_context = mocker.Mock(spec=CloudDataContext)
mock_validator = mocker.Mock(spec=Validator)
mock_context.get_validator.return_value = mock_validator
computed_metrics = {
("table.row_count", (), ()): 2,
("table.columns", (), ()): ["timestamp_col"],
}
cdm_metrics_list: List[MetricTypes] = [
MetricTypes.TABLE_ROW_COUNT,
MetricTypes.TABLE_COLUMNS,
# MetricTypes.TABLE_COLUMN_TYPES,
MetricTypes.COLUMN_MIN,
MetricTypes.COLUMN_MAX,
MetricTypes.COLUMN_NULL_COUNT,
]
aborted_metrics = {}
mock_validator.compute_metrics.return_value = (
computed_metrics,
aborted_metrics,
)
mock_batch = mocker.Mock(spec=Batch)
mock_batch.id = "batch_id"
mock_validator.active_batch = mock_batch

metric_retriever = MetricListMetricRetriever(context=mock_context)

mock_batch_request = mocker.Mock(spec=BatchRequest)

mocker.patch(
f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names",
return_value=[],
)
mocker.patch(
f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names",
return_value=["timestamp_col"],
)
metrics = metric_retriever.get_metrics(
batch_request=mock_batch_request, metric_list=cdm_metrics_list
)

assert metrics == [
TableMetric[int](
batch_id="batch_id",
metric_name="table.row_count",
value=2,
exception=None,
),
TableMetric[List[str]](
batch_id="batch_id",
metric_name="table.columns",
value=["timestamp_col"],
exception=None,
),
]
assert (
"TABLE_COLUMN_TYPES metric is required to compute column metrics. Skipping column metrics."
in caplog.text
)


def test_get_metrics_metrics_missing(mocker: MockerFixture):
"""This test is meant to simulate metrics missing from the computed metrics."""
mock_context = mocker.Mock(spec=CloudDataContext)
Expand Down

0 comments on commit abcf671

Please sign in to comment.