Skip to content

Commit

Permalink
Allow field overrides via Annotated
Browse files Browse the repository at this point in the history
  • Loading branch information
slykar committed Nov 24, 2024
1 parent dbe138b commit 7c98e50
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
8 changes: 6 additions & 2 deletions src/cattrs/gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ._consts import AttributeOverride, already_generating, neutral
from ._generics import generate_mapping
from ._lc import generate_unique_filename
from ._shared import find_structure_handler
from ._shared import find_structure_handler, get_fields_annotated_by

if TYPE_CHECKING:
from ..converters import BaseConverter
Expand Down Expand Up @@ -264,6 +264,10 @@ def make_dict_unstructure_fn(

working_set.add(cl)

# Merge overrides provided via Annotated with kwargs
annotated_overrides = get_fields_annotated_by(cl, AttributeOverride)
annotated_overrides.update(kwargs)

try:
return make_dict_unstructure_fn_from_attrs(
attrs,
Expand All @@ -274,7 +278,7 @@ def make_dict_unstructure_fn(
_cattrs_use_linecache=_cattrs_use_linecache,
_cattrs_use_alias=_cattrs_use_alias,
_cattrs_include_init_false=_cattrs_include_init_false,
**kwargs,
**annotated_overrides,
)
finally:
working_set.remove(cl)
Expand Down
29 changes: 28 additions & 1 deletion src/cattrs/gen/_shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar, get_type_hints

from attrs import NOTHING, Attribute, Factory

Expand All @@ -10,8 +10,10 @@
from ..fns import raise_error

if TYPE_CHECKING:
from collections.abc import Mapping
from ..converters import BaseConverter

T = TypeVar("T")

def find_structure_handler(
a: Attribute, type: Any, c: BaseConverter, prefer_attrs_converters: bool = False
Expand Down Expand Up @@ -62,3 +64,28 @@ def handler(v, _, _h=handler):
except RecursionError:
# This means we're dealing with a reference cycle, so use late binding.
return c.structure


def get_fields_annotated_by(cls: type, annotation_type: type[T] | T) -> dict[str, T]:
type_hints = get_type_hints(cls, include_extras=True)
# Support for both AttributeOverride and AttributeOverride()
annotation_type_ = annotation_type if isinstance(annotation_type, type) else type(annotation_type)

# First pass of filtering to get only fields with annotations
fields_with_annotations = (
(field_name, param_spec.__metadata__)
for field_name, param_spec in type_hints.items()
if hasattr(param_spec, "__metadata__")
)

# Now that we have fields with ANY annotations, we need to remove unwanted annotations.
fields_with_specific_annotation = (
(
field_name,
next((a for a in annotations if isinstance(a, annotation_type_)), None),
)
for field_name, annotations in fields_with_annotations
)

# We still might have some `None` values from previous filtering.
return {field_name: annotation for field_name, annotation in fields_with_specific_annotation if annotation}
72 changes: 72 additions & 0 deletions tests/test_annotated_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Annotated

import attrs
import pytest

from cattrs.gen._shared import get_fields_annotated_by


class NotThere: ...


class IgnoreMe:
def __init__(self, why: str | None = None):
self.why = why


class FindMe:
def __init__(self, taint: str):
self.taint = taint


class EmptyClassExample:
pass


class PureClassExample:
id: Annotated[int, FindMe("red")]
name: Annotated[str, FindMe]


class MultipleAnnotationsExample:
id: Annotated[int, FindMe("red"), IgnoreMe()]
name: Annotated[str, IgnoreMe()]
surface: Annotated[str, IgnoreMe("sorry"), FindMe("shiny")]


@attrs.define
class AttrsClassExample:
id: int = attrs.field(default=0)
color: Annotated[str, FindMe("blue")] = attrs.field(default="red")
config: Annotated[dict, FindMe("required")] = attrs.field(factory=dict)


class PureClassInheritanceExample(PureClassExample):
include: dict
exclude: Annotated[dict, FindMe("boring things")]
extras: Annotated[dict, FindMe]


@pytest.mark.parametrize(
"klass,expected",
[
(EmptyClassExample, {}),
(PureClassExample, {"id": isinstance}),
(AttrsClassExample, {"color": isinstance, "config": isinstance}),
(MultipleAnnotationsExample, {"id": isinstance, "surface": isinstance}),
(PureClassInheritanceExample, {"id": isinstance, "exclude": isinstance}),
],
)
@pytest.mark.parametrize("instantiate", [True, False])
def test_gets_annotated_types(klass, expected, instantiate: bool):
annotated = get_fields_annotated_by(klass, FindMe("irrelevant") if instantiate else FindMe)

assert set(annotated.keys()) == set(expected.keys()), "Too many or too few annotations"
assert all(
assertion_func(annotated[field_name], FindMe) for field_name, assertion_func in expected.items()
), "Unexpected type of annotation"


def test_empty_result_for_missing_annotation():
annotated = get_fields_annotated_by(MultipleAnnotationsExample, NotThere)
assert not annotated, "No annotation should be found."

0 comments on commit 7c98e50

Please sign in to comment.