From 04e152f879f64cba0233fa538146c726c1f4407e Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Thu, 18 Apr 2024 11:41:06 +0200 Subject: [PATCH 01/28] Do not assume that default manager is named `objects` in `join()` --- model_utils/managers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_utils/managers.py b/model_utils/managers.py index aaa4be84..3ca58267 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -331,7 +331,7 @@ def join(self, qs=None): else: fk_column = 'id' qs = self.only(fk_column) - new_qs = self.model.objects.all() + new_qs = self.model._default_manager.all() TABLE_NAME = 'temp_stuff' query, params = qs.query.sql_with_params() From db9336031e3b8fe30924dbcd8034ca5396e6a871 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 11 Jun 2024 19:46:51 +0200 Subject: [PATCH 02/28] Call `Field.get_default()` instead of `_get_default()` We can call the inherited public method using `super()` instead of calling the private method directly. --- model_utils/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_utils/fields.py b/model_utils/fields.py index 69188bba..b392ff0b 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -34,7 +34,7 @@ class AutoLastModifiedField(AutoCreatedField): def get_default(self): """Return the default value for this field.""" if not hasattr(self, "_default"): - self._default = self._get_default() + self._default = super().get_default() return self._default def pre_save(self, model_instance, add): From 885da698d57f6edfa52c8340f6343394f9db6d35 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 1 May 2024 17:24:18 +0200 Subject: [PATCH 03/28] Upgrade mypy and django-stubs --- requirements-mypy.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-mypy.txt b/requirements-mypy.txt index c0422a75..65a16c5c 100644 --- a/requirements-mypy.txt +++ b/requirements-mypy.txt @@ -1,2 +1,2 @@ -mypy==1.9.0 -django-stubs==4.2.7 +mypy==1.10.0 +django-stubs==5.0.2 From 0043fedf463daa816563cc6055d0a5874cce1d3c Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Thu, 16 Mar 2023 18:19:56 +0100 Subject: [PATCH 04/28] Enable postponed evaluation of annotations for all source modules This allows using the latest annotation syntax supported by the type checker regardless of the runtime Python version. --- model_utils/choices.py | 2 ++ model_utils/fields.py | 2 ++ model_utils/managers.py | 2 ++ model_utils/models.py | 2 ++ model_utils/tracker.py | 2 ++ 5 files changed, 10 insertions(+) diff --git a/model_utils/choices.py b/model_utils/choices.py index 46c3877c..c6cc438f 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy diff --git a/model_utils/fields.py b/model_utils/fields.py index b392ff0b..788d07cb 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import secrets import uuid diff --git a/model_utils/managers.py b/model_utils/managers.py index 3ca58267..10e9c349 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from django.core.exceptions import ObjectDoesNotExist diff --git a/model_utils/models.py b/model_utils/models.py index 4eb0e63d..f5b717f1 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.db.models.functions import Now diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 118d1acf..7c912b04 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from copy import deepcopy from functools import wraps From 632441ea539f3243f39b7b84361e2d51b86b9231 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Thu, 16 Mar 2023 18:24:33 +0100 Subject: [PATCH 05/28] Require full annotation in mypy configuration --- mypy.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy.ini b/mypy.ini index 918bc713..d0264e95 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,6 @@ [mypy] +disallow_incomplete_defs=True +disallow_untyped_defs=True implicit_reexport=False pretty=True show_error_codes=True From 56ea5272861a9e82215d6c903a38d9ead9a82849 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 17 Mar 2023 12:35:37 +0100 Subject: [PATCH 06/28] Annotate the `tracker` module --- model_utils/tracker.py | 149 ++++++++++++++++++++++++++++------------- 1 file changed, 101 insertions(+), 48 deletions(-) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 7c912b04..d89f4a8b 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,11 +2,23 @@ from copy import deepcopy from functools import wraps +from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload from django.core.exceptions import FieldError from django.db import models from django.db.models.fields.files import FieldFile +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from types import TracebackType + + class _AugmentedModel(models.Model): + _instance_initialized: bool + _deferred_fields: set[str] + + +T = TypeVar("T") + class LightStateFieldFile(FieldFile): """ @@ -18,7 +30,7 @@ class LightStateFieldFile(FieldFile): Django 3.1+ can make the app unusable, as CPU and memory usage gets easily multiplied by magnitudes. """ - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: """ We don't need to deepcopy the instance, so nullify if provided. """ @@ -28,27 +40,35 @@ def __getstate__(self): return state -def lightweight_deepcopy(value): +def lightweight_deepcopy(value: T) -> T: """ Use our lightweight class to avoid copying the instance on a FieldFile deepcopy. """ if isinstance(value, FieldFile): - value = LightStateFieldFile( + value = cast(T, LightStateFieldFile( instance=value.instance, field=value.field, name=value.name, - ) + )) return deepcopy(value) class DescriptorWrapper: - def __init__(self, field_name, descriptor, tracker_attname): + def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str): self.field_name = field_name self.descriptor = descriptor self.tracker_attname = tracker_attname - def __get__(self, instance, owner): + @overload + def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper: + ... + + @overload + def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field: + ... + + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field: if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() @@ -58,7 +78,7 @@ def __get__(self, instance, owner): tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value) return value - def __set__(self, instance, value): + def __set__(self, instance: models.Model, value: models.Field) -> None: initialized = hasattr(instance, '_instance_initialized') was_deferred = self.field_name in instance.get_deferred_fields() @@ -81,11 +101,11 @@ def __set__(self, instance, value): else: instance.__dict__[self.field_name] = value - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> models.Field: return getattr(self.descriptor, attr) @staticmethod - def cls_for_descriptor(descriptor): + def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]: if hasattr(descriptor, '__delete__'): return FullDescriptorWrapper else: @@ -96,8 +116,8 @@ class FullDescriptorWrapper(DescriptorWrapper): """ Wrapper for descriptors with all three descriptor methods. """ - def __delete__(self, obj): - self.descriptor.__delete__(obj) + def __delete__(self, obj: models.Field) -> None: + self.descriptor.__delete__(obj) # type: ignore[attr-defined] class FieldsContext: @@ -121,7 +141,12 @@ class FieldsContext: """ - def __init__(self, tracker, *fields, state=None): + def __init__( + self, + tracker: FieldInstanceTracker, + *fields: str, + state: dict[str, int] | None = None + ): """ :param tracker: FieldInstanceTracker instance to be reset after context exit @@ -139,7 +164,7 @@ def __init__(self, tracker, *fields, state=None): self.fields = fields self.state = state - def __enter__(self): + def __enter__(self) -> FieldsContext: """ Increments tracked fields occurrences count in shared state. """ @@ -148,7 +173,12 @@ def __enter__(self): self.state[f] += 1 return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: """ Decrements tracked fields occurrences count in shared state. @@ -166,29 +196,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): class FieldInstanceTracker: - def __init__(self, instance, fields, field_map): - self.instance = instance + def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]): + self.instance = cast('_AugmentedModel', instance) self.fields = fields self.field_map = field_map self.context = FieldsContext(self, *self.fields) - def __enter__(self): + def __enter__(self) -> FieldsContext: return self.context.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: return self.context.__exit__(exc_type, exc_val, exc_tb) - def __call__(self, *fields): + def __call__(self, *fields: str) -> FieldsContext: return FieldsContext(self, *fields, state=self.context.state) @property - def deferred_fields(self): + def deferred_fields(self) -> set[str]: return self.instance.get_deferred_fields() - def get_field_value(self, field): + def get_field_value(self, field: str) -> Any: return getattr(self.instance, self.field_map[field]) - def set_saved_fields(self, fields=None): + def set_saved_fields(self, fields: Iterable[str] | None = None) -> None: if not self.instance.pk: self.saved_data = {} elif fields is None: @@ -200,7 +235,7 @@ def set_saved_fields(self, fields=None): for field, field_value in self.saved_data.items(): self.saved_data[field] = lightweight_deepcopy(field_value) - def current(self, fields=None): + def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]: """Returns dict of current values for all tracked fields""" if fields is None: deferred_fields = self.deferred_fields @@ -214,7 +249,7 @@ def current(self, fields=None): return {f: self.get_field_value(f) for f in fields} - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: # deferred fields haven't changed @@ -224,7 +259,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def previous(self, field): + def previous(self, field: str) -> Any: """Returns currently saved value of given field""" # handle deferred fields that have not yet been loaded from the database @@ -244,7 +279,7 @@ def previous(self, field): return self.saved_data.get(field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" return { field: self.previous(field) @@ -257,13 +292,18 @@ class FieldTracker: tracker_class = FieldInstanceTracker - def __init__(self, fields=None): - self.fields = fields + def __init__(self, fields: Iterable[str] | None = None): + # finalize_class() will replace None; pretend it is never None. + self.fields = cast(Iterable[str], fields) - def __call__(self, func=None, fields=None): - def decorator(f): + def __call__( + self, + func: Callable | None = None, + fields: Iterable[str] | None = None + ) -> Any: + def decorator(f: Callable) -> Callable: @wraps(f) - def inner(obj, *args, **kwargs): + def inner(obj: models.Model, *args: object, **kwargs: object) -> object: tracker = getattr(obj, self.attname) field_list = tracker.fields if fields is None else fields with tracker(*field_list): @@ -274,7 +314,7 @@ def inner(obj, *args, **kwargs): return decorator return decorator(func) - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: """Returns dict mapping fields names to model attribute names""" field_map = {field: field for field in self.fields} all_fields = {f.name: f.attname for f in cls._meta.fields} @@ -282,17 +322,17 @@ def get_field_map(self, cls): if k in field_map}) return field_map - def contribute_to_class(self, cls, name): + def contribute_to_class(self, cls: type[models.Model], name: str) -> None: self.name = name self.attname = '_%s' % name models.signals.class_prepared.connect(self.finalize_class, sender=cls) - def finalize_class(self, sender, **kwargs): - if self.fields is None: + def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None: + if self.fields is None or TYPE_CHECKING: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) for field_name in self.fields: - descriptor = getattr(sender, field_name) + descriptor: models.Field = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) @@ -302,34 +342,39 @@ def finalize_class(self, sender, **kwargs): setattr(sender, self.name, self) self.patch_save(sender) - def initialize_tracker(self, sender, instance, **kwargs): + def initialize_tracker( + self, + sender: type[models.Model], + instance: models.Model, + **kwargs: object + ) -> None: if not isinstance(instance, self.model_class): return # Only init instances of given model (including children) tracker = self.tracker_class(instance, self.fields, self.field_map) setattr(instance, self.attname, tracker) tracker.set_saved_fields() - instance._instance_initialized = True + cast('_AugmentedModel', instance)._instance_initialized = True - def patch_init(self, model): + def patch_init(self, model: type[models.Model]) -> None: original = getattr(model, '__init__') @wraps(original) - def inner(instance, *args, **kwargs): + def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None: original(instance, *args, **kwargs) self.initialize_tracker(model, instance) setattr(model, '__init__', inner) - def patch_save(self, model): + def patch_save(self, model: type[models.Model]) -> None: self._patch(model, 'save_base', 'update_fields') self._patch(model, 'refresh_from_db', 'fields') - def _patch(self, model, method, fields_kwarg): + def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None: original = getattr(model, method) @wraps(original) - def inner(instance, *args, **kwargs): - update_fields = kwargs.get(fields_kwarg) + def inner(instance: models.Model, *args: object, **kwargs: Any) -> object: + update_fields: Iterable[str] | None = kwargs.get(fields_kwarg) if update_fields is None: fields = self.fields else: @@ -343,7 +388,15 @@ def inner(instance, *args, **kwargs): setattr(model, method, inner) - def __get__(self, instance, owner): + @overload + def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker: + ... + + @overload + def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker: + ... + + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker: if instance is None: return self else: @@ -352,7 +405,7 @@ def __get__(self, instance, owner): class ModelInstanceTracker(FieldInstanceTracker): - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if not self.instance.pk: return True @@ -361,7 +414,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def changed(self): + def changed(self) -> dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" if not self.instance.pk: return {} @@ -373,5 +426,5 @@ def changed(self): class ModelTracker(FieldTracker): tracker_class = ModelInstanceTracker - def get_field_map(self, cls): + def get_field_map(self, cls: type[models.Model]) -> dict[str, str]: return {field: field for field in self.fields} From bde2d8f9a9d1a965a9f7a4d98c734ebcbc3aeeb1 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Mon, 20 Mar 2023 19:16:03 +0100 Subject: [PATCH 07/28] Annotate the `managers` module --- model_utils/managers.py | 220 +++++++++++++++++++++++++--------------- 1 file changed, 138 insertions(+), 82 deletions(-) diff --git a/model_utils/managers.py b/model_utils/managers.py index 10e9c349..c4631d2a 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,20 +1,35 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload from django.core.exceptions import ObjectDoesNotExist from django.db import connection, models from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel -from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from collections.abc import Iterator + + from django.db.models.query import BaseIterable + from django.db.models.query import ModelIterable as ModelIterableGeneric + from django.db.models.query import QuerySet as QuerySetGeneric + + ModelIterable = ModelIterableGeneric[ModelT] + QuerySet = QuerySetGeneric[ModelT] +else: + from django.db.models.query import ModelIterable, QuerySet + class InheritanceIterable(ModelIterable): - def __iter__(self): + def __iter__(self) -> Iterator[ModelT]: queryset = self.queryset - iter = ModelIterable(queryset) - if getattr(queryset, 'subclasses', False): + iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + assert hasattr(queryset, '_get_sub_obj_recurse') extras = tuple(queryset.query.extra.keys()) # sort the subclass names longest first, # so with 'a' and 'a__b' it goes as deep as possible @@ -28,7 +43,7 @@ def __iter__(self): if not sub_obj: sub_obj = obj - if getattr(queryset, '_annotated', False): + if hasattr(queryset, '_annotated'): for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) @@ -40,26 +55,31 @@ def __iter__(self): yield from iter -class InheritanceQuerySetMixin: - def __init__(self, *args, **kwargs): +class InheritanceQuerySetMixin(Generic[ModelT]): + + model: type[ModelT] + subclasses: Sequence[str] + + def __init__(self, *args: object, **kwargs: object): super().__init__(*args, **kwargs) - self._iterable_class = InheritanceIterable + self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable - def select_subclasses(self, *subclasses): - calculated_subclasses = self._get_subclasses_recurse(self.model) + def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model + calculated_subclasses = self._get_subclasses_recurse(model) # if none were passed in, we can just short circuit and select all if not subclasses: - subclasses = calculated_subclasses + selected_subclasses = calculated_subclasses else: - verified_subclasses = [] + verified_subclasses: list[str] = [] for subclass in subclasses: # special case for passing in the same model as the queryset # is bound against. Rather than raise an error later, we know # we can allow this through. - if subclass is self.model: + if subclass is model: continue - if not isinstance(subclass, (str,)): + if not isinstance(subclass, str): subclass = self._get_ancestors_path(subclass) if subclass in calculated_subclasses: @@ -69,38 +89,39 @@ def select_subclasses(self, *subclasses): '{!r} is not in the discovered subclasses, tried: {}'.format( subclass, ', '.join(calculated_subclasses)) ) - subclasses = verified_subclasses + selected_subclasses = verified_subclasses - if subclasses: - new_qs = self.select_related(*subclasses) - else: - new_qs = self - new_qs.subclasses = subclasses + new_qs = cast('InheritanceQuerySet[ModelT]', self) + if selected_subclasses: + new_qs = new_qs.select_related(*selected_subclasses) + new_qs.subclasses = selected_subclasses return new_qs - def _chain(self, **kwargs): + def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]: update = {} for name in ['subclasses', '_annotated']: if hasattr(self, name): update[name] = getattr(self, name) - chained = super()._chain(**kwargs) + # django-stubs doesn't include this private API. + chained = super()._chain(**kwargs) # type: ignore[misc] chained.__dict__.update(update) return chained - def _clone(self): - qs = super()._clone() + def _clone(self) -> InheritanceQuerySet[ModelT]: + # django-stubs doesn't include this private API. + qs = super()._clone() # type: ignore[misc] for name in ['subclasses', '_annotated']: if hasattr(self, name): setattr(qs, name, getattr(self, name)) return qs - def annotate(self, *args, **kwargs): - qset = super().annotate(*args, **kwargs) + def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) return qset - def _get_subclasses_recurse(self, model): + def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]: """ Given a Model class, find all related objects, exploring children recursively, returning a `list` of strings representing the @@ -126,7 +147,7 @@ def _get_subclasses_recurse(self, model): subclasses.append(rel.get_accessor_name()) return subclasses - def _get_ancestors_path(self, model): + def _get_ancestors_path(self, model: type[models.Model]) -> str: """ Serves as an opposite to _get_subclasses_recurse, instead walking from the Model class up the Model's ancestry and constructing the desired @@ -136,7 +157,7 @@ def _get_ancestors_path(self, model): raise ValueError( f"{model!r} is not a subclass of {self.model!r}") - ancestry = [] + ancestry: list[str] = [] # should be a OneToOneField or None parent_link = model._meta.get_ancestor_link(self.model) @@ -149,7 +170,7 @@ def _get_ancestors_path(self, model): return LOOKUP_SEP.join(ancestry) - def _get_sub_obj_recurse(self, obj, s): + def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None: rel, _, s = s.partition(LOOKUP_SEP) try: @@ -162,12 +183,14 @@ def _get_sub_obj_recurse(self, obj, s): else: return node - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.select_subclasses().get(*args, **kwargs) -class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): - def instance_of(self, *models): +# Defining the 'model' attribute using a generic type triggers a bug in mypy: +# https://github.com/python/mypy/issues/9031 +class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: """ Fetch only objects that are instances of the provided model(s). """ @@ -190,88 +213,118 @@ def instance_of(self, *models): ) for field in model._meta.parents.values() ]) + ')') - return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + return cast( + 'InheritanceQuerySet[ModelT]', + self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + ) -class InheritanceManagerMixin: +class InheritanceManagerMixin(Generic[ModelT]): _queryset_class = InheritanceQuerySet - def get_queryset(self): - return self._queryset_class(self.model) + if TYPE_CHECKING: + def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... - def select_subclasses(self, *subclasses): + def get_queryset(self) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model # type: ignore[attr-defined] + return self._queryset_class(model) + + def select_subclasses( + self, *subclasses: str | type[models.Model] + ) -> InheritanceQuerySet[ModelT]: return self.get_queryset().select_subclasses(*subclasses) - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.get_queryset().get_subclass(*args, **kwargs) - def instance_of(self, *models): + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: return self.get_queryset().instance_of(*models) -class InheritanceManager(InheritanceManagerMixin, models.Manager): +class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]): pass -class QueryManagerMixin: +class QueryManagerMixin(Generic[ModelT]): + + @overload + def __init__(self, *args: models.Q): + ... - def __init__(self, *args, **kwargs): + @overload + def __init__(self, **kwargs: object): + ... + + def __init__(self, *args: models.Q, **kwargs: object): if args: self._q = args[0] else: self._q = models.Q(**kwargs) - self._order_by = None + self._order_by: tuple[Any, ...] | None = None super().__init__() - def order_by(self, *args): + def order_by(self, *args: Any) -> QueryManager[ModelT]: self._order_by = args - return self + return cast('QueryManager[ModelT]', self) - def get_queryset(self): - qs = super().get_queryset().filter(self._q) + def get_queryset(self) -> QuerySet[ModelT]: + qs = super().get_queryset() # type: ignore[misc] + qs = qs.filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) return qs -class QueryManager(QueryManagerMixin, models.Manager): +class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc] pass -class SoftDeletableQuerySetMixin: +class SoftDeletableQuerySetMixin(Generic[ModelT]): """ QuerySet for SoftDeletableModel. Instead of removing instance sets its ``is_removed`` field to True. """ - def delete(self): + def delete(self) -> None: """ Soft delete objects from queryset (set their ``is_removed`` field to True) """ - self.update(is_removed=True) + cast(QuerySet[ModelT], self).update(is_removed=True) -class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet): +# Note that our delete() method does not return anything, unlike Django's. +# https://github.com/jazzband/django-model-utils/issues/541 +class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] pass -class SoftDeletableManagerMixin: +class SoftDeletableManagerMixin(Generic[ModelT]): """ Manager that limits the queryset by default to show only not removed instances of model. """ _queryset_class = SoftDeletableQuerySet - def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs): + _db: str | None + + def __init__( + self, + *args: object, + _emit_deprecation_warnings: bool = False, + **kwargs: object + ): self.emit_deprecation_warnings = _emit_deprecation_warnings super().__init__(*args, **kwargs) - def get_queryset(self): + def get_queryset(self) -> SoftDeletableQuerySet[ModelT]: """ Return queryset limited to not removed entries. """ + model: type[ModelT] = self.model # type: ignore[attr-defined] + if self.emit_deprecation_warnings: warning_message = ( "{0}.objects model manager will include soft-deleted objects in an " @@ -279,23 +332,23 @@ def get_queryset(self): "excluding soft-deleted objects. See " "https://django-model-utils.readthedocs.io/en/stable/models.html" "#softdeletablemodel for more information." - ).format(self.model.__class__.__name__) + ).format(model.__class__.__name__) warnings.warn(warning_message, DeprecationWarning) - kwargs = {'model': self.model, 'using': self._db} - if hasattr(self, '_hints'): - kwargs['hints'] = self._hints + return self._queryset_class( + model=model, + using=self._db, + **({'hints': self._hints} if hasattr(self, '_hints') else {}) + ).filter(is_removed=False) - return self._queryset_class(**kwargs).filter(is_removed=False) - -class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): +class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]): pass -class JoinQueryset(models.QuerySet): +class JoinQueryset(models.QuerySet[Any]): - def join(self, qs=None): + def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]: ''' Join one queryset together with another using a temporary table. If no queryset is used, it will use the current queryset and join that @@ -310,11 +363,11 @@ def join(self, qs=None): to_field = 'id' if qs: - fk = [ + fks = [ fk for fk in qs.model._meta.fields if getattr(fk, 'related_model', None) == self.model ] - fk = fk[0] if fk else None + fk = fks[0] if fks else None model_set = f'{self.model.__name__.lower()}_set' key = fk or getattr(qs.model, model_set, None) @@ -371,21 +424,24 @@ class Meta: return new_qs -class JoinManagerMixin: - """ - Manager that adds a method join. This method allows you to join two - querysets together. - """ - _queryset_class = JoinQueryset +if not TYPE_CHECKING: + # Hide deprecated API during type checking, to encourage switch to + # 'JoinQueryset.as_manager()', which is supported by the mypy plugin + # of django-stubs. - def get_queryset(self): - warnings.warn( - "JoinManager and JoinManagerMixin are deprecated. " - "Please use 'JoinQueryset.as_manager()' instead.", - DeprecationWarning - ) - return self._queryset_class(model=self.model, using=self._db) + class JoinManagerMixin: + """ + Manager that adds a method join. This method allows you to join two + querysets together. + """ + def get_queryset(self): + warnings.warn( + "JoinManager and JoinManagerMixin are deprecated. " + "Please use 'JoinQueryset.as_manager()' instead.", + DeprecationWarning + ) + return self._queryset_class(model=self.model, using=self._db) -class JoinManager(JoinManagerMixin, models.Manager): - pass + class JoinManager(JoinManagerMixin): + pass From 172cc72ec638f64d81e023d8c4531d3ada9daeb4 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Mon, 20 Mar 2023 19:16:18 +0100 Subject: [PATCH 08/28] Annotate the `models` module --- model_utils/models.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/model_utils/models.py b/model_utils/models.py index f5b717f1..71d0055c 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any, Literal, TypeVar, overload + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.db.models.functions import Now @@ -14,6 +16,8 @@ ) from model_utils.managers import QueryManager, SoftDeletableManager +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + now = Now() @@ -26,7 +30,7 @@ class TimeStampedModel(models.Model): created = AutoCreatedField(_('created')) modified = AutoLastModifiedField(_('modified')) - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Overriding the save method in order to make sure that modified field is updated even if it is not given as @@ -67,7 +71,7 @@ class StatusModel(models.Model): status = StatusField(_('status')) status_changed = MonitorField(_('status changed'), monitor='status') - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Overriding the save method in order to make sure that status_changed field is updated even if it is not given as @@ -83,7 +87,7 @@ class Meta: abstract = True -def add_status_query_managers(sender, **kwargs): +def add_status_query_managers(sender: type[models.Model], **kwargs: Any) -> None: """ Add a Querymanager for each status item dynamically. @@ -92,6 +96,7 @@ def add_status_query_managers(sender, **kwargs): return default_manager = sender._meta.default_manager + assert default_manager is not None for value, display in getattr(sender, 'STATUS', ()): if _field_exists(sender, value): @@ -105,7 +110,7 @@ def add_status_query_managers(sender, **kwargs): sender._meta.default_manager_name = default_manager.name -def add_timeframed_query_manager(sender, **kwargs): +def add_timeframed_query_manager(sender: type[models.Model], **kwargs: Any) -> None: """ Add a QueryManager for a specific timeframe. @@ -128,7 +133,7 @@ def add_timeframed_query_manager(sender, **kwargs): models.signals.class_prepared.connect(add_timeframed_query_manager) -def _field_exists(model_class, field_name): +def _field_exists(model_class: type[models.Model], field_name: str) -> bool: return field_name in [f.attname for f in model_class._meta.local_fields] @@ -144,11 +149,28 @@ class SoftDeletableModel(models.Model): class Meta: abstract = True - objects = SoftDeletableManager(_emit_deprecation_warnings=True) - available_objects = SoftDeletableManager() + objects: models.Manager[SoftDeletableModel] = SoftDeletableManager(_emit_deprecation_warnings=True) + available_objects: models.Manager[SoftDeletableModel] = SoftDeletableManager() all_objects = models.Manager() - def delete(self, using=None, *args, soft=True, **kwargs): + # Note that soft delete does not return anything, + # which doesn't conform to Django's interface. + # https://github.com/jazzband/django-model-utils/issues/541 + @overload # type: ignore[override] + def delete( + self, using: Any = None, *args: Any, soft: Literal[True] = True, **kwargs: Any + ) -> None: + ... + + @overload + def delete( + self, using: Any = None, *args: Any, soft: Literal[False], **kwargs: Any + ) -> tuple[int, dict[str, int]]: + ... + + def delete( + self, using: Any = None, *args: Any, soft: bool = True, **kwargs: Any + ) -> tuple[int, dict[str, int]] | None: """ Soft delete object (set its ``is_removed`` field to True). Actually delete object if setting ``soft`` to False. @@ -156,6 +178,7 @@ def delete(self, using=None, *args, soft=True, **kwargs): if soft: self.is_removed = True self.save(using=using) + return None else: return super().delete(using, *args, **kwargs) From ebfb3455dcd51be531a95906f25a1cedf0f94366 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Mon, 20 Mar 2023 19:16:32 +0100 Subject: [PATCH 09/28] Annotate the `fields` module --- model_utils/fields.py | 111 +++++++++++++++++++++++++++--------------- 1 file changed, 72 insertions(+), 39 deletions(-) diff --git a/model_utils/fields.py b/model_utils/fields.py index 788d07cb..143528ff 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -2,12 +2,18 @@ import secrets import uuid +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Union from django.conf import settings from django.core.exceptions import ValidationError from django.db import models from django.utils.timezone import now +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from datetime import datetime + DEFAULT_CHOICES_NAME = 'STATUS' @@ -20,7 +26,7 @@ class AutoCreatedField(models.DateTimeField): """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault('editable', False) kwargs.setdefault('default', now) super().__init__(*args, **kwargs) @@ -33,13 +39,13 @@ class AutoLastModifiedField(AutoCreatedField): By default, sets editable=False and default=datetime.now. """ - def get_default(self): + def get_default(self) -> datetime: """Return the default value for this field.""" if not hasattr(self, "_default"): self._default = super().get_default() return self._default - def pre_save(self, model_instance, add): + def pre_save(self, model_instance: models.Model, add: bool) -> datetime: value = now() if add: current_value = getattr(model_instance, self.attname, self.get_default()) @@ -70,13 +76,19 @@ class StatusField(models.CharField): South can handle this field when it freezes a model. """ - def __init__(self, *args, no_check_for_status=False, choices_name=DEFAULT_CHOICES_NAME, **kwargs): + def __init__( + self, + *args: Any, + no_check_for_status: bool = False, + choices_name: str = DEFAULT_CHOICES_NAME, + **kwargs: Any + ): kwargs.setdefault('max_length', 100) self.check_for_status = not no_check_for_status self.choices_name = choices_name super().__init__(*args, **kwargs) - def prepare_class(self, sender, **kwargs): + def prepare_class(self, sender: type[models.Model], **kwargs: Any) -> None: if not sender._meta.abstract and self.check_for_status: assert hasattr(sender, self.choices_name), \ "To use StatusField, the model '%s' must have a %s choices class attribute." \ @@ -85,7 +97,7 @@ def prepare_class(self, sender, **kwargs): if not self.has_default(): self.default = tuple(getattr(sender, self.choices_name))[0][0] # set first as default - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: models.signals.class_prepared.connect(self.prepare_class, sender=cls) # we don't set the real choices until class_prepared (so we can rely on # the STATUS class attr being available), but we need to set some dummy @@ -93,7 +105,7 @@ def contribute_to_class(self, cls, name, *args, **kwargs): self.choices = [(0, 'dummy')] super().contribute_to_class(cls, name, *args, **kwargs) - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['no_check_for_status'] = True return name, path, args, kwargs @@ -107,30 +119,28 @@ class MonitorField(models.DateTimeField): """ - def __init__(self, *args, monitor, when=None, **kwargs): + def __init__(self, *args: Any, monitor: str, when: Iterable[Any] | None = None, **kwargs: Any): default = None if kwargs.get("null") else now kwargs.setdefault('default', default) self.monitor = monitor - if when is not None: - when = set(when) - self.when = when + self.when = None if when is None else set(when) super().__init__(*args, **kwargs) - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: self.monitor_attname = '_monitor_%s' % name models.signals.post_init.connect(self._save_initial, sender=cls) super().contribute_to_class(cls, name, *args, **kwargs) - def get_monitored_value(self, instance): + def get_monitored_value(self, instance: models.Model) -> Any: return getattr(instance, self.monitor) - def _save_initial(self, sender, instance, **kwargs): + def _save_initial(self, sender: type[models.Model], instance: models.Model, **kwargs: Any) -> None: if self.monitor in instance.get_deferred_fields(): # Fix related to issue #241 to avoid recursive error on double monitor fields return setattr(instance, self.monitor_attname, self.get_monitored_value(instance)) - def pre_save(self, model_instance, add): + def pre_save(self, model_instance: models.Model, add: bool) -> Any: value = now() previous = getattr(model_instance, self.monitor_attname, None) current = self.get_monitored_value(model_instance) @@ -140,7 +150,7 @@ def pre_save(self, model_instance, add): self._save_initial(model_instance.__class__, model_instance) return super().pre_save(model_instance, add) - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['monitor'] = self.monitor if self.when is not None: @@ -154,12 +164,12 @@ def deconstruct(self): SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2) -def _excerpt_field_name(name): +def _excerpt_field_name(name: str) -> str: return '_%s_excerpt' % name -def get_excerpt(content): - excerpt = [] +def get_excerpt(content: str) -> str: + excerpt: list[str] = [] default_excerpt = [] paras_seen = 0 for line in content.splitlines(): @@ -175,7 +185,7 @@ def get_excerpt(content): class SplitText: - def __init__(self, instance, field_name, excerpt_field_name): + def __init__(self, instance: models.Model, field_name: str, excerpt_field_name: str): # instead of storing actual values store a reference to the instance # along with field names, this makes assignment possible self.instance = instance @@ -183,36 +193,36 @@ def __init__(self, instance, field_name, excerpt_field_name): self.excerpt_field_name = excerpt_field_name @property - def content(self): + def content(self) -> str: return self.instance.__dict__[self.field_name] @content.setter - def content(self, val): + def content(self, val: str) -> None: setattr(self.instance, self.field_name, val) @property - def excerpt(self): + def excerpt(self) -> str: return getattr(self.instance, self.excerpt_field_name) @property - def has_more(self): + def has_more(self) -> bool: return self.excerpt.strip() != self.content.strip() - def __str__(self): + def __str__(self) -> str: return self.content class SplitDescriptor: - def __init__(self, field): + def __init__(self, field: SplitField): self.field = field self.excerpt_field_name = _excerpt_field_name(self.field.name) - def __get__(self, instance, owner): + def __get__(self, instance: models.Model, owner: type[models.Model]) -> SplitText: if instance is None: raise AttributeError('Can only be accessed via an instance.') return SplitText(instance, self.field.name, self.excerpt_field_name) - def __set__(self, obj, value): + def __set__(self, obj: models.Model, value: SplitText | str) -> None: if isinstance(value, SplitText): obj.__dict__[self.field.name] = value.content setattr(obj, self.excerpt_field_name, value.excerpt) @@ -220,25 +230,32 @@ def __set__(self, obj, value): obj.__dict__[self.field.name] = value -class SplitField(models.TextField): - def contribute_to_class(self, cls, name, *args, **kwargs): +if TYPE_CHECKING: + _SplitFieldBase = models.TextField[Union[SplitText, str], SplitText] +else: + _SplitFieldBase = models.TextField + + +class SplitField(_SplitFieldBase): + + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: if not cls._meta.abstract: - excerpt_field = models.TextField(editable=False) + excerpt_field: models.TextField = models.TextField(editable=False) cls.add_to_class(_excerpt_field_name(name), excerpt_field) super().contribute_to_class(cls, name, *args, **kwargs) setattr(cls, self.name, SplitDescriptor(self)) - def pre_save(self, model_instance, add): - value = super().pre_save(model_instance, add) + def pre_save(self, model_instance: models.Model, add: bool) -> str: + value: SplitText = super().pre_save(model_instance, add) excerpt = get_excerpt(value.content) setattr(model_instance, _excerpt_field_name(self.attname), excerpt) return value.content - def value_to_string(self, obj): + def value_to_string(self, obj: models.Model) -> str: value = self.value_from_object(obj) return value.content - def get_prep_value(self, value): + def get_prep_value(self, value: Any) -> str: try: return value.content except AttributeError: @@ -250,7 +267,14 @@ class UUIDField(models.UUIDField): A field for storing universally unique identifiers. Use Python UUID class. """ - def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs): + def __init__( + self, + primary_key: bool = True, + version: int = 4, + editable: bool = False, + *args: Any, + **kwargs: Any + ): """ Parameters ---------- @@ -276,6 +300,7 @@ def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs) raise ValidationError( 'UUID version is not valid.') + default: Callable[..., uuid.UUID] if version == 1: default = uuid.uuid1 elif version == 3: @@ -296,7 +321,15 @@ class UrlsafeTokenField(models.CharField): A field for storing a unique token in database. """ - def __init__(self, editable=False, max_length=128, factory=None, **kwargs): + max_length: int + + def __init__( + self, + editable: bool = False, + max_length: int = 128, + factory: Callable[[int], str] | None = None, + **kwargs: Any + ): """ Parameters ---------- @@ -321,14 +354,14 @@ def __init__(self, editable=False, max_length=128, factory=None, **kwargs): super().__init__(editable=editable, max_length=max_length, **kwargs) - def get_default(self): + def get_default(self) -> str: if self._factory is not None: return self._factory(self.max_length) # generate a token of length x1.33 approx. trim up to max length token = secrets.token_urlsafe(self.max_length)[:self.max_length] return token - def deconstruct(self): + def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: name, path, args, kwargs = super().deconstruct() kwargs['factory'] = self._factory return name, path, args, kwargs From 2b0b4827a5eeaedd1bba16e928756d4b0d34c5c1 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Mon, 20 Mar 2023 18:57:52 +0100 Subject: [PATCH 10/28] Annotate the `choices` module The `Choices` constructor actually only accepts lists and tuples, not arbitrary sequences. However, using `list` in the annotation opens a can of worms related to type variance. --- model_utils/choices.py | 134 +++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 33 deletions(-) diff --git a/model_utils/choices.py b/model_utils/choices.py index c6cc438f..16ef2ce3 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -1,9 +1,32 @@ from __future__ import annotations import copy +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload +T = TypeVar("T") -class Choices: +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from django_stubs_ext import StrOrPromise + + # The type argument 'T' to 'Choices' is the database representation type. + _Double = tuple[T, StrOrPromise] + _Triple = tuple[T, str, StrOrPromise] + _Group = tuple[StrOrPromise, Sequence["_Choice[T]"]] + _Choice = _Double[T] | _Triple[T] | _Group[T] + # Choices can only be given as a single string if 'T' is 'str'. + _GroupStr = tuple[StrOrPromise, Sequence["_ChoiceStr"]] + _ChoiceStr = str | _Double[str] | _Triple[str] | _GroupStr + # Note that we only accept lists and tuples in groups, not arbitrary sequences. + # However, annotating it as such causes many problems. + + _DoubleRead = _Double[T] | tuple[StrOrPromise, Iterable["_DoubleRead[T]"]] + _DoubleCollector = list[_Double[T] | tuple[StrOrPromise, "_DoubleCollector[T]"]] + _TripleCollector = list[_Triple[T] | tuple[StrOrPromise, "_TripleCollector[T]"]] + + +class Choices(Generic[T]): """ A class to encapsulate handy functionality for lists of choices for a Django model field. @@ -43,36 +66,60 @@ class Choices: """ - def __init__(self, *choices): + @overload + def __init__(self: Choices[str], *choices: _ChoiceStr): + ... + + @overload + def __init__(self, *choices: _Choice[T]): + ... + + def __init__(self, *choices: _ChoiceStr | _Choice[T]): # list of choices expanded to triples - can include optgroups - self._triples = [] + self._triples: _TripleCollector[T] = [] # list of choices as (db, human-readable) - can include optgroups - self._doubles = [] + self._doubles: _DoubleCollector[T] = [] # dictionary mapping db representation to human-readable - self._display_map = {} + self._display_map: dict[T, StrOrPromise | list[_Triple[T]]] = {} # dictionary mapping Python identifier to db representation - self._identifier_map = {} + self._identifier_map: dict[str, T] = {} # set of db representations - self._db_values = set() + self._db_values: set[T] = set() self._process(choices) - def _store(self, triple, triple_collector, double_collector): + def _store( + self, + triple: tuple[T, str, StrOrPromise], + triple_collector: _TripleCollector[T], + double_collector: _DoubleCollector[T] + ) -> None: self._identifier_map[triple[1]] = triple[0] self._display_map[triple[0]] = triple[2] self._db_values.add(triple[0]) triple_collector.append(triple) double_collector.append((triple[0], triple[2])) - def _process(self, choices, triple_collector=None, double_collector=None): + def _process( + self, + choices: Iterable[_ChoiceStr | _Choice[T]], + triple_collector: _TripleCollector[T] | None = None, + double_collector: _DoubleCollector[T] | None = None + ) -> None: if triple_collector is None: triple_collector = self._triples if double_collector is None: double_collector = self._doubles - store = lambda c: self._store(c, triple_collector, double_collector) + def store(c: tuple[Any, str, StrOrPromise]) -> None: + self._store(c, triple_collector, double_collector) for choice in choices: + # The type inference is not very accurate here: + # - we lied in the type aliases, stating groups contain an arbitrary Sequence + # rather than only list or tuple + # - there is no way to express that _ChoiceStr is only used when T=str + # - mypy 1.9.0 doesn't narrow types based on the value of len() if isinstance(choice, (list, tuple)): if len(choice) == 3: store(choice) @@ -81,13 +128,13 @@ def _process(self, choices, triple_collector=None, double_collector=None): # option group group_name = choice[0] subchoices = choice[1] - tc = [] + tc: _TripleCollector[T] = [] triple_collector.append((group_name, tc)) - dc = [] + dc: _DoubleCollector[T] = [] double_collector.append((group_name, dc)) self._process(subchoices, tc, dc) else: - store((choice[0], choice[0], choice[1])) + store((choice[0], cast(str, choice[0]), cast('StrOrPromise', choice[1]))) else: raise ValueError( "Choices can't take a list of length %s, only 2 or 3" @@ -96,54 +143,74 @@ def _process(self, choices, triple_collector=None, double_collector=None): else: store((choice, choice, choice)) - def __len__(self): + def __len__(self) -> int: return len(self._doubles) - def __iter__(self): + def __iter__(self) -> Iterator[_DoubleRead[T]]: return iter(self._doubles) - def __reversed__(self): + def __reversed__(self) -> Iterator[_DoubleRead[T]]: return reversed(self._doubles) - def __getattr__(self, attname): + def __getattr__(self, attname: str) -> T: try: return self._identifier_map[attname] except KeyError: raise AttributeError(attname) - def __getitem__(self, key): + def __getitem__(self, key: T) -> StrOrPromise | Sequence[_Triple[T]]: return self._display_map[key] - def __add__(self, other): + @overload + def __add__(self: Choices[str], other: Choices[str] | Iterable[_ChoiceStr]) -> Choices[str]: + ... + + @overload + def __add__(self, other: Choices[T] | Iterable[_Choice[T]]) -> Choices[T]: + ... + + def __add__(self, other: Choices[Any] | Iterable[_ChoiceStr | _Choice[Any]]) -> Choices[Any]: + other_args: list[Any] if isinstance(other, self.__class__): - other = other._triples + other_args = other._triples else: - other = list(other) - return Choices(*(self._triples + other)) + other_args = list(other) + return Choices(*(self._triples + other_args)) + + @overload + def __radd__(self: Choices[str], other: Iterable[_ChoiceStr]) -> Choices[str]: + ... + + @overload + def __radd__(self, other: Iterable[_Choice[T]]) -> Choices[T]: + ... - def __radd__(self, other): + def __radd__(self, other: Iterable[_ChoiceStr] | Iterable[_Choice[T]]) -> Choices[Any]: # radd is never called for matching types, so we don't check here - other = list(other) - return Choices(*(other + self._triples)) + other_args = list(other) + # The exact type of 'other' depends on our type argument 'T', which + # is expressed in the overloading, but lost within this method body. + return Choices(*(other_args + self._triples)) # type: ignore[arg-type] - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self._triples == other._triples return False - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format( self.__class__.__name__, ', '.join("%s" % repr(i) for i in self._triples) ) - def __contains__(self, item): + def __contains__(self, item: T) -> bool: return item in self._db_values - def __deepcopy__(self, memo): - return self.__class__(*copy.deepcopy(self._triples, memo)) + def __deepcopy__(self, memo: dict[int, Any] | None) -> Choices[T]: + args: list[Any] = copy.deepcopy(self._triples, memo) + return self.__class__(*args) - def subset(self, *new_identifiers): + def subset(self, *new_identifiers: str) -> Choices[T]: identifiers = set(self._identifier_map.keys()) if not identifiers.issuperset(new_identifiers): @@ -152,7 +219,8 @@ def subset(self, *new_identifiers): identifiers.symmetric_difference(new_identifiers), ) - return self.__class__(*[ + args: list[Any] = [ choice for choice in self._triples if choice[1] in new_identifiers - ]) + ] + return self.__class__(*args) From 713a3fec88e74990652a247bb548a0ed1f67a7d0 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 10 Apr 2024 18:24:37 +0200 Subject: [PATCH 11/28] Make type aliases compatible with old Python versions --- model_utils/choices.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/model_utils/choices.py b/model_utils/choices.py index 16ef2ce3..ce709be3 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -8,22 +8,29 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Sequence + # The type aliases defined here are evaluated when the django-stubs mypy plugin + # loads this module, so they must be able to execute under the lowest supported + # Python VM: + # - typing.List, typing.Tuple become obsolete in Pyton 3.9 + # - typing.Union becomes obsolete in Pyton 3.10 + from typing import List, Tuple, Union + from django_stubs_ext import StrOrPromise # The type argument 'T' to 'Choices' is the database representation type. - _Double = tuple[T, StrOrPromise] - _Triple = tuple[T, str, StrOrPromise] - _Group = tuple[StrOrPromise, Sequence["_Choice[T]"]] - _Choice = _Double[T] | _Triple[T] | _Group[T] + _Double = Tuple[T, StrOrPromise] + _Triple = Tuple[T, str, StrOrPromise] + _Group = Tuple[StrOrPromise, Sequence["_Choice[T]"]] + _Choice = Union[_Double[T], _Triple[T], _Group[T]] # Choices can only be given as a single string if 'T' is 'str'. - _GroupStr = tuple[StrOrPromise, Sequence["_ChoiceStr"]] - _ChoiceStr = str | _Double[str] | _Triple[str] | _GroupStr + _GroupStr = Tuple[StrOrPromise, Sequence["_ChoiceStr"]] + _ChoiceStr = Union[str, _Double[str], _Triple[str], _GroupStr] # Note that we only accept lists and tuples in groups, not arbitrary sequences. # However, annotating it as such causes many problems. - _DoubleRead = _Double[T] | tuple[StrOrPromise, Iterable["_DoubleRead[T]"]] - _DoubleCollector = list[_Double[T] | tuple[StrOrPromise, "_DoubleCollector[T]"]] - _TripleCollector = list[_Triple[T] | tuple[StrOrPromise, "_TripleCollector[T]"]] + _DoubleRead = Union[_Double[T], Tuple[StrOrPromise, Iterable["_DoubleRead[T]"]]] + _DoubleCollector = List[Union[_Double[T], Tuple[StrOrPromise, "_DoubleCollector[T]"]]] + _TripleCollector = List[Union[_Triple[T], Tuple[StrOrPromise, "_TripleCollector[T]"]]] class Choices(Generic[T]): From aeeb69a9dd7c5274b6ed1795888d1880052d26da Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 21 Mar 2023 14:21:03 +0100 Subject: [PATCH 12/28] Enable postponed evaluation of annotations for all test modules This allows using the latest annotation syntax supported by the type checker regardless of the runtime Python version. --- tests/fields.py | 2 ++ tests/test_choices.py | 2 ++ tests/test_fields/test_monitor_field.py | 2 ++ tests/test_fields/test_split_field.py | 2 ++ tests/test_fields/test_status_field.py | 2 ++ tests/test_fields/test_urlsafe_token_field.py | 2 ++ tests/test_fields/test_uuid_field.py | 2 ++ tests/test_inheritance_iterable.py | 2 ++ tests/test_managers/test_inheritance_manager.py | 2 ++ tests/test_managers/test_join_manager.py | 2 ++ tests/test_managers/test_query_manager.py | 2 ++ tests/test_managers/test_softdelete_manager.py | 2 ++ tests/test_managers/test_status_manager.py | 2 ++ tests/test_miscellaneous.py | 2 ++ tests/test_models/test_deferred_fields.py | 2 ++ tests/test_models/test_softdeletable_model.py | 2 ++ tests/test_models/test_status_model.py | 2 ++ tests/test_models/test_timeframed_model.py | 2 ++ tests/test_models/test_timestamped_model.py | 2 ++ tests/test_models/test_uuid_model.py | 2 ++ 20 files changed, 40 insertions(+) diff --git a/tests/fields.py b/tests/fields.py index fa302039..df074f19 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.db import models diff --git a/tests/test_choices.py b/tests/test_choices.py index 6d09319b..8916e977 100644 --- a/tests/test_choices.py +++ b/tests/test_choices.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from model_utils import Choices diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py index f0041368..143417c9 100644 --- a/tests/test_fields/test_monitor_field.py +++ b/tests/test_fields/test_monitor_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import time_machine diff --git a/tests/test_fields/test_split_field.py b/tests/test_fields/test_split_field.py index 6028f895..3199b56e 100644 --- a/tests/test_fields/test_split_field.py +++ b/tests/test_fields/test_split_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import Article, SplitFieldAbstractParent diff --git a/tests/test_fields/test_status_field.py b/tests/test_fields/test_status_field.py index ab250c6d..a24ce454 100644 --- a/tests/test_fields/test_status_field.py +++ b/tests/test_fields/test_status_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from model_utils.fields import StatusField diff --git a/tests/test_fields/test_urlsafe_token_field.py b/tests/test_fields/test_urlsafe_token_field.py index 66aeb29d..3feffde1 100644 --- a/tests/test_fields/test_urlsafe_token_field.py +++ b/tests/test_fields/test_urlsafe_token_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from django.db.models import NOT_PROVIDED diff --git a/tests/test_fields/test_uuid_field.py b/tests/test_fields/test_uuid_field.py index cc354f09..7401566a 100644 --- a/tests/test_fields/test_uuid_field.py +++ b/tests/test_fields/test_uuid_field.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from django.core.exceptions import ValidationError diff --git a/tests/test_inheritance_iterable.py b/tests/test_inheritance_iterable.py index f4202e53..d7cf943f 100644 --- a/tests/test_inheritance_iterable.py +++ b/tests/test_inheritance_iterable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.db.models import Prefetch from django.test import TestCase diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py index 60987bfe..89c5bd9d 100644 --- a/tests/test_managers/test_inheritance_manager.py +++ b/tests/test_managers/test_inheritance_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.db import models from django.test import TestCase diff --git a/tests/test_managers/test_join_manager.py b/tests/test_managers/test_join_manager.py index 57c1c344..3a6699c7 100644 --- a/tests/test_managers/test_join_manager.py +++ b/tests/test_managers/test_join_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import BoxJoinModel, JoinItemForeignKey diff --git a/tests/test_managers/test_query_manager.py b/tests/test_managers/test_query_manager.py index 2339dbe0..85aa9e95 100644 --- a/tests/test_managers/test_query_manager.py +++ b/tests/test_managers/test_query_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import Post diff --git a/tests/test_managers/test_softdelete_manager.py b/tests/test_managers/test_softdelete_manager.py index aec53576..e4c64b5b 100644 --- a/tests/test_managers/test_softdelete_manager.py +++ b/tests/test_managers/test_softdelete_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import CustomSoftDelete diff --git a/tests/test_managers/test_status_manager.py b/tests/test_managers/test_status_manager.py index 106881a2..ef467a0f 100644 --- a/tests/test_managers/test_status_manager.py +++ b/tests/test_managers/test_status_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase diff --git a/tests/test_miscellaneous.py b/tests/test_miscellaneous.py index 684d824c..8faaeddb 100644 --- a/tests/test_miscellaneous.py +++ b/tests/test_miscellaneous.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.core.management import call_command from django.test import TestCase diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py index 7a839fab..a3a42839 100644 --- a/tests/test_models/test_deferred_fields.py +++ b/tests/test_models/test_deferred_fields.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import ModelWithCustomDescriptor diff --git a/tests/test_models/test_softdeletable_model.py b/tests/test_models/test_softdeletable_model.py index c2ffd54f..38332d75 100644 --- a/tests/test_models/test_softdeletable_model.py +++ b/tests/test_models/test_softdeletable_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from django.utils.connection import ConnectionDoesNotExist diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index 1e36a58f..35c38089 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import time_machine diff --git a/tests/test_models/test_timeframed_model.py b/tests/test_models/test_timeframed_model.py index 3038828b..3cf3d730 100644 --- a/tests/test_models/test_timeframed_model.py +++ b/tests/test_models/test_timeframed_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timedelta from django.core.exceptions import ImproperlyConfigured diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py index 1087da10..193af3e6 100644 --- a/tests/test_models/test_timestamped_model.py +++ b/tests/test_models/test_timestamped_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timedelta, timezone import time_machine diff --git a/tests/test_models/test_uuid_model.py b/tests/test_models/test_uuid_model.py index 3c7663d7..62dcae3f 100644 --- a/tests/test_models/test_uuid_model.py +++ b/tests/test_models/test_uuid_model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from django.test import TestCase from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel From 218843d754af411f449a8efc70254b50c0c69e6b Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 22 Mar 2023 12:42:31 +0100 Subject: [PATCH 13/28] Add minimal annotations to unit tests These annotations are sufficient to pass mypy inspection if mypy is configured to allow unannotated functions. --- tests/models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models.py b/tests/models.py index 8b496540..58c97c13 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import ClassVar +from typing import ClassVar, TypeVar from django.db import models from django.db.models import Manager @@ -26,6 +26,8 @@ from model_utils.tracker import FieldTracker, ModelTracker from tests.fields import MutableField +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + class InheritanceManagerTestRelated(models.Model): pass @@ -128,7 +130,7 @@ class DoubleMonitored(models.Model): class Status(StatusModel): - STATUS = Choices( + STATUS: Choices[str] = Choices( ("active", _("active")), ("deleted", _("deleted")), ("on_hold", _("on hold")), @@ -184,7 +186,8 @@ class Post(models.Model): public: ClassVar[QueryManager[Post]] = QueryManager(published=True) public_confirmed: ClassVar[QueryManager[Post]] = QueryManager( models.Q(published=True) & models.Q(confirmed=True)) - public_reversed = QueryManager(published=True).order_by("-order") + public_reversed: QueryManager[Post] = QueryManager( + published=True).order_by("-order") class Meta: ordering = ("order",) @@ -246,7 +249,7 @@ class TrackedFK(models.Model): class TrackedAbstract(AbstractTracked): name = models.CharField(max_length=20) - number = models.IntegerField() + number = models.IntegerField() # type: ignore[assignment] mutable = MutableField(default=None) tracker = ModelTracker() From 23f1811b9dc83baf1b60c5ecf2afe74bbda513f7 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 22 Mar 2023 18:50:18 +0100 Subject: [PATCH 14/28] Annotate return type of test methods --- tests/test_choices.py | 100 ++++++------ tests/test_fields/test_field_tracker.py | 150 +++++++++--------- tests/test_fields/test_monitor_field.py | 40 ++--- tests/test_fields/test_split_field.py | 28 ++-- tests/test_fields/test_status_field.py | 10 +- tests/test_fields/test_urlsafe_token_field.py | 24 +-- tests/test_fields/test_uuid_field.py | 16 +- tests/test_inheritance_iterable.py | 2 +- .../test_managers/test_inheritance_manager.py | 96 +++++------ tests/test_managers/test_join_manager.py | 10 +- tests/test_managers/test_query_manager.py | 8 +- .../test_managers/test_softdelete_manager.py | 8 +- tests/test_managers/test_status_manager.py | 4 +- tests/test_miscellaneous.py | 10 +- tests/test_models/test_deferred_fields.py | 6 +- tests/test_models/test_softdeletable_model.py | 14 +- tests/test_models/test_status_model.py | 14 +- tests/test_models/test_timeframed_model.py | 16 +- tests/test_models/test_timestamped_model.py | 24 +-- tests/test_models/test_uuid_model.py | 4 +- 20 files changed, 295 insertions(+), 289 deletions(-) diff --git a/tests/test_choices.py b/tests/test_choices.py index 8916e977..d653cd71 100644 --- a/tests/test_choices.py +++ b/tests/test_choices.py @@ -6,61 +6,61 @@ class ChoicesTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.STATUS = Choices('DRAFT', 'PUBLISHED') - def test_getattr(self): + def test_getattr(self) -> None: self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - def test_indexing(self): + def test_indexing(self) -> None: self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') - def test_iteration(self): + def test_iteration(self) -> None: self.assertEqual(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) - def test_reversed(self): + def test_reversed(self) -> None: self.assertEqual(tuple(reversed(self.STATUS)), (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.STATUS), 2) - def test_repr(self): + def test_repr(self) -> None: self.assertEqual(repr(self.STATUS), "Choices" + repr(( ('DRAFT', 'DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), ))) - def test_wrong_length_tuple(self): + def test_wrong_length_tuple(self) -> None: with self.assertRaises(ValueError): Choices(('a',)) - def test_contains_value(self): + def test_contains_value(self) -> None: self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS) - def test_doesnt_contain_value(self): + def test_doesnt_contain_value(self) -> None: self.assertFalse('UNPUBLISHED' in self.STATUS) - def test_deepcopy(self): + def test_deepcopy(self) -> None: import copy self.assertEqual(list(self.STATUS), list(copy.deepcopy(self.STATUS))) - def test_equality(self): + def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) - def test_inequality(self): + def test_inequality(self) -> None: self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) self.assertNotEqual(self.STATUS, Choices('DRAFT')) - def test_composability(self): + def test_composability(self) -> None: self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) - def test_option_groups(self): + def test_option_groups(self) -> None: c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) self.assertEqual( list(c), @@ -72,47 +72,47 @@ def test_option_groups(self): class LabelChoicesTests(ChoicesTests): - def setUp(self): + def setUp(self) -> None: self.STATUS = Choices( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), 'DELETED', ) - def test_iteration(self): + def test_iteration(self) -> None: self.assertEqual(tuple(self.STATUS), ( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), ('DELETED', 'DELETED'), )) - def test_reversed(self): + def test_reversed(self) -> None: self.assertEqual(tuple(reversed(self.STATUS)), ( ('DELETED', 'DELETED'), ('PUBLISHED', 'is published'), ('DRAFT', 'is draft'), )) - def test_indexing(self): + def test_indexing(self) -> None: self.assertEqual(self.STATUS['PUBLISHED'], 'is published') - def test_default(self): + def test_default(self) -> None: self.assertEqual(self.STATUS.DELETED, 'DELETED') - def test_provided(self): + def test_provided(self) -> None: self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.STATUS), 3) - def test_equality(self): + def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), 'DELETED', )) - def test_inequality(self): + def test_inequality(self) -> None: self.assertNotEqual(self.STATUS, [ ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), @@ -120,27 +120,27 @@ def test_inequality(self): ]) self.assertNotEqual(self.STATUS, Choices('DRAFT')) - def test_repr(self): + def test_repr(self) -> None: self.assertEqual(repr(self.STATUS), "Choices" + repr(( ('DRAFT', 'DRAFT', 'is draft'), ('PUBLISHED', 'PUBLISHED', 'is published'), ('DELETED', 'DELETED', 'DELETED'), ))) - def test_contains_value(self): + def test_contains_value(self) -> None: self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS) # This should be True, because both the display value # and the internal representation are both DELETED. self.assertTrue('DELETED' in self.STATUS) - def test_doesnt_contain_value(self): + def test_doesnt_contain_value(self) -> None: self.assertFalse('UNPUBLISHED' in self.STATUS) - def test_doesnt_contain_display_value(self): + def test_doesnt_contain_display_value(self) -> None: self.assertFalse('is draft' in self.STATUS) - def test_composability(self): + def test_composability(self) -> None: self.assertEqual( Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), self.STATUS @@ -156,7 +156,7 @@ def test_composability(self): self.STATUS ) - def test_option_groups(self): + def test_option_groups(self) -> None: c = Choices( ('group a', [(1, 'one'), (2, 'two')]), ['group b', ((3, 'three'),)] @@ -171,64 +171,64 @@ def test_option_groups(self): class IdentifierChoicesTests(ChoicesTests): - def setUp(self): + def setUp(self) -> None: self.STATUS = Choices( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted')) - def test_iteration(self): + def test_iteration(self) -> None: self.assertEqual(tuple(self.STATUS), ( (0, 'is draft'), (1, 'is published'), (2, 'is deleted'), )) - def test_reversed(self): + def test_reversed(self) -> None: self.assertEqual(tuple(reversed(self.STATUS)), ( (2, 'is deleted'), (1, 'is published'), (0, 'is draft'), )) - def test_indexing(self): + def test_indexing(self) -> None: self.assertEqual(self.STATUS[1], 'is published') - def test_getattr(self): + def test_getattr(self) -> None: self.assertEqual(self.STATUS.DRAFT, 0) - def test_len(self): + def test_len(self) -> None: self.assertEqual(len(self.STATUS), 3) - def test_repr(self): + def test_repr(self) -> None: self.assertEqual(repr(self.STATUS), "Choices" + repr(( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted'), ))) - def test_contains_value(self): + def test_contains_value(self) -> None: self.assertTrue(0 in self.STATUS) self.assertTrue(1 in self.STATUS) self.assertTrue(2 in self.STATUS) - def test_doesnt_contain_value(self): + def test_doesnt_contain_value(self) -> None: self.assertFalse(3 in self.STATUS) - def test_doesnt_contain_display_value(self): + def test_doesnt_contain_display_value(self) -> None: self.assertFalse('is draft' in self.STATUS) - def test_doesnt_contain_python_attr(self): + def test_doesnt_contain_python_attr(self) -> None: self.assertFalse('PUBLISHED' in self.STATUS) - def test_equality(self): + def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices( (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), (2, 'DELETED', 'is deleted') )) - def test_inequality(self): + def test_inequality(self) -> None: self.assertNotEqual(self.STATUS, [ (0, 'DRAFT', 'is draft'), (1, 'PUBLISHED', 'is published'), @@ -236,7 +236,7 @@ def test_inequality(self): ]) self.assertNotEqual(self.STATUS, Choices('DRAFT')) - def test_composability(self): + def test_composability(self) -> None: self.assertEqual( Choices( (0, 'DRAFT', 'is draft'), @@ -267,7 +267,7 @@ def test_composability(self): self.STATUS ) - def test_option_groups(self): + def test_option_groups(self) -> None: c = Choices( ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), ['group b', ((3, 'THREE', 'three'),)] @@ -283,26 +283,26 @@ def test_option_groups(self): class SubsetChoicesTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.choices = Choices( (0, 'a', 'A'), (1, 'b', 'B'), ) - def test_nonexistent_identifiers_raise(self): + def test_nonexistent_identifiers_raise(self) -> None: with self.assertRaises(ValueError): self.choices.subset('a', 'c') - def test_solo_nonexistent_identifiers_raise(self): + def test_solo_nonexistent_identifiers_raise(self) -> None: with self.assertRaises(ValueError): self.choices.subset('c') - def test_empty_subset_passes(self): + def test_empty_subset_passes(self) -> None: subset = self.choices.subset() self.assertEqual(subset, Choices()) - def test_subset_returns_correct_subset(self): + def test_subset_returns_correct_subset(self) -> None: subset = self.choices.subset('a') self.assertEqual(subset, Choices((0, 'a', 'A'))) diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index f239bd37..91a84c9d 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -65,7 +65,7 @@ def update_instance(self, **kwargs): class FieldTrackerCommonTests: - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: self.assertPrevious(name=None, number=None) self.instance.name = 'new age' self.instance.number = 8 @@ -76,14 +76,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): tracked_class: type[models.Model] = Tracked - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker - def test_descriptor(self): + def test_descriptor(self) -> None: self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(name=None) self.instance.name = 'new age' self.assertChanged(name=None) @@ -94,7 +94,7 @@ def test_pre_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged(name=None, number=None, mutable=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=False, mutable=False) self.instance.name = 'new age' self.assertHasChanged(name=True, number=False, mutable=False) @@ -103,12 +103,12 @@ def test_pre_save_has_changed(self): self.instance.mutable = [1, 2, 3] self.assertHasChanged(name=True, number=True, mutable=True) - def test_save_with_args(self): + def test_save_with_args(self) -> None: self.instance.number = 1 self.instance.save(False, False, None, None) self.assertChanged() - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=False, mutable=False) self.assertPrevious(name=None, number=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None) @@ -129,7 +129,7 @@ def test_first_save(self): with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertHasChanged(name=False, number=False, mutable=False) self.instance.name = 'new age' @@ -141,14 +141,14 @@ def test_post_save_has_changed(self): self.instance.name = 'retro' self.assertHasChanged(name=False, number=True, mutable=True) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.instance.name = 'new age' self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) self.instance.mutable[1] = 4 self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' @@ -162,7 +162,7 @@ def test_post_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged(number=4) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(id=None, name='', number=None, mutable=None) self.instance.name = 'new age' self.assertCurrent(id=None, name='new age', number=None, mutable=None) @@ -175,7 +175,7 @@ def test_current(self): self.instance.save() self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3]) - def test_update_fields(self): + def test_update_fields(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' @@ -198,7 +198,7 @@ def test_update_fields(self): self.assertEqual(in_db.number, self.instance.number) self.assertEqual(in_db.mutable, self.instance.mutable) - def test_refresh_from_db(self): + def test_refresh_from_db(self) -> None: self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.tracked_class.objects.filter(pk=self.instance.pk).update( name='new age', number=8, mutable=[3, 2, 1]) @@ -214,7 +214,7 @@ def test_refresh_from_db(self): self.instance.refresh_from_db() self.assertChanged() - def test_with_deferred(self): + def test_with_deferred(self) -> None: self.instance.name = 'new age' self.instance.number = 1 self.instance.save() @@ -268,7 +268,7 @@ def test_with_deferred(self): class FieldTrackerMultipleInstancesTests(TestCase): - def test_with_deferred_fields_access_multiple(self): + def test_with_deferred_fields_access_multiple(self) -> None: Tracked.objects.create(pk=1, name='foo', number=1) Tracked.objects.create(pk=2, name='bar', number=2) @@ -283,11 +283,11 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase, tracked_class: type[models.Model] = TrackedNotDefault - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.name_tracker - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(name=None) self.instance.name = 'new age' self.assertChanged(name=None) @@ -296,7 +296,7 @@ def test_pre_save_changed(self): self.instance.name = '' self.assertChanged(name=None) - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=None) self.assertPrevious(name=None, number=None) self.assertCurrent(name='') @@ -308,14 +308,14 @@ def test_first_save(self): self.assertCurrent(name='retro') self.assertChanged(name=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=None) self.instance.name = 'new age' self.assertHasChanged(name=True, number=None) self.instance.number = 7 self.assertHasChanged(name=True, number=None) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertHasChanged(name=False, number=None) self.instance.name = 'new age' @@ -325,12 +325,12 @@ def test_post_save_has_changed(self): self.instance.name = 'retro' self.assertHasChanged(name=False, number=None) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4) self.instance.name = 'new age' self.assertPrevious(name='retro', number=None) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged() self.instance.name = 'new age' @@ -340,7 +340,7 @@ def test_post_save_changed(self): self.instance.name = 'retro' self.assertChanged() - def test_current(self): + def test_current(self) -> None: self.assertCurrent(name='') self.instance.name = 'new age' self.assertCurrent(name='new age') @@ -349,7 +349,7 @@ def test_current(self): self.instance.save() self.assertCurrent(name='new age') - def test_update_fields(self): + def test_update_fields(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged() self.instance.name = 'new age' @@ -362,11 +362,11 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase): tracked_class = TrackedNonFieldAttr - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker - def test_previous(self): + def test_previous(self) -> None: self.assertPrevious(rounded=None) self.instance.number = 7.5 self.assertPrevious(rounded=None) @@ -377,7 +377,7 @@ def test_previous(self): self.instance.save() self.assertPrevious(rounded=7) - def test_has_changed(self): + def test_has_changed(self) -> None: self.assertHasChanged(rounded=False) self.instance.number = 7.5 self.assertHasChanged(rounded=True) @@ -388,7 +388,7 @@ def test_has_changed(self): self.instance.number = 7.8 self.assertHasChanged(rounded=False) - def test_changed(self): + def test_changed(self) -> None: self.assertChanged() self.instance.number = 7.5 self.assertPrevious(rounded=None) @@ -401,7 +401,7 @@ def test_changed(self): self.instance.save() self.assertPrevious() - def test_current(self): + def test_current(self) -> None: self.assertCurrent(rounded=None) self.instance.number = 7.5 self.assertCurrent(rounded=8) @@ -414,12 +414,12 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase, tracked_class: type[models.Model] = TrackedMultiple - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.trackers = [self.instance.name_tracker, self.instance.number_tracker] - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertChanged(name=None) self.instance.name = 'new age' @@ -435,7 +435,7 @@ def test_pre_save_changed(self): self.instance.number = 8 self.assertChanged(number=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertHasChanged(name=True, number=None) self.instance.name = 'new age' @@ -445,12 +445,12 @@ def test_pre_save_has_changed(self): self.instance.name = 'new age' self.assertHasChanged(name=None, number=False) - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: for tracker in self.trackers: self.tracker = tracker super().test_pre_save_previous() - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) @@ -465,14 +465,14 @@ def test_post_save_has_changed(self): self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(name='retro', number=4) self.instance.name = 'new age' self.instance.number = 8 self.assertPrevious(tracker=self.trackers[0], name='retro', number=None) self.assertPrevious(tracker=self.trackers[1], name=None, number=4) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(name='retro', number=4) self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[1]) @@ -487,7 +487,7 @@ def test_post_save_changed(self): self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[1]) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(tracker=self.trackers[0], name='') self.assertCurrent(tracker=self.trackers[1], number=None) self.instance.name = 'new age' @@ -506,11 +506,11 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase): fk_class: type[models.Model] = Tracked tracked_class: type[models.Model] = TrackedFK - def setUp(self): + def setUp(self) -> None: self.old_fk = self.fk_class.objects.create(number=8) self.instance = self.tracked_class.objects.create(fk=self.old_fk) - def test_default(self): + def test_default(self) -> None: self.tracker = self.instance.tracker self.assertChanged() self.assertPrevious() @@ -520,7 +520,7 @@ def test_default(self): self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) - def test_custom(self): + def test_custom(self) -> None: self.tracker = self.instance.custom_tracker self.assertChanged() self.assertPrevious() @@ -530,7 +530,7 @@ def test_custom(self): self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(fk_id=self.instance.fk_id) - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: with self.assertNumQueries(1): self.tracked_class.objects.get() self.tracker = self.instance.custom_tracker_without_id @@ -549,19 +549,19 @@ class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): fk_class = Tracked tracked_class = TrackedFK - def setUp(self): + def setUp(self) -> None: model_tracked = self.fk_class.objects.create(name="", number=0) self.instance = self.tracked_class.objects.create(fk=model_tracked) - def test_default(self): + def test_default(self) -> None: self.tracker = self.instance.tracker self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) - def test_custom(self): + def test_custom(self) -> None: self.tracker = self.instance.custom_tracker self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: self.tracker = self.instance.custom_tracker_without_id self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) @@ -571,18 +571,18 @@ class FieldTrackerTimeStampedTests(FieldTrackerTestCase): fk_class = Tracked tracked_class = TrackerTimeStamped - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class.objects.create(name='old', number=1) self.tracker = self.instance.tracker - def test_set_modified_on_save(self): + def test_set_modified_on_save(self) -> None: old_modified = self.instance.modified self.instance.name = 'new' self.instance.save() self.assertGreater(self.instance.modified, old_modified) self.assertChanged() - def test_set_modified_on_save_update_fields(self): + def test_set_modified_on_save_update_fields(self) -> None: old_modified = self.instance.modified self.instance.name = 'new' self.instance.save(update_fields=('name',)) @@ -594,7 +594,7 @@ class InheritedFieldTrackerTests(FieldTrackerTests): tracked_class = InheritedTracked - def test_child_fields_not_tracked(self): + def test_child_fields_not_tracked(self) -> None: self.name2 = 'test' self.assertEqual(self.tracker.previous('name2'), None) self.assertRaises(FieldError, self.tracker.has_changed, 'name2') @@ -609,13 +609,13 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase): tracked_class = TrackedFileField - def setUp(self): + def setUp(self) -> None: self.instance = self.tracked_class() self.tracker = self.instance.tracker self.some_file = 'something.txt' self.another_file = 'another.txt' - def test_saved_data_without_instance(self): + def test_saved_data_without_instance(self) -> None: """ Tests that instance won't get copied by the Field Tracker. @@ -634,22 +634,22 @@ def test_saved_data_without_instance(self): self.assertEqual(self.instance.some_file.instance, self.instance) self.assertIsInstance(self.instance.some_file, FieldFile) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged(some_file=None) self.instance.some_file = self.some_file self.assertChanged(some_file=None) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(some_file=True) self.instance.some_file = self.some_file self.assertHasChanged(some_file=True) - def test_pre_save_previous(self): + def test_pre_save_previous(self) -> None: self.assertPrevious(some_file=None) self.instance.some_file = self.some_file self.assertPrevious(some_file=None) - def test_post_save_changed(self): + def test_post_save_changed(self) -> None: self.update_instance(some_file=self.some_file) self.assertChanged() previous_file = self.instance.some_file @@ -667,7 +667,7 @@ def test_post_save_changed(self): some_file=previous_file, ) - def test_post_save_has_changed(self): + def test_post_save_has_changed(self) -> None: self.update_instance(some_file=self.some_file) self.assertHasChanged(some_file=False) self.instance.some_file = self.another_file @@ -687,7 +687,7 @@ def test_post_save_has_changed(self): some_file=True, ) - def test_post_save_previous(self): + def test_post_save_previous(self) -> None: self.update_instance(some_file=self.some_file) previous_file = self.instance.some_file self.instance.some_file = self.another_file @@ -707,7 +707,7 @@ def test_post_save_previous(self): some_file=previous_file, ) - def test_current(self): + def test_current(self) -> None: self.assertCurrent(some_file=self.instance.some_file, id=None) self.instance.some_file = self.some_file self.assertCurrent(some_file=self.instance.some_file, id=None) @@ -732,7 +732,7 @@ class ModelTrackerTests(FieldTrackerTests): tracked_class: type[models.Model] = ModelTracked - def test_cache_compatible(self): + def test_cache_compatible(self) -> None: cache.set('key', self.instance) instance = cache.get('key') instance.number = 1 @@ -742,7 +742,7 @@ def test_cache_compatible(self): instance.number = 2 self.assertHasChanged(number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged() self.instance.name = 'new age' self.assertChanged() @@ -753,7 +753,7 @@ def test_pre_save_changed(self): self.instance.mutable = [1, 2, 3] self.assertChanged() - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None) @@ -774,7 +774,7 @@ def test_first_save(self): with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) @@ -786,7 +786,7 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests): tracked_class = ModelTrackedNotDefault - def test_first_save(self): + def test_first_save(self) -> None: self.assertHasChanged(name=True, number=True) self.assertPrevious(name=None, number=None) self.assertCurrent(name='') @@ -798,14 +798,14 @@ def test_first_save(self): self.assertCurrent(name='retro') self.assertChanged() - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) self.instance.number = 7 self.assertHasChanged(name=True, number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.assertChanged() self.instance.name = 'new age' self.assertChanged() @@ -819,7 +819,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): tracked_class = ModelTrackedMultiple - def test_pre_save_has_changed(self): + def test_pre_save_has_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertHasChanged(name=True, number=True) self.instance.name = 'new age' @@ -829,7 +829,7 @@ def test_pre_save_has_changed(self): self.instance.name = 'new age' self.assertHasChanged(name=True, number=True) - def test_pre_save_changed(self): + def test_pre_save_changed(self) -> None: self.tracker = self.instance.name_tracker self.assertChanged() self.instance.name = 'new age' @@ -851,7 +851,7 @@ class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): fk_class = ModelTracked tracked_class = ModelTrackedFK - def test_custom_without_id(self): + def test_custom_without_id(self) -> None: with self.assertNumQueries(2): self.tracked_class.objects.get() self.tracker = self.instance.custom_tracker_without_id @@ -869,7 +869,7 @@ class InheritedModelTrackerTests(ModelTrackerTests): tracked_class = InheritedModelTracked - def test_child_fields_not_tracked(self): + def test_child_fields_not_tracked(self) -> None: self.name2 = 'test' self.assertEqual(self.tracker.previous('name2'), None) self.assertTrue(self.tracker.has_changed('name2')) @@ -882,7 +882,7 @@ class AbstractModelTrackerTests(ModelTrackerTests): class TrackerContextDecoratorTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.instance = Tracked.objects.create(number=1) self.tracker = self.instance.tracker @@ -894,7 +894,7 @@ def assertNotChanged(self, *fields): for f in fields: self.assertFalse(self.tracker.has_changed(f)) - def test_context_manager(self): + def test_context_manager(self) -> None: with self.tracker: with self.tracker: self.instance.name = 'new' @@ -905,7 +905,7 @@ def test_context_manager(self): self.assertNotChanged('name') - def test_context_manager_fields(self): + def test_context_manager_fields(self) -> None: with self.tracker('number'): with self.tracker('number', 'name'): self.instance.name = 'new' @@ -918,7 +918,7 @@ def test_context_manager_fields(self): self.assertNotChanged('number', 'name') - def test_tracker_decorator(self): + def test_tracker_decorator(self) -> None: @Tracked.tracker def tracked_method(obj): @@ -929,7 +929,7 @@ def tracked_method(obj): self.assertNotChanged('name') - def test_tracker_decorator_fields(self): + def test_tracker_decorator_fields(self) -> None: @Tracked.tracker(fields=['name']) def tracked_method(obj): @@ -942,7 +942,7 @@ def tracked_method(obj): self.assertChanged('number') self.assertNotChanged('name') - def test_tracker_context_with_save(self): + def test_tracker_context_with_save(self) -> None: with self.tracker: self.instance.name = 'new' diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py index 143417c9..9c9ba841 100644 --- a/tests/test_fields/test_monitor_field.py +++ b/tests/test_fields/test_monitor_field.py @@ -10,33 +10,33 @@ class MonitorFieldTests(TestCase): - def setUp(self): + def setUp(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): self.instance = Monitored(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed(self): + def test_save_changed(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_double_save(self): + def test_double_save(self) -> None: self.instance.name = 'Jose' self.instance.save() changed = self.instance.name_changed self.instance.save() self.assertEqual(self.instance.name_changed, changed) - def test_no_monitor_arg(self): + def test_no_monitor_arg(self) -> None: with self.assertRaises(TypeError): MonitorField() - def test_monitor_default_is_none_when_nullable(self): + def test_monitor_default_is_none_when_nullable(self) -> None: self.assertIsNone(self.instance.name_changed_nullable) expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc) @@ -51,33 +51,33 @@ class MonitorWhenFieldTests(TestCase): """ Will record changes only when name is 'Jose' or 'Maria' """ - def setUp(self): + def setUp(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): self.instance = MonitorWhen(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Jose(self): + def test_save_changed_to_Jose(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Jose' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_save_changed_to_Maria(self): + def test_save_changed_to_Maria(self) -> None: with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) - def test_save_changed_to_Pedro(self): + def test_save_changed_to_Pedro(self) -> None: self.instance.name = 'Pedro' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_double_save(self): + def test_double_save(self) -> None: self.instance.name = 'Jose' self.instance.save() changed = self.instance.name_changed @@ -89,20 +89,20 @@ class MonitorWhenEmptyFieldTests(TestCase): """ Monitor should never be updated id when is an empty list. """ - def setUp(self): + def setUp(self) -> None: self.instance = MonitorWhenEmpty(name='Charlie') self.created = self.instance.name_changed - def test_save_no_change(self): + def test_save_no_change(self) -> None: self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Jose(self): + def test_save_changed_to_Jose(self) -> None: self.instance.name = 'Jose' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) - def test_save_changed_to_Maria(self): + def test_save_changed_to_Maria(self) -> None: self.instance.name = 'Maria' self.instance.save() self.assertEqual(self.instance.name_changed, self.created) @@ -110,18 +110,18 @@ def test_save_changed_to_Maria(self): class MonitorDoubleFieldTests(TestCase): - def setUp(self): + def setUp(self) -> None: DoubleMonitored.objects.create(name='Charlie', name2='Charlie2') - def test_recursion_error_with_only(self): + def test_recursion_error_with_only(self) -> None: # Any field passed to only() is generating a recursion error list(DoubleMonitored.objects.only('id')) - def test_recursion_error_with_defer(self): + def test_recursion_error_with_defer(self) -> None: # Only monitored fields passed to defer() are failing list(DoubleMonitored.objects.defer('name')) - def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self): + def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self) -> None: obj = DoubleMonitored.objects.defer('name').get(name='Charlie') with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)): obj.name = 'Charlie2' diff --git a/tests/test_fields/test_split_field.py b/tests/test_fields/test_split_field.py index 3199b56e..9dd27a17 100644 --- a/tests/test_fields/test_split_field.py +++ b/tests/test_fields/test_split_field.py @@ -9,62 +9,62 @@ class SplitFieldTests(TestCase): full_text = 'summary\n\n\n\nmore' excerpt = 'summary\n' - def setUp(self): + def setUp(self) -> None: self.post = Article.objects.create( title='example post', body=self.full_text) - def test_unicode_content(self): + def test_unicode_content(self) -> None: self.assertEqual(str(self.post.body), self.full_text) - def test_excerpt(self): + def test_excerpt(self) -> None: self.assertEqual(self.post.body.excerpt, self.excerpt) - def test_content(self): + def test_content(self) -> None: self.assertEqual(self.post.body.content, self.full_text) - def test_has_more(self): + def test_has_more(self) -> None: self.assertTrue(self.post.body.has_more) - def test_not_has_more(self): + def test_not_has_more(self) -> None: post = Article.objects.create(title='example 2', body='some text\n\nsome more\n') self.assertFalse(post.body.has_more) - def test_load_back(self): + def test_load_back(self) -> None: post = Article.objects.get(pk=self.post.pk) self.assertEqual(post.body.content, self.post.body.content) self.assertEqual(post.body.excerpt, self.post.body.excerpt) - def test_assign_to_body(self): + def test_assign_to_body(self) -> None: new_text = 'different\n\n\n\nother' self.post.body = new_text self.post.save() self.assertEqual(str(self.post.body), new_text) - def test_assign_to_content(self): + def test_assign_to_content(self) -> None: new_text = 'different\n\n\n\nother' self.post.body.content = new_text self.post.save() self.assertEqual(str(self.post.body), new_text) - def test_assign_to_excerpt(self): + def test_assign_to_excerpt(self) -> None: with self.assertRaises(AttributeError): self.post.body.excerpt = 'this should fail' - def test_access_via_class(self): + def test_access_via_class(self) -> None: with self.assertRaises(AttributeError): Article.body - def test_assign_splittext(self): + def test_assign_splittext(self) -> None: a = Article(title='Some Title') a.body = self.post.body self.assertEqual(a.body.excerpt, 'summary\n') - def test_value_to_string(self): + def test_value_to_string(self) -> None: f = self.post._meta.get_field('body') self.assertEqual(f.value_to_string(self.post), self.full_text) - def test_abstract_inheritance(self): + def test_abstract_inheritance(self) -> None: class Child(SplitFieldAbstractParent): pass diff --git a/tests/test_fields/test_status_field.py b/tests/test_fields/test_status_field.py index a24ce454..fe79e11f 100644 --- a/tests/test_fields/test_status_field.py +++ b/tests/test_fields/test_status_field.py @@ -13,22 +13,22 @@ class StatusFieldTests(TestCase): - def test_status_with_default_filled(self): + def test_status_with_default_filled(self) -> None: instance = StatusFieldDefaultFilled() self.assertEqual(instance.status, instance.STATUS.yes) - def test_status_with_default_not_filled(self): + def test_status_with_default_not_filled(self) -> None: instance = StatusFieldDefaultNotFilled() self.assertEqual(instance.status, instance.STATUS.no) - def test_no_check_for_status(self): + def test_no_check_for_status(self) -> None: field = StatusField(no_check_for_status=True) # this model has no STATUS attribute, so checking for it would error field.prepare_class(Article) - def test_get_status_display(self): + def test_get_status_display(self) -> None: instance = StatusFieldDefaultFilled() self.assertEqual(instance.get_status_display(), "Yes") - def test_choices_name(self): + def test_choices_name(self) -> None: StatusFieldChoicesName() diff --git a/tests/test_fields/test_urlsafe_token_field.py b/tests/test_fields/test_urlsafe_token_field.py index 3feffde1..6146fe61 100644 --- a/tests/test_fields/test_urlsafe_token_field.py +++ b/tests/test_fields/test_urlsafe_token_field.py @@ -9,41 +9,41 @@ class UrlsaftTokenFieldTests(TestCase): - def test_editable_default(self): + def test_editable_default(self) -> None: field = UrlsafeTokenField() self.assertFalse(field.editable) - def test_editable(self): + def test_editable(self) -> None: field = UrlsafeTokenField(editable=True) self.assertTrue(field.editable) - def test_max_length_default(self): + def test_max_length_default(self) -> None: field = UrlsafeTokenField() self.assertEqual(field.max_length, 128) - def test_max_length(self): + def test_max_length(self) -> None: field = UrlsafeTokenField(max_length=256) self.assertEqual(field.max_length, 256) - def test_factory_default(self): + def test_factory_default(self) -> None: field = UrlsafeTokenField() self.assertIsNone(field._factory) - def test_factory_not_callable(self): + def test_factory_not_callable(self) -> None: with self.assertRaises(TypeError): UrlsafeTokenField(factory='INVALID') - def test_get_default(self): + def test_get_default(self) -> None: field = UrlsafeTokenField() value = field.get_default() self.assertEqual(len(value), field.max_length) - def test_get_default_with_non_default_max_length(self): + def test_get_default_with_non_default_max_length(self) -> None: field = UrlsafeTokenField(max_length=64) value = field.get_default() self.assertEqual(len(value), 64) - def test_get_default_with_factory(self): + def test_get_default_with_factory(self) -> None: token = 'SAMPLE_TOKEN' factory = Mock(return_value=token) field = UrlsafeTokenField(factory=factory) @@ -52,12 +52,12 @@ def test_get_default_with_factory(self): self.assertEqual(value, token) factory.assert_called_once_with(field.max_length) - def test_no_default_param(self): + def test_no_default_param(self) -> None: field = UrlsafeTokenField(default='DEFAULT') self.assertIs(field.default, NOT_PROVIDED) - def test_deconstruct(self): - def test_factory(): + def test_deconstruct(self) -> None: + def test_factory() -> None: pass instance = UrlsafeTokenField(factory=test_factory) name, path, args, kwargs = instance.deconstruct() diff --git a/tests/test_fields/test_uuid_field.py b/tests/test_fields/test_uuid_field.py index 7401566a..4c77aaad 100644 --- a/tests/test_fields/test_uuid_field.py +++ b/tests/test_fields/test_uuid_field.py @@ -10,31 +10,31 @@ class UUIDFieldTests(TestCase): - def test_uuid_version_default(self): + def test_uuid_version_default(self) -> None: instance = UUIDField() self.assertEqual(instance.default, uuid.uuid4) - def test_uuid_version_1(self): + def test_uuid_version_1(self) -> None: instance = UUIDField(version=1) self.assertEqual(instance.default, uuid.uuid1) - def test_uuid_version_2_error(self): + def test_uuid_version_2_error(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 2) - def test_uuid_version_3(self): + def test_uuid_version_3(self) -> None: instance = UUIDField(version=3) self.assertEqual(instance.default, uuid.uuid3) - def test_uuid_version_4(self): + def test_uuid_version_4(self) -> None: instance = UUIDField(version=4) self.assertEqual(instance.default, uuid.uuid4) - def test_uuid_version_5(self): + def test_uuid_version_5(self) -> None: instance = UUIDField(version=5) self.assertEqual(instance.default, uuid.uuid5) - def test_uuid_version_bellow_min(self): + def test_uuid_version_bellow_min(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 0) - def test_uuid_version_above_max(self): + def test_uuid_version_above_max(self) -> None: self.assertRaises(ValidationError, UUIDField, 'version', 6) diff --git a/tests/test_inheritance_iterable.py b/tests/test_inheritance_iterable.py index d7cf943f..e9896c56 100644 --- a/tests/test_inheritance_iterable.py +++ b/tests/test_inheritance_iterable.py @@ -7,7 +7,7 @@ class InheritanceIterableTest(TestCase): - def test_prefetch(self): + def test_prefetch(self) -> None: qs = InheritanceManagerTestChild1.objects.all().prefetch_related( Prefetch( 'normal_field', diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py index 89c5bd9d..68e8a743 100644 --- a/tests/test_managers/test_inheritance_manager.py +++ b/tests/test_managers/test_inheritance_manager.py @@ -1,8 +1,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from django.db import models from django.test import TestCase +from model_utils.managers import InheritanceManager from tests.models import ( InheritanceManagerTestChild1, InheritanceManagerTestChild2, @@ -16,19 +19,22 @@ TimeFrame, ) +if TYPE_CHECKING: + from django.db.models.fields.related_descriptors import RelatedManager + class InheritanceManagerTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.child1 = InheritanceManagerTestChild1.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1_2 = \ InheritanceManagerTestGrandChild1_2.objects.create() - def get_manager(self): + def get_manager(self) -> InheritanceManager[InheritanceManagerTestParent]: return InheritanceManagerTestParent.objects - def test_normal(self): + def test_normal(self) -> None: children = { InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child2.pk), @@ -37,14 +43,14 @@ def test_normal(self): } self.assertEqual(set(self.get_manager().all()), children) - def test_select_all_subclasses(self): + def test_select_all_subclasses(self) -> None: children = {self.child1, self.child2} children.add(self.grandchild1) children.add(self.grandchild1_2) self.assertEqual( set(self.get_manager().select_subclasses()), children) - def test_select_subclasses_invalid_relation(self): + def test_select_subclasses_invalid_relation(self) -> None: """ If an invalid relation string is provided, we can provide the user with a list which is valid, rather than just have the select_related() @@ -54,7 +60,7 @@ def test_select_subclasses_invalid_relation(self): with self.assertRaisesRegex(ValueError, regex): self.get_manager().select_subclasses('user') - def test_select_specific_subclasses(self): + def test_select_specific_subclasses(self) -> None: children = { self.child1, InheritanceManagerTestParent(pk=self.child2.pk), @@ -69,7 +75,7 @@ def test_select_specific_subclasses(self): children, ) - def test_select_specific_grandchildren(self): + def test_select_specific_grandchildren(self) -> None: children = { InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child2.pk), @@ -85,7 +91,7 @@ def test_select_specific_grandchildren(self): children, ) - def test_children_and_grandchildren(self): + def test_children_and_grandchildren(self) -> None: children = { self.child1, InheritanceManagerTestParent(pk=self.child2.pk), @@ -102,24 +108,24 @@ def test_children_and_grandchildren(self): children, ) - def test_get_subclass(self): + def test_get_subclass(self) -> None: self.assertEqual( self.get_manager().get_subclass(pk=self.child1.pk), self.child1) - def test_get_subclass_on_queryset(self): + def test_get_subclass_on_queryset(self) -> None: self.assertEqual( self.get_manager().all().get_subclass(pk=self.child1.pk), self.child1) - def test_prior_select_related(self): + def test_prior_select_related(self) -> None: with self.assertNumQueries(1): obj = self.get_manager().select_related( "inheritancemanagertestchild1").select_subclasses( "inheritancemanagertestchild2").get(pk=self.child1.pk) obj.inheritancemanagertestchild1 - def test_manually_specifying_parent_fk_including_grandchildren(self): + def test_manually_specifying_parent_fk_including_grandchildren(self) -> None: """ given a Model which inherits from another Model, but also declares the OneToOne link manually using `related_name` and `parent_link`, @@ -150,7 +156,7 @@ def test_manually_specifying_parent_fk_including_grandchildren(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_manually_specifying_parent_fk_single_subclass(self): + def test_manually_specifying_parent_fk_single_subclass(self) -> None: """ Using a string related_name when the relation is manually defined instead of implicit should still work in the same way. @@ -170,11 +176,11 @@ def test_manually_specifying_parent_fk_single_subclass(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_filter_on_values_queryset(self): + def test_filter_on_values_queryset(self) -> None: queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk) self.assertEqual(list(queryset), [{'id': self.child1.pk}]) - def test_values_list_on_select_subclasses(self): + def test_values_list_on_select_subclasses(self) -> None: """ Using `select_subclasses` in conjunction with `values_list()` raised an exception in `_get_sub_obj_recurse()` because the result of `values_list()` @@ -219,14 +225,14 @@ def test_values_list_on_select_subclasses(self): class InheritanceManagerUsingModelsTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.parent1 = InheritanceManagerTestParent.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() - def test_select_subclass_by_child_model(self): + def test_select_subclass_by_child_model(self) -> None: """ Confirm that passing a child model works the same as passing the select_related manually @@ -238,7 +244,7 @@ def test_select_subclass_by_child_model(self): self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(list(objs), list(objsmodels)) - def test_select_subclass_by_grandchild_model(self): + def test_select_subclass_by_grandchild_model(self) -> None: """ Confirm that passing a grandchild model works the same as passing the select_related manually @@ -251,7 +257,7 @@ def test_select_subclass_by_grandchild_model(self): self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(list(objs), list(objsmodels)) - def test_selecting_all_subclasses_specifically_grandchildren(self): + def test_selecting_all_subclasses_specifically_grandchildren(self) -> None: """ A bare select_subclasses() should achieve the same results as doing select_subclasses and specifying all possible subclasses. @@ -268,7 +274,7 @@ def test_selecting_all_subclasses_specifically_grandchildren(self): self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(list(objs), list(objsmodels)) - def test_selecting_all_subclasses_specifically_children(self): + def test_selecting_all_subclasses_specifically_children(self) -> None: """ A bare select_subclasses() should achieve the same results as doing select_subclasses and specifying all possible subclasses. @@ -296,7 +302,7 @@ def test_selecting_all_subclasses_specifically_children(self): self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(list(objs), list(objsmodels)) - def test_select_subclass_just_self(self): + def test_select_subclass_just_self(self) -> None: """ Passing in the same model as the manager/queryset is bound against (ie: the root parent) should have no effect on the result set. @@ -312,7 +318,7 @@ def test_select_subclass_just_self(self): InheritanceManagerTestParent(pk=self.grandchild1_2.pk), ]) - def test_select_subclass_invalid_related_model(self): + def test_select_subclass_invalid_related_model(self) -> None: """ Confirming that giving a stupid model doesn't work. """ @@ -321,7 +327,7 @@ def test_select_subclass_invalid_related_model(self): InheritanceManagerTestParent.objects.select_subclasses( TimeFrame).order_by('pk') - def test_mixing_strings_and_classes_with_grandchildren(self): + def test_mixing_strings_and_classes_with_grandchildren(self) -> None: """ Given arguments consisting of both strings and model classes, ensure the right resolutions take place, accounting for the extra @@ -342,7 +348,7 @@ def test_mixing_strings_and_classes_with_grandchildren(self): ] self.assertEqual(list(objs), expecting2) - def test_mixing_strings_and_classes_with_children(self): + def test_mixing_strings_and_classes_with_children(self) -> None: """ Given arguments consisting of both strings and model classes, ensure the right resolutions take place, walking down as far as @@ -364,7 +370,7 @@ def test_mixing_strings_and_classes_with_children(self): ] self.assertEqual(list(objs), expecting2) - def test_duplications(self): + def test_duplications(self) -> None: """ Check that even if the same thing is provided as a string and a model that the right results are retrieved. @@ -381,7 +387,7 @@ def test_duplications(self): InheritanceManagerTestParent(pk=self.grandchild1_2.pk), ]) - def test_child_doesnt_accidentally_get_parent(self): + def test_child_doesnt_accidentally_get_parent(self) -> None: """ Given a Child model which also has an InheritanceManager, none of the returned objects should be Parent objects. @@ -394,7 +400,7 @@ def test_child_doesnt_accidentally_get_parent(self): InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), ], list(objs)) - def test_manually_specifying_parent_fk_only_specific_child(self): + def test_manually_specifying_parent_fk_only_specific_child(self) -> None: """ given a Model which inherits from another Model, but also declares the OneToOne link manually using `related_name` and `parent_link`, @@ -418,7 +424,7 @@ def test_manually_specifying_parent_fk_only_specific_child(self): self.assertEqual(set(results.subclasses), set(expected_related_names)) - def test_extras_descend(self): + def test_extras_descend(self) -> None: """ Ensure that extra(select=) values are copied onto sub-classes. """ @@ -427,25 +433,25 @@ def test_extras_descend(self): ) self.assertTrue(all(result.foo == (result.id + 1) for result in results)) - def test_limit_to_specific_subclass(self): + def test_limit_to_specific_subclass(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) self.assertEqual([child3], list(results)) - def test_limit_to_specific_subclass_with_custom_db_column(self): + def test_limit_to_specific_subclass_with_custom_db_column(self) -> None: item = InheritanceManagerTestChild3_1.objects.create() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1) self.assertEqual([item], list(results)) - def test_limit_to_specific_grandchild_class(self): + def test_limit_to_specific_grandchild_class(self) -> None: grandchild1 = InheritanceManagerTestGrandChild1.objects.get() results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) self.assertEqual([grandchild1], list(results)) - def test_limit_to_child_fetches_grandchildren_as_child_class(self): + def test_limit_to_child_fetches_grandchildren_as_child_class(self) -> None: # Not sure if this is the desired behaviour...? children = InheritanceManagerTestChild1.objects.all() @@ -453,7 +459,7 @@ def test_limit_to_child_fetches_grandchildren_as_child_class(self): self.assertEqual(set(children), set(results)) - def test_can_fetch_limited_class_grandchildren(self): + def test_can_fetch_limited_class_grandchildren(self) -> None: # Not sure if this is the desired behaviour...? children = InheritanceManagerTestChild1.objects.select_subclasses() @@ -461,7 +467,7 @@ def test_can_fetch_limited_class_grandchildren(self): self.assertEqual(set(children), set(results)) - def test_selecting_multiple_instance_classes(self): + def test_selecting_multiple_instance_classes(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() children1 = InheritanceManagerTestChild1.objects.all() @@ -469,7 +475,7 @@ def test_selecting_multiple_instance_classes(self): self.assertEqual(set([child3] + list(children1)), set(results)) - def test_selecting_multiple_instance_classes_including_grandchildren(self): + def test_selecting_multiple_instance_classes_including_grandchildren(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() grandchild1 = InheritanceManagerTestGrandChild1.objects.get() @@ -477,7 +483,7 @@ def test_selecting_multiple_instance_classes_including_grandchildren(self): self.assertEqual({child3, grandchild1}, set(results)) - def test_select_subclasses_interaction_with_instance_of(self): + def test_select_subclasses_interaction_with_instance_of(self) -> None: child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3) @@ -486,7 +492,7 @@ def test_select_subclasses_interaction_with_instance_of(self): class InheritanceManagerRelatedTests(InheritanceManagerTests): - def setUp(self): + def setUp(self) -> None: self.related = InheritanceManagerTestRelated.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create( related=self.related) @@ -495,16 +501,16 @@ def setUp(self): self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related) - def get_manager(self): + def get_manager(self) -> RelatedManager[InheritanceManagerTestParent]: # type: ignore[override] return self.related.imtests - def test_get_method_with_select_subclasses(self): + def test_get_method_with_select_subclasses(self) -> None: self.assertEqual( InheritanceManagerTestParent.objects.select_subclasses().get( id=self.child1.id), self.child1) - def test_get_method_with_select_subclasses_check_for_useless_join(self): + def test_get_method_with_select_subclasses_check_for_useless_join(self) -> None: child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1) self.assertEqual( str(InheritanceManagerTestChild4.objects.select_subclasses().filter( @@ -512,26 +518,26 @@ def test_get_method_with_select_subclasses_check_for_useless_join(self): str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter( id=child4.id).query)) - def test_annotate_with_select_subclasses(self): + def test_annotate_with_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( models.Count('id')) self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - def test_annotate_with_named_arguments_with_select_subclasses(self): + def test_annotate_with_named_arguments_with_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( test_count=models.Count('id')) self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - def test_annotate_before_select_subclasses(self): + def test_annotate_before_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.annotate( models.Count('id')).select_subclasses() self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - def test_annotate_with_named_arguments_before_select_subclasses(self): + def test_annotate_with_named_arguments_before_select_subclasses(self) -> None: qs = InheritanceManagerTestParent.objects.annotate( test_count=models.Count('id')).select_subclasses() self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self): + def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses() self.assertEqual(qs.subclasses, qs._clone().subclasses) diff --git a/tests/test_managers/test_join_manager.py b/tests/test_managers/test_join_manager.py index 3a6699c7..44bdcfc8 100644 --- a/tests/test_managers/test_join_manager.py +++ b/tests/test_managers/test_join_manager.py @@ -6,7 +6,7 @@ class JoinManagerTest(TestCase): - def setUp(self): + def setUp(self) -> None: for i in range(20): BoxJoinModel.objects.create(name=f'name_{i}') @@ -15,24 +15,24 @@ def setUp(self): ) JoinItemForeignKey.objects.create(weight=20) - def test_self_join(self): + def test_self_join(self) -> None: a_slice = BoxJoinModel.objects.all()[0:10] with self.assertNumQueries(1): result = a_slice.join() self.assertEqual(result.count(), 10) - def test_self_join_with_where_statement(self): + def test_self_join_with_where_statement(self) -> None: qs = BoxJoinModel.objects.filter(name='name_1') result = qs.join() self.assertEqual(result.count(), 1) - def test_join_with_other_qs(self): + def test_join_with_other_qs(self) -> None: item_qs = JoinItemForeignKey.objects.filter(weight=10) boxes = BoxJoinModel.objects.all().join(qs=item_qs) self.assertEqual(boxes.count(), 1) self.assertEqual(boxes[0].name, 'name_1') - def test_reverse_join(self): + def test_reverse_join(self) -> None: box_qs = BoxJoinModel.objects.filter(name='name_1') items = JoinItemForeignKey.objects.all().join(box_qs) self.assertEqual(items.count(), 1) diff --git a/tests/test_managers/test_query_manager.py b/tests/test_managers/test_query_manager.py index 85aa9e95..03ec814f 100644 --- a/tests/test_managers/test_query_manager.py +++ b/tests/test_managers/test_query_manager.py @@ -6,7 +6,7 @@ class QueryManagerTests(TestCase): - def setUp(self): + def setUp(self) -> None: data = ((True, True, 0), (True, False, 4), (False, False, 2), @@ -16,14 +16,14 @@ def setUp(self): for p, c, o in data: Post.objects.create(published=p, confirmed=c, order=o) - def test_passing_kwargs(self): + def test_passing_kwargs(self) -> None: qs = Post.public.all() self.assertEqual([p.order for p in qs], [0, 1, 4, 5]) - def test_passing_Q(self): + def test_passing_Q(self) -> None: qs = Post.public_confirmed.all() self.assertEqual([p.order for p in qs], [0, 1]) - def test_ordering(self): + def test_ordering(self) -> None: qs = Post.public_reversed.all() self.assertEqual([p.order for p in qs], [5, 4, 1, 0]) diff --git a/tests/test_managers/test_softdelete_manager.py b/tests/test_managers/test_softdelete_manager.py index e4c64b5b..01fffc42 100644 --- a/tests/test_managers/test_softdelete_manager.py +++ b/tests/test_managers/test_softdelete_manager.py @@ -7,21 +7,21 @@ class CustomSoftDeleteManagerTests(TestCase): - def test_custom_manager_empty(self): + def test_custom_manager_empty(self) -> None: qs = CustomSoftDelete.available_objects.only_read() self.assertEqual(qs.count(), 0) - def test_custom_qs_empty(self): + def test_custom_qs_empty(self) -> None: qs = CustomSoftDelete.available_objects.all().only_read() self.assertEqual(qs.count(), 0) - def test_is_read(self): + def test_is_read(self) -> None: for is_read in [True, False, True, False]: CustomSoftDelete.available_objects.create(is_read=is_read) qs = CustomSoftDelete.available_objects.only_read() self.assertEqual(qs.count(), 2) - def test_is_read_removed(self): + def test_is_read_removed(self) -> None: for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]: CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed) qs = CustomSoftDelete.available_objects.only_read() diff --git a/tests/test_managers/test_status_manager.py b/tests/test_managers/test_status_manager.py index ef467a0f..a4b69b2f 100644 --- a/tests/test_managers/test_status_manager.py +++ b/tests/test_managers/test_status_manager.py @@ -10,10 +10,10 @@ class StatusManagerAddedTests(TestCase): - def test_manager_available(self): + def test_manager_available(self) -> None: self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager)) - def test_conflict_error(self): + def test_conflict_error(self) -> None: with self.assertRaises(ImproperlyConfigured): class ErrorModel(StatusModel): STATUS = ( diff --git a/tests/test_miscellaneous.py b/tests/test_miscellaneous.py index 8faaeddb..0fbfecc1 100644 --- a/tests/test_miscellaneous.py +++ b/tests/test_miscellaneous.py @@ -7,23 +7,23 @@ class MigrationsTests(TestCase): - def test_makemigrations(self): + def test_makemigrations(self) -> None: call_command('makemigrations', dry_run=True) class GetExcerptTests(TestCase): - def test_split(self): + def test_split(self) -> None: e = get_excerpt("some content\n\n\n\nsome more") self.assertEqual(e, 'some content\n') - def test_auto_split(self): + def test_auto_split(self) -> None: e = get_excerpt("para one\n\npara two\n\npara three") self.assertEqual(e, 'para one\n\npara two') - def test_middle_of_para(self): + def test_middle_of_para(self) -> None: e = get_excerpt("some text\n\nmore text") self.assertEqual(e, 'some text') - def test_middle_of_line(self): + def test_middle_of_line(self) -> None: e = get_excerpt("some text more text") self.assertEqual(e, "some text more text") diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py index a3a42839..f51e5eb7 100644 --- a/tests/test_models/test_deferred_fields.py +++ b/tests/test_models/test_deferred_fields.py @@ -6,7 +6,7 @@ class CustomDescriptorTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.instance = ModelWithCustomDescriptor.objects.create( custom_field='1', tracked_custom_field='1', @@ -14,7 +14,7 @@ def setUp(self): tracked_regular_field=1, ) - def test_custom_descriptor_works(self): + def test_custom_descriptor_works(self) -> None: instance = self.instance self.assertEqual(instance.custom_field, '1') self.assertEqual(instance.__dict__['custom_field'], 1) @@ -27,7 +27,7 @@ def test_custom_descriptor_works(self): self.assertEqual(instance.custom_field, '2') self.assertEqual(instance.__dict__['custom_field'], 2) - def test_deferred(self): + def test_deferred(self) -> None: instance = ModelWithCustomDescriptor.objects.only('id').get( pk=self.instance.pk) self.assertIn('custom_field', instance.get_deferred_fields()) diff --git a/tests/test_models/test_softdeletable_model.py b/tests/test_models/test_softdeletable_model.py index 38332d75..1f58f435 100644 --- a/tests/test_models/test_softdeletable_model.py +++ b/tests/test_models/test_softdeletable_model.py @@ -7,7 +7,7 @@ class SoftDeletableModelTests(TestCase): - def test_can_only_see_not_removed_entries(self): + def test_can_only_see_not_removed_entries(self) -> None: SoftDeletable.available_objects.create(name='a', is_removed=True) SoftDeletable.available_objects.create(name='b', is_removed=False) @@ -16,7 +16,7 @@ def test_can_only_see_not_removed_entries(self): self.assertEqual(queryset.count(), 1) self.assertEqual(queryset[0].name, 'b') - def test_instance_cannot_be_fully_deleted(self): + def test_instance_cannot_be_fully_deleted(self) -> None: instance = SoftDeletable.available_objects.create(name='a') instance.delete() @@ -24,7 +24,7 @@ def test_instance_cannot_be_fully_deleted(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 1) - def test_instance_cannot_be_fully_deleted_via_queryset(self): + def test_instance_cannot_be_fully_deleted_via_queryset(self) -> None: SoftDeletable.available_objects.create(name='a') SoftDeletable.available_objects.all().delete() @@ -32,12 +32,12 @@ def test_instance_cannot_be_fully_deleted_via_queryset(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 1) - def test_delete_instance_no_connection(self): + def test_delete_instance_no_connection(self) -> None: obj = SoftDeletable.available_objects.create(name='a') self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other') - def test_instance_purge(self): + def test_instance_purge(self) -> None: instance = SoftDeletable.available_objects.create(name='a') instance.delete(soft=False) @@ -45,11 +45,11 @@ def test_instance_purge(self): self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 0) - def test_instance_purge_no_connection(self): + def test_instance_purge_no_connection(self) -> None: instance = SoftDeletable.available_objects.create(name='a') self.assertRaises(ConnectionDoesNotExist, instance.delete, using='other', soft=False) - def test_deprecation_warning(self): + def test_deprecation_warning(self) -> None: self.assertWarns(DeprecationWarning, SoftDeletable.objects.all) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index 35c38089..8d6e89d9 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -9,12 +9,12 @@ class StatusModelTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.model = Status self.on_hold = Status.STATUS.on_hold self.active = Status.STATUS.active - def test_created(self): + def test_created(self) -> None: with time_machine.travel(datetime(2016, 1, 1)): c1 = self.model.objects.create() self.assertTrue(c1.status_changed, datetime(2016, 1, 1)) @@ -23,7 +23,7 @@ def test_created(self): self.assertEqual(self.model.active.count(), 2) self.assertEqual(self.model.deleted.count(), 0) - def test_modification(self): + def test_modification(self) -> None: t1 = self.model.objects.create() date_created = t1.status_changed t1.status = self.on_hold @@ -39,7 +39,7 @@ def test_modification(self): t1.save() self.assertTrue(t1.status_changed > date_active_again) - def test_save_with_update_fields_overrides_status_changed_provided(self): + def test_save_with_update_fields_overrides_status_changed_provided(self) -> None: ''' Tests if the save method updated status_changed field accordingly when update_fields is used as an argument @@ -54,7 +54,7 @@ def test_save_with_update_fields_overrides_status_changed_provided(self): self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_save_with_update_fields_overrides_status_changed_not_provided(self): + def test_save_with_update_fields_overrides_status_changed_not_provided(self) -> None: ''' Tests if the save method updated status_changed field accordingly when update_fields is used as an argument @@ -71,7 +71,7 @@ def test_save_with_update_fields_overrides_status_changed_not_provided(self): class StatusModelPlainTupleTests(StatusModelTests): - def setUp(self): + def setUp(self) -> None: self.model = StatusPlainTuple self.on_hold = StatusPlainTuple.STATUS[2][0] self.active = StatusPlainTuple.STATUS[0][0] @@ -79,7 +79,7 @@ def setUp(self): class StatusModelDefaultManagerTests(TestCase): - def test_default_manager_is_not_status_model_generated_ones(self): + def test_default_manager_is_not_status_model_generated_ones(self) -> None: # Regression test for GH-251 # The logic behind order for managers seems to have changed in Django 1.10 # and affects default manager. diff --git a/tests/test_models/test_timeframed_model.py b/tests/test_models/test_timeframed_model.py index 3cf3d730..246b3992 100644 --- a/tests/test_models/test_timeframed_model.py +++ b/tests/test_models/test_timeframed_model.py @@ -12,36 +12,36 @@ class TimeFramedModelTests(TestCase): - def setUp(self): + def setUp(self) -> None: self.now = datetime.now() - def test_not_yet_begun(self): + def test_not_yet_begun(self) -> None: TimeFrame.objects.create(start=self.now + timedelta(days=2)) self.assertEqual(TimeFrame.timeframed.count(), 0) - def test_finished(self): + def test_finished(self) -> None: TimeFrame.objects.create(end=self.now - timedelta(days=1)) self.assertEqual(TimeFrame.timeframed.count(), 0) - def test_no_end(self): + def test_no_end(self) -> None: TimeFrame.objects.create(start=self.now - timedelta(days=10)) self.assertEqual(TimeFrame.timeframed.count(), 1) - def test_no_start(self): + def test_no_start(self) -> None: TimeFrame.objects.create(end=self.now + timedelta(days=2)) self.assertEqual(TimeFrame.timeframed.count(), 1) - def test_within_range(self): + def test_within_range(self) -> None: TimeFrame.objects.create(start=self.now - timedelta(days=1), end=self.now + timedelta(days=1)) self.assertEqual(TimeFrame.timeframed.count(), 1) class TimeFrameManagerAddedTests(TestCase): - def test_manager_available(self): + def test_manager_available(self) -> None: self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager)) - def test_conflict_error(self): + def test_conflict_error(self) -> None: with self.assertRaises(ImproperlyConfigured): class ErrorModel(TimeFramedModel): timeframed = models.BooleanField() diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py index 193af3e6..1cc8d511 100644 --- a/tests/test_models/test_timestamped_model.py +++ b/tests/test_models/test_timestamped_model.py @@ -9,19 +9,19 @@ class TimeStampedModelTests(TestCase): - def test_created(self): + def test_created(self) -> None: with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc)) - def test_created_sets_modified(self): + def test_created_sets_modified(self) -> None: ''' Ensure that on creation that modified is set exactly equal to created. ''' t1 = TimeStamp.objects.create() self.assertEqual(t1.created, t1.modified) - def test_modified(self): + def test_modified(self) -> None: with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() @@ -30,7 +30,7 @@ def test_modified(self): self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc)) - def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self): + def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self) -> None: """ Setting the created date when first creating an object should be permissible. @@ -40,7 +40,7 @@ def test_overriding_created_via_object_creation_also_uses_creation_date_for_modi self.assertEqual(t1.created, different_date) self.assertEqual(t1.modified, different_date) - def test_overriding_modified_via_object_creation(self): + def test_overriding_modified_via_object_creation(self) -> None: """ Setting the modified date explicitly should be possible when first creating an object, but not thereafter. @@ -50,7 +50,7 @@ def test_overriding_modified_via_object_creation(self): self.assertEqual(t1.modified, different_date) self.assertNotEqual(t1.created, different_date) - def test_overriding_created_after_object_created(self): + def test_overriding_created_after_object_created(self) -> None: """ The created date may be changed post-create """ @@ -60,7 +60,7 @@ def test_overriding_created_after_object_created(self): t1.save() self.assertEqual(t1.created, different_date) - def test_overriding_modified_after_object_created(self): + def test_overriding_modified_after_object_created(self) -> None: """ The modified date should always be updated when the object is saved, regardless of attempts to change it. @@ -71,7 +71,7 @@ def test_overriding_modified_after_object_created(self): t1.save() self.assertNotEqual(t1.modified, different_date) - def test_overrides_using_save(self): + def test_overrides_using_save(self) -> None: """ The first time an object is saved, allow modification of both created and modified fields. @@ -92,7 +92,7 @@ def test_overrides_using_save(self): self.assertNotEqual(t1.modified, different_date2) self.assertNotEqual(t1.modified, different_date) - def test_save_with_update_fields_overrides_modified_provided_within_a(self): + def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> None: """ Tests if the save method updated modified field accordingly when update_fields is used as an argument @@ -113,7 +113,7 @@ def test_save_with_update_fields_overrides_modified_provided_within_a(self): t1.save(update_fields=update_fields) self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_save_is_skipped_for_empty_update_fields_iterable(self): + def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None: tests = ( [], # list (), # tuple @@ -133,7 +133,7 @@ def test_save_is_skipped_for_empty_update_fields_iterable(self): self.assertEqual(t1.test_field, 0) self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc)) - def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self): + def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self) -> None: with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): t1 = TimeStamp.objects.create() @@ -142,7 +142,7 @@ def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(s self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) - def test_model_inherit_timestampmodel_and_statusmodel(self): + def test_model_inherit_timestampmodel_and_statusmodel(self) -> None: with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): t1 = TimeStampWithStatusModel.objects.create() diff --git a/tests/test_models/test_uuid_model.py b/tests/test_models/test_uuid_model.py index 62dcae3f..d9d71d9d 100644 --- a/tests/test_models/test_uuid_model.py +++ b/tests/test_models/test_uuid_model.py @@ -7,13 +7,13 @@ class UUIDFieldTests(TestCase): - def test_uuid_model_with_uuid_field_as_primary_key(self): + def test_uuid_model_with_uuid_field_as_primary_key(self) -> None: instance = CustomUUIDModel() instance.save() self.assertEqual(instance.id.__class__.__name__, 'UUID') self.assertEqual(instance.id, instance.pk) - def test_uuid_model_with_uuid_field_as_not_primary_key(self): + def test_uuid_model_with_uuid_field_as_not_primary_key(self) -> None: instance = CustomNotPrimaryUUIDModel() instance.save() self.assertEqual(instance.uuid.__class__.__name__, 'UUID') From c83ef89bc9b8f50e48697bb53b0f8696f01d4c11 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 22 Mar 2024 18:38:11 +0100 Subject: [PATCH 15/28] Annotate test helpers --- tests/fields.py | 13 ++++++++----- tests/models.py | 41 +++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/tests/fields.py b/tests/fields.py index df074f19..d57960cf 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1,9 +1,12 @@ from __future__ import annotations +from typing import Any + from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper -def mutable_from_db(value): +def mutable_from_db(value: object) -> Any: if value == '': return None try: @@ -14,7 +17,7 @@ def mutable_from_db(value): return value -def mutable_to_db(value): +def mutable_to_db(value: object) -> str: if value is None: return '' if isinstance(value, list): @@ -23,12 +26,12 @@ def mutable_to_db(value): class MutableField(models.TextField): - def to_python(self, value): + def to_python(self, value: object) -> Any: return mutable_from_db(value) - def from_db_value(self, value, expression, connection): + def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any: return mutable_from_db(value) - def get_db_prep_save(self, value, connection): + def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str: value = super().get_db_prep_save(value, connection) return mutable_to_db(value) diff --git a/tests/models.py b/tests/models.py index 58c97c13..f06d15b9 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import ClassVar, TypeVar +from typing import Any, ClassVar, TypeVar, overload from django.db import models from django.db.models import Manager +from django.db.models.query import QuerySet from django.db.models.query_utils import DeferredAttribute from django.utils.translation import gettext_lazy as _ @@ -45,7 +46,7 @@ class InheritanceManagerTestParent(models.Model): on_delete=models.CASCADE) objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager() - def __str__(self): + def __str__(self) -> str: return "{}({})".format( self.__class__.__name__[len('InheritanceManagerTest'):], self.pk, @@ -186,7 +187,7 @@ class Post(models.Model): public: ClassVar[QueryManager[Post]] = QueryManager(published=True) public_confirmed: ClassVar[QueryManager[Post]] = QueryManager( models.Q(published=True) & models.Q(confirmed=True)) - public_reversed: QueryManager[Post] = QueryManager( + public_reversed: ClassVar[QueryManager[Post]] = QueryManager( published=True).order_by("-order") class Meta: @@ -206,7 +207,6 @@ class Meta: class AbstractTracked(models.Model): - number: models.IntegerField class Meta: abstract = True @@ -219,7 +219,7 @@ class Tracked(models.Model): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ No-op save() to ensure that FieldTracker.patch_save() works. """ super().save(*args, **kwargs) @@ -231,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Automatically add "modified" to update_fields.""" update_fields = kwargs.get('update_fields') if update_fields is not None: @@ -249,7 +249,7 @@ class TrackedFK(models.Model): class TrackedAbstract(AbstractTracked): name = models.CharField(max_length=20) - number = models.IntegerField() # type: ignore[assignment] + number = models.IntegerField() mutable = MutableField(default=None) tracker = ModelTracker() @@ -266,7 +266,7 @@ class TrackedNonFieldAttr(models.Model): number = models.FloatField() @property - def rounded(self): + def rounded(self) -> int | None: return round(self.number) if self.number is not None else None tracker = FieldTracker(fields=['rounded']) @@ -355,43 +355,52 @@ class SoftDeletable(SoftDeletableModel): all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager() -class CustomSoftDeleteQuerySet(SoftDeletableQuerySet): - def only_read(self): +class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]): + def only_read(self) -> QuerySet[ModelT]: return self.filter(is_read=True) class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) - available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc] + available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() class StringyDescriptor: """ Descriptor that returns a string version of the underlying integer value. """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls=None): + @overload + def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor: + ... + + @overload + def __get__(self, obj: models.Model, cls: type[models.Model]) -> str: + ... + + def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str: if obj is None: return self if self.name in obj.get_deferred_fields(): # This queries the database, and sets the value on the instance. + assert cls is not None fields_map = {f.name: f for f in cls._meta.fields} field = fields_map[self.name] DeferredAttribute(field=field).__get__(obj, cls) return str(obj.__dict__[self.name]) - def __set__(self, obj, value): + def __set__(self, obj: object, value: str) -> None: obj.__dict__[self.name] = int(value) - def __delete__(self, obj): + def __delete__(self, obj: object) -> None: del obj.__dict__[self.name] class CustomDescriptorField(models.IntegerField): - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: super().contribute_to_class(cls, name, *args, **kwargs) setattr(cls, name, StringyDescriptor(name)) From 9d3940a6f244209107ea4133717486bce5389a8d Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 27 Mar 2024 17:32:12 +0100 Subject: [PATCH 16/28] Annotate the `test_choices` module This required a bit of refactoring to get the type of `STATUS` correct for each suite. There are two cases which I decided not to support in the type system: - passing a list instead of a tuple when defining an option group - `in` checks using a data type that doesn't match the choices --- tests/test_choices.py | 133 +++++++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 55 deletions(-) diff --git a/tests/test_choices.py b/tests/test_choices.py index d653cd71..973b5968 100644 --- a/tests/test_choices.py +++ b/tests/test_choices.py @@ -1,77 +1,88 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Generic, TypeVar + +import pytest from django.test import TestCase from model_utils import Choices +T = TypeVar("T") -class ChoicesTests(TestCase): - def setUp(self) -> None: - self.STATUS = Choices('DRAFT', 'PUBLISHED') - def test_getattr(self) -> None: - self.assertEqual(self.STATUS.DRAFT, 'DRAFT') +class ChoicesTestsMixin(Generic[T]): - def test_indexing(self) -> None: - self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') + STATUS: Choices[T] - def test_iteration(self) -> None: - self.assertEqual(tuple(self.STATUS), - (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) - - def test_reversed(self) -> None: - self.assertEqual(tuple(reversed(self.STATUS)), - (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) + def test_getattr(self) -> None: + assert self.STATUS.DRAFT == 'DRAFT' def test_len(self) -> None: - self.assertEqual(len(self.STATUS), 2) + assert len(self.STATUS) == 2 def test_repr(self) -> None: - self.assertEqual(repr(self.STATUS), "Choices" + repr(( + assert repr(self.STATUS) == "Choices" + repr(( ('DRAFT', 'DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), - ))) + )) def test_wrong_length_tuple(self) -> None: - with self.assertRaises(ValueError): - Choices(('a',)) - - def test_contains_value(self) -> None: - self.assertTrue('PUBLISHED' in self.STATUS) - self.assertTrue('DRAFT' in self.STATUS) - - def test_doesnt_contain_value(self) -> None: - self.assertFalse('UNPUBLISHED' in self.STATUS) + with pytest.raises(ValueError): + Choices(('a',)) # type: ignore[arg-type] def test_deepcopy(self) -> None: import copy - self.assertEqual(list(self.STATUS), - list(copy.deepcopy(self.STATUS))) + assert list(self.STATUS) == list(copy.deepcopy(self.STATUS)) def test_equality(self) -> None: - self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) + assert self.STATUS == Choices('DRAFT', 'PUBLISHED') def test_inequality(self) -> None: - self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) - self.assertNotEqual(self.STATUS, Choices('DRAFT')) + assert self.STATUS != ['DRAFT', 'PUBLISHED'] + assert self.STATUS != Choices('DRAFT') def test_composability(self) -> None: - self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) - self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) - self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) + assert Choices('DRAFT') + Choices('PUBLISHED') == self.STATUS + assert Choices('DRAFT') + ('PUBLISHED',) == self.STATUS + assert ('DRAFT',) + Choices('PUBLISHED') == self.STATUS def test_option_groups(self) -> None: - c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) - self.assertEqual( - list(c), - [ - ('group a', [('one', 'one'), ('two', 'two')]), - ('group b', [('three', 'three')]), - ], - ) + # Note: The implementation accepts any kind of sequence, but the type system can only + # track per-index types for tuples. + if TYPE_CHECKING: + c = Choices(('group a', ['one', 'two']), ('group b', ('three',))) + else: + c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) + assert list(c) == [ + ('group a', [('one', 'one'), ('two', 'two')]), + ('group b', [('three', 'three')]), + ] + + +class ChoicesTests(TestCase, ChoicesTestsMixin[str]): + def setUp(self) -> None: + self.STATUS = Choices('DRAFT', 'PUBLISHED') + + def test_indexing(self) -> None: + self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') + + def test_iteration(self) -> None: + self.assertEqual(tuple(self.STATUS), + (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + + def test_reversed(self) -> None: + self.assertEqual(tuple(reversed(self.STATUS)), + (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) + + def test_contains_value(self) -> None: + self.assertTrue('PUBLISHED' in self.STATUS) + self.assertTrue('DRAFT' in self.STATUS) + + def test_doesnt_contain_value(self) -> None: + self.assertFalse('UNPUBLISHED' in self.STATUS) -class LabelChoicesTests(ChoicesTests): +class LabelChoicesTests(TestCase, ChoicesTestsMixin[str]): def setUp(self) -> None: self.STATUS = Choices( ('DRAFT', 'is draft'), @@ -157,10 +168,16 @@ def test_composability(self) -> None: ) def test_option_groups(self) -> None: - c = Choices( - ('group a', [(1, 'one'), (2, 'two')]), - ['group b', ((3, 'three'),)] - ) + if TYPE_CHECKING: + c = Choices[int]( + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', ((3, 'three'),)) + ) + else: + c = Choices( + ('group a', [(1, 'one'), (2, 'two')]), + ['group b', ((3, 'three'),)] + ) self.assertEqual( list(c), [ @@ -170,7 +187,7 @@ def test_option_groups(self) -> None: ) -class IdentifierChoicesTests(ChoicesTests): +class IdentifierChoicesTests(TestCase, ChoicesTestsMixin[int]): def setUp(self) -> None: self.STATUS = Choices( (0, 'DRAFT', 'is draft'), @@ -216,10 +233,10 @@ def test_doesnt_contain_value(self) -> None: self.assertFalse(3 in self.STATUS) def test_doesnt_contain_display_value(self) -> None: - self.assertFalse('is draft' in self.STATUS) + self.assertFalse('is draft' in self.STATUS) # type: ignore[operator] def test_doesnt_contain_python_attr(self) -> None: - self.assertFalse('PUBLISHED' in self.STATUS) + self.assertFalse('PUBLISHED' in self.STATUS) # type: ignore[operator] def test_equality(self) -> None: self.assertEqual(self.STATUS, Choices( @@ -268,10 +285,16 @@ def test_composability(self) -> None: ) def test_option_groups(self) -> None: - c = Choices( - ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), - ['group b', ((3, 'THREE', 'three'),)] - ) + if TYPE_CHECKING: + c = Choices[int]( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ('group b', ((3, 'THREE', 'three'),)) + ) + else: + c = Choices( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ['group b', ((3, 'THREE', 'three'),)] + ) self.assertEqual( list(c), [ @@ -284,7 +307,7 @@ def test_option_groups(self) -> None: class SubsetChoicesTest(TestCase): def setUp(self) -> None: - self.choices = Choices( + self.choices = Choices[int]( (0, 'a', 'A'), (1, 'b', 'B'), ) From e4c88103f5e8db1f68a7ecf6c04b425dff5c0feb Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 10 Apr 2024 18:27:56 +0200 Subject: [PATCH 17/28] Add pytest to type check dependencies We also type check the unit tests and the unit tests now import functionality from pytest. --- requirements-mypy.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-mypy.txt b/requirements-mypy.txt index 65a16c5c..a7b6bb92 100644 --- a/requirements-mypy.txt +++ b/requirements-mypy.txt @@ -1,2 +1,3 @@ mypy==1.10.0 django-stubs==5.0.2 +pytest==7.4.3 From 949d110d043e8a38ee8d0127c5309174a7d8624a Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 27 Mar 2024 18:32:59 +0100 Subject: [PATCH 18/28] Annotate the `test_models` package --- tests/test_models/test_status_model.py | 2 ++ tests/test_models/test_timestamped_model.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index 8d6e89d9..9c9a4a42 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -9,6 +9,8 @@ class StatusModelTests(TestCase): + model: type[Status] | type[StatusPlainTuple] + def setUp(self) -> None: self.model = Status self.on_hold = Status.STATUS.on_hold diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py index 1cc8d511..66979417 100644 --- a/tests/test_models/test_timestamped_model.py +++ b/tests/test_models/test_timestamped_model.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterable from datetime import datetime, timedelta, timezone import time_machine @@ -114,7 +115,7 @@ def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> N self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None: - tests = ( + tests: Iterable[Iterable[str]] = ( [], # list (), # tuple set(), # set From 7d6cad0200dba7787707a1396c36fc0fc0036914 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 29 Mar 2024 17:12:32 +0100 Subject: [PATCH 19/28] Annotate `test_field_tracker` module --- tests/test_fields/test_field_tracker.py | 93 ++++++++++++------- tests/test_fields/test_monitor_field.py | 2 +- tests/test_fields/test_urlsafe_token_field.py | 6 +- 3 files changed, 66 insertions(+), 35 deletions(-) diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 91a84c9d..4dc2dc3c 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + from django.core.cache import cache from django.core.exceptions import FieldError from django.db import models @@ -7,7 +9,7 @@ from django.test import TestCase from model_utils import FieldTracker -from model_utils.tracker import DescriptorWrapper +from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker from tests.models import ( InheritedModelTracked, InheritedTracked, @@ -26,12 +28,18 @@ TrackerTimeStamped, ) +if TYPE_CHECKING: + MixinBase = TestCase +else: + MixinBase = object + -class FieldTrackerTestCase(TestCase): +class FieldTrackerMixin(MixinBase): - tracker = None + tracker: FieldInstanceTracker + instance: models.Model - def assertHasChanged(self, *, tracker=None, **kwargs): + def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): @@ -41,29 +49,35 @@ def assertHasChanged(self, *, tracker=None, **kwargs): else: self.assertEqual(tracker.has_changed(field), value) - def assertPrevious(self, *, tracker=None, **kwargs): + def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): self.assertEqual(tracker.previous(field), value) - def assertChanged(self, *, tracker=None, **kwargs): + def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.changed(), kwargs) - def assertCurrent(self, *, tracker=None, **kwargs): + def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.current(), kwargs) - def update_instance(self, **kwargs): + def update_instance(self, **kwargs: Any) -> None: for field, value in kwargs.items(): setattr(self.instance, field, value) self.instance.save() -class FieldTrackerCommonTests: +class FieldTrackerCommonMixin(FieldTrackerMixin): + + instance: ( + Tracked | TrackedNotDefault | TrackedMultiple + | ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple + | TrackedAbstract + ) def test_pre_save_previous(self) -> None: self.assertPrevious(name=None, number=None) @@ -72,9 +86,10 @@ def test_pre_save_previous(self) -> None: self.assertPrevious(name=None, number=None) -class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): +class FieldTrackerTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = Tracked + tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked + instance: Tracked | ModelTracked | TrackedAbstract def setUp(self) -> None: self.instance = self.tracked_class() @@ -219,6 +234,7 @@ def test_with_deferred(self) -> None: self.instance.number = 1 self.instance.save() item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertTrue(item.get_deferred_fields()) # has_changed() returns False for deferred fields, without un-deferring them. @@ -234,6 +250,7 @@ def test_with_deferred(self) -> None: # examining a deferred field un-defers it item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertEqual(item.number, 1) self.assertTrue('number' not in item.get_deferred_fields()) self.assertEqual(item.tracker.previous('number'), 1) @@ -252,6 +269,7 @@ def test_with_deferred(self) -> None: if self.tracked_class == Tracked: item = self.tracked_class.objects.only('name').first() + assert item is not None item.number = 2 # previous() fetches correct value from database after deferred field is assigned @@ -278,10 +296,10 @@ def test_with_deferred_fields_access_multiple(self) -> None: instance.name -class FieldTrackedModelCustomTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedNotDefault + tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault + instance: TrackedNotDefault | ModelTrackedNotDefault def setUp(self) -> None: self.instance = self.tracked_class() @@ -358,9 +376,10 @@ def test_update_fields(self) -> None: self.assertChanged() -class FieldTrackedModelAttributeTests(FieldTrackerTestCase): +class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase): tracked_class = TrackedNonFieldAttr + instance: TrackedNonFieldAttr def setUp(self) -> None: self.instance = self.tracked_class() @@ -409,10 +428,10 @@ def test_current(self) -> None: self.assertCurrent(rounded=8) -class FieldTrackedModelMultiTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedMultiple + tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple + instance: TrackedMultiple | ModelTrackedMultiple def setUp(self) -> None: self.instance = self.tracked_class() @@ -501,10 +520,11 @@ def test_current(self) -> None: self.assertCurrent(tracker=self.trackers[1], number=8) -class FieldTrackerForeignKeyTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyMixin(FieldTrackerMixin): - fk_class: type[models.Model] = Tracked - tracked_class: type[models.Model] = TrackedFK + fk_class: type[Tracked | ModelTracked] + tracked_class: type[TrackedFK | ModelTrackedFK] + instance: TrackedFK | ModelTrackedFK def setUp(self) -> None: self.old_fk = self.fk_class.objects.create(number=8) @@ -543,11 +563,18 @@ def test_custom_without_id(self) -> None: self.assertCurrent(fk=self.instance.fk_id) -class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): + + fk_class = Tracked + tracked_class = TrackedFK + + +class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase): """Test that using `prefetch_related` on a tracked field does not raise a ValueError.""" fk_class = Tracked tracked_class = TrackedFK + instance: TrackedFK def setUp(self) -> None: model_tracked = self.fk_class.objects.create(name="", number=0) @@ -566,10 +593,11 @@ def test_custom_without_id(self) -> None: self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) -class FieldTrackerTimeStampedTests(FieldTrackerTestCase): +class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase): fk_class = Tracked tracked_class = TrackerTimeStamped + instance: TrackerTimeStamped def setUp(self) -> None: self.instance = self.tracked_class.objects.create(name='old', number=1) @@ -605,9 +633,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests): tracked_class = InheritedTrackedFK -class FieldTrackerFileFieldTests(FieldTrackerTestCase): +class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase): tracked_class = TrackedFileField + instance: TrackedFileField def setUp(self) -> None: self.instance = self.tracked_class() @@ -629,7 +658,7 @@ def test_saved_data_without_instance(self) -> None: self.assertEqual(self.tracker.saved_data, {}) self.update_instance(some_file=self.some_file) field_file_copy = self.tracker.saved_data.get('some_file') - self.assertIsNotNone(field_file_copy) + assert field_file_copy is not None self.assertEqual(field_file_copy.__getstate__().get('instance'), None) self.assertEqual(self.instance.some_file.instance, self.instance) self.assertIsInstance(self.instance.some_file, FieldFile) @@ -730,7 +759,8 @@ def test_current(self) -> None: class ModelTrackerTests(FieldTrackerTests): - tracked_class: type[models.Model] = ModelTracked + tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked + instance: ModelTracked def test_cache_compatible(self) -> None: cache.set('key', self.instance) @@ -846,10 +876,11 @@ def test_pre_save_changed(self) -> None: self.assertChanged() -class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): +class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): fk_class = ModelTracked tracked_class = ModelTrackedFK + instance: ModelTrackedFK def test_custom_without_id(self) -> None: with self.assertNumQueries(2): @@ -886,11 +917,11 @@ def setUp(self) -> None: self.instance = Tracked.objects.create(number=1) self.tracker = self.instance.tracker - def assertChanged(self, *fields): + def assertChanged(self, *fields: str) -> None: for f in fields: self.assertTrue(self.tracker.has_changed(f)) - def assertNotChanged(self, *fields): + def assertNotChanged(self, *fields: str) -> None: for f in fields: self.assertFalse(self.tracker.has_changed(f)) @@ -921,7 +952,7 @@ def test_context_manager_fields(self) -> None: def test_tracker_decorator(self) -> None: @Tracked.tracker - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' self.assertChanged('name') @@ -932,7 +963,7 @@ def tracked_method(obj): def test_tracker_decorator_fields(self) -> None: @Tracked.tracker(fields=['name']) - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' obj.number += 1 self.assertChanged('name', 'number') diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py index 9c9ba841..19ed9027 100644 --- a/tests/test_fields/test_monitor_field.py +++ b/tests/test_fields/test_monitor_field.py @@ -34,7 +34,7 @@ def test_double_save(self) -> None: def test_no_monitor_arg(self) -> None: with self.assertRaises(TypeError): - MonitorField() + MonitorField() # type: ignore[call-arg] def test_monitor_default_is_none_when_nullable(self) -> None: self.assertIsNone(self.instance.name_changed_nullable) diff --git a/tests/test_fields/test_urlsafe_token_field.py b/tests/test_fields/test_urlsafe_token_field.py index 6146fe61..72bbcda8 100644 --- a/tests/test_fields/test_urlsafe_token_field.py +++ b/tests/test_fields/test_urlsafe_token_field.py @@ -31,7 +31,7 @@ def test_factory_default(self) -> None: def test_factory_not_callable(self) -> None: with self.assertRaises(TypeError): - UrlsafeTokenField(factory='INVALID') + UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type] def test_get_default(self) -> None: field = UrlsafeTokenField() @@ -57,8 +57,8 @@ def test_no_default_param(self) -> None: self.assertIs(field.default, NOT_PROVIDED) def test_deconstruct(self) -> None: - def test_factory() -> None: - pass + def test_factory(max_length: int) -> str: + assert False instance = UrlsafeTokenField(factory=test_factory) name, path, args, kwargs = instance.deconstruct() new_instance = UrlsafeTokenField(*args, **kwargs) From ecc7312bf45c0d37831e2d8582abed4e4a4baa0e Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Thu, 4 Apr 2024 08:35:45 +0200 Subject: [PATCH 20/28] Override signature of query-returning methods in `InheritanceManager` --- model_utils/managers.py | 69 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/model_utils/managers.py b/model_utils/managers.py index c4631d2a..1a870d10 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -223,9 +223,78 @@ class InheritanceManagerMixin(Generic[ModelT]): _queryset_class = InheritanceQuerySet if TYPE_CHECKING: + from collections.abc import Sequence + + def none(self) -> InheritanceQuerySet[ModelT]: + ... + + def all(self) -> InheritanceQuerySet[ModelT]: + ... + def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: ... + def exclude(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def complex_filter(self, filter_obj: Any) -> InheritanceQuerySet[ModelT]: + ... + + def union(self, *other_qs: Any, all: bool = ...) -> InheritanceQuerySet[ModelT]: + ... + + def intersection(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def difference(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def select_for_update( + self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ... + ) -> InheritanceQuerySet[ModelT]: + ... + + def select_related(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... + + def prefetch_related(self, *lookups: Any) -> InheritanceQuerySet[ModelT]: + ... + + def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def alias(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... + + def order_by(self, *field_names: Any) -> InheritanceQuerySet[ModelT]: + ... + + def distinct(self, *field_names: Any) -> InheritanceQuerySet[ModelT]: + ... + + def extra( + self, + select: dict[str, Any] | None = ..., + where: list[str] | None = ..., + params: list[Any] | None = ..., + tables: list[str] | None = ..., + order_by: Sequence[str] | None = ..., + select_params: Sequence[Any] | None = ..., + ) -> InheritanceQuerySet[Any]: + ... + + def reverse(self) -> InheritanceQuerySet[ModelT]: + ... + + def defer(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... + + def only(self, *fields: Any) -> InheritanceQuerySet[ModelT]: + ... + + def using(self, alias: str | None) -> InheritanceQuerySet[ModelT]: + ... + def get_queryset(self) -> InheritanceQuerySet[ModelT]: model: type[ModelT] = self.model # type: ignore[attr-defined] return self._queryset_class(model) From 3c314234ff339e43b1633bfa39fee0ff400ed4be Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 16 Apr 2024 05:16:04 +0200 Subject: [PATCH 21/28] Drop uses of `JoinManager` from the tests The `JoinManager` class is deprecated and not annotated. --- tests/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models.py b/tests/models.py index f06d15b9..4d345050 100644 --- a/tests/models.py +++ b/tests/models.py @@ -12,7 +12,7 @@ from model_utils.fields import MonitorField, SplitField, StatusField, UUIDField from model_utils.managers import ( InheritanceManager, - JoinManager, + JoinQueryset, QueryManager, SoftDeletableManager, SoftDeletableQuerySet, @@ -416,7 +416,7 @@ class ModelWithCustomDescriptor(models.Model): class BoxJoinModel(models.Model): name = models.CharField(max_length=32) - objects: ClassVar[JoinManager[BoxJoinModel]] = JoinManager() + objects = JoinQueryset.as_manager() class JoinItemForeignKey(models.Model): @@ -426,7 +426,7 @@ class JoinItemForeignKey(models.Model): null=True, on_delete=models.CASCADE ) - objects: ClassVar[JoinManager[JoinItemForeignKey]] = JoinManager() + objects = JoinQueryset.as_manager() class CustomUUIDModel(UUIDModel): From 8f0b4ee2f82e4b208523ca88ed681c316fd67818 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 16 Apr 2024 06:36:09 +0200 Subject: [PATCH 22/28] Suppress mypy errors in field tracker tests I can't find a way to inform mypy of the actual types without duplicating a lot of test code. --- tests/test_fields/test_field_tracker.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 4dc2dc3c..81db1ec4 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -96,7 +96,8 @@ def setUp(self) -> None: self.tracker = self.instance.tracker def test_descriptor(self) -> None: - self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) + tracker = self.tracked_class.tracker + self.assertTrue(isinstance(tracker, FieldTracker)) def test_pre_save_changed(self) -> None: self.assertChanged(name=None) @@ -528,14 +529,14 @@ class FieldTrackerForeignKeyMixin(FieldTrackerMixin): def setUp(self) -> None: self.old_fk = self.fk_class.objects.create(number=8) - self.instance = self.tracked_class.objects.create(fk=self.old_fk) + self.instance = self.tracked_class.objects.create(fk=self.old_fk) # type: ignore[misc] def test_default(self) -> None: self.tracker = self.instance.tracker self.assertChanged() self.assertPrevious() self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) @@ -545,7 +546,7 @@ def test_custom(self) -> None: self.assertChanged() self.assertPrevious() self.assertCurrent(fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(fk_id=self.instance.fk_id) @@ -557,7 +558,7 @@ def test_custom_without_id(self) -> None: self.assertChanged() self.assertPrevious() self.assertCurrent(fk=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment] self.assertChanged(fk=self.old_fk.id) self.assertPrevious(fk=self.old_fk.id) self.assertCurrent(fk=self.instance.fk_id) From 23a756e13e5961d6cbce2507ba1965a7423485bd Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 16 Apr 2024 08:24:07 +0200 Subject: [PATCH 23/28] Annotate `test_fields` package --- tests/test_fields/test_split_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_fields/test_split_field.py b/tests/test_fields/test_split_field.py index 9dd27a17..bc94f48a 100644 --- a/tests/test_fields/test_split_field.py +++ b/tests/test_fields/test_split_field.py @@ -49,7 +49,7 @@ def test_assign_to_content(self) -> None: def test_assign_to_excerpt(self) -> None: with self.assertRaises(AttributeError): - self.post.body.excerpt = 'this should fail' + self.post.body.excerpt = 'this should fail' # type: ignore[misc] def test_access_via_class(self) -> None: with self.assertRaises(AttributeError): From f4653f08e521d441dbf7e934dc0e8946b808bf12 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 17 Apr 2024 16:14:33 +0200 Subject: [PATCH 24/28] Preserve tracked function's return type in `FieldTracker` --- model_utils/tracker.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index d89f4a8b..266ff3b8 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -296,14 +296,30 @@ def __init__(self, fields: Iterable[str] | None = None): # finalize_class() will replace None; pretend it is never None. self.fields = cast(Iterable[str], fields) + @overload + def __call__( + self, + func: None = None, + fields: Iterable[str] | None = None + ) -> Callable[[Callable[..., T]], Callable[..., T]]: + ... + + @overload + def __call__( + self, + func: Callable[..., T], + fields: Iterable[str] | None = None + ) -> Callable[..., T]: + ... + def __call__( self, - func: Callable | None = None, + func: Callable[..., T] | None = None, fields: Iterable[str] | None = None - ) -> Any: - def decorator(f: Callable) -> Callable: + ) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]: + def decorator(f: Callable[..., T]) -> Callable[..., T]: @wraps(f) - def inner(obj: models.Model, *args: object, **kwargs: object) -> object: + def inner(obj: models.Model, *args: object, **kwargs: object) -> T: tracker = getattr(obj, self.attname) field_list = tracker.fields if fields is None else fields with tracker(*field_list): From 1db7d6ba3380a0a7830a02117280465229ab826d Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Wed, 17 Apr 2024 17:58:42 +0200 Subject: [PATCH 25/28] Fix type generics in `InheritanceIterable` --- model_utils/managers.py | 78 +++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/model_utils/managers.py b/model_utils/managers.py index 1a870d10..899b9887 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -7,6 +7,7 @@ from django.db import connection, models from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel +from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) @@ -15,44 +16,51 @@ from collections.abc import Iterator from django.db.models.query import BaseIterable - from django.db.models.query import ModelIterable as ModelIterableGeneric - from django.db.models.query import QuerySet as QuerySetGeneric - ModelIterable = ModelIterableGeneric[ModelT] - QuerySet = QuerySetGeneric[ModelT] -else: - from django.db.models.query import ModelIterable, QuerySet - - -class InheritanceIterable(ModelIterable): - def __iter__(self) -> Iterator[ModelT]: - queryset = self.queryset - iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset) - if hasattr(queryset, 'subclasses'): - assert hasattr(queryset, '_get_sub_obj_recurse') - extras = tuple(queryset.query.extra.keys()) - # sort the subclass names longest first, - # so with 'a' and 'a__b' it goes as deep as possible - subclasses = sorted(queryset.subclasses, key=len, reverse=True) - for obj in iter: - sub_obj = None - for s in subclasses: - sub_obj = queryset._get_sub_obj_recurse(obj, s) - if sub_obj: - break - if not sub_obj: - sub_obj = obj - - if hasattr(queryset, '_annotated'): - for k in queryset._annotated: - setattr(sub_obj, k, getattr(obj, k)) - - for k in extras: + +def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]: + iter: ModelIterable[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + assert hasattr(queryset, '_get_sub_obj_recurse') + extras = tuple(queryset.query.extra.keys()) + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(queryset.subclasses, key=len, reverse=True) + for obj in iter: + sub_obj = None + for s in subclasses: + sub_obj = queryset._get_sub_obj_recurse(obj, s) + if sub_obj: + break + if not sub_obj: + sub_obj = obj + + if hasattr(queryset, '_annotated'): + for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) - yield sub_obj - else: - yield from iter + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj + else: + yield from iter + + +if TYPE_CHECKING: + class InheritanceIterable(ModelIterable[ModelT]): + queryset: QuerySet[ModelT] + + def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any): + ... + + def __iter__(self) -> Iterator[ModelT]: + ... + +else: + class InheritanceIterable(ModelIterable): + def __iter__(self): + return _iter_inheritance_queryset(self.queryset) class InheritanceQuerySetMixin(Generic[ModelT]): From 00937608fa399d82216dda4beefbf6980d9eb106 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 7 May 2024 09:59:06 +0200 Subject: [PATCH 26/28] Add type argument to `DescriptorWrapper` This preserves the type of the wrapped descriptor (usually a field). Maybe this is overkill, as `DescriptorWrapper` seems to only be used as part of the `FieldTracker` implementation and is not documented and barely tested. But technically, it is public API. --- model_utils/tracker.py | 57 ++++++++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 266ff3b8..61093802 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,7 +2,16 @@ from copy import deepcopy from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Protocol, + TypeVar, + cast, + overload, +) from django.core.exceptions import FieldError from django.db import models @@ -16,10 +25,22 @@ class _AugmentedModel(models.Model): _instance_initialized: bool _deferred_fields: set[str] - T = TypeVar("T") +class Descriptor(Protocol[T]): + def __get__(self, instance: object, owner: type[object]) -> T: + ... + + def __set__(self, instance: object, value: T) -> None: + ... + + +class FullDescriptor(Descriptor[T]): + def __delete__(self, instance: object) -> None: + ... + + class LightStateFieldFile(FieldFile): """ FieldFile subclass with the only aim to remove the instance from the state. @@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T: return deepcopy(value) -class DescriptorWrapper: +class DescriptorWrapper(Generic[T]): - def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str): + def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str): self.field_name = field_name self.descriptor = descriptor self.tracker_attname = tracker_attname @overload - def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper: + def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]: ... @overload - def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field: + def __get__(self, instance: models.Model, owner: type[models.Model]) -> T: ... - def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field: + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T: if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() @@ -78,7 +99,7 @@ def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> D tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value) return value - def __set__(self, instance: models.Model, value: models.Field) -> None: + def __set__(self, instance: models.Model, value: T) -> None: initialized = hasattr(instance, '_instance_initialized') was_deferred = self.field_name in instance.get_deferred_fields() @@ -101,23 +122,23 @@ def __set__(self, instance: models.Model, value: models.Field) -> None: else: instance.__dict__[self.field_name] = value - def __getattr__(self, attr: str) -> models.Field: + def __getattr__(self, attr: str) -> T: return getattr(self.descriptor, attr) @staticmethod - def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]: + def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]: if hasattr(descriptor, '__delete__'): return FullDescriptorWrapper else: return DescriptorWrapper -class FullDescriptorWrapper(DescriptorWrapper): +class FullDescriptorWrapper(DescriptorWrapper[T]): """ Wrapper for descriptors with all three descriptor methods. """ - def __delete__(self, obj: models.Field) -> None: - self.descriptor.__delete__(obj) # type: ignore[attr-defined] + def __delete__(self, obj: models.Model) -> None: + cast(FullDescriptor[T], self.descriptor).__delete__(obj) class FieldsContext: @@ -255,7 +276,9 @@ def has_changed(self, field: str) -> bool: # deferred fields haven't changed if field in self.deferred_fields and field not in self.instance.__dict__: return False - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field) @@ -348,7 +371,7 @@ def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) for field_name in self.fields: - descriptor: models.Field = getattr(sender, field_name) + descriptor: models.Field[Any, Any] = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) @@ -426,7 +449,9 @@ def has_changed(self, field: str) -> bool: if not self.instance.pk: return True elif field in self.saved_data: - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field) From 5fc37eb4b40983af7e18a76a8fcd78d12e840d63 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Tue, 7 May 2024 10:53:09 +0200 Subject: [PATCH 27/28] Provide type arguments to field base classes --- model_utils/fields.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/model_utils/fields.py b/model_utils/fields.py index 143528ff..ced117cb 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -12,12 +12,16 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable - from datetime import datetime + from datetime import date, datetime + + DateTimeFieldBase = models.DateTimeField[Union[str, datetime, date], datetime] +else: + DateTimeFieldBase = models.DateTimeField DEFAULT_CHOICES_NAME = 'STATUS' -class AutoCreatedField(models.DateTimeField): +class AutoCreatedField(DateTimeFieldBase): """ A DateTimeField that automatically populates itself at object creation. @@ -111,7 +115,7 @@ def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]: return name, path, args, kwargs -class MonitorField(models.DateTimeField): +class MonitorField(DateTimeFieldBase): """ A DateTimeField that monitors another field on the same model and sets itself to the current date/time whenever the monitored field From 9dc224723b8c8008ea994474b2c42ded5cda38d6 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 14 Jun 2024 11:56:35 +0200 Subject: [PATCH 28/28] Tell coverage tool to ignore lines intended for mypy only In particular, `if TYPE_CHECKING:` blocks and `...` in bodies of overloaded method definitions. --- .coveragerc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.coveragerc b/.coveragerc index 8708371a..77531508 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,8 @@ [run] include = model_utils/*.py + +[report] +exclude_also = + # Exclusive to mypy: + if TYPE_CHECKING:$ + \.\.\.$