diff --git a/model_utils/choices.py b/model_utils/choices.py index 46c3877c..fd8946f4 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -141,7 +141,7 @@ def __contains__(self, item): def __deepcopy__(self, memo): return self.__class__(*copy.deepcopy(self._triples, memo)) - def subset(self, *new_identifiers): + def subset(self, *new_identifiers) -> "Choices": identifiers = set(self._identifier_map.keys()) if not identifiers.issuperset(new_identifiers): diff --git a/model_utils/managers.py b/model_utils/managers.py index 75d3d98a..f6c71d11 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,3 +1,4 @@ +from typing import Optional, Type import warnings from django.core.exceptions import ObjectDoesNotExist @@ -7,6 +8,11 @@ from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join +try: + from typing import Self +except ImportError: + from typing import Any as Self + class InheritanceIterable(ModelIterable): def __iter__(self): @@ -43,7 +49,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._iterable_class = InheritanceIterable - def select_subclasses(self, *subclasses): + def select_subclasses(self, *subclasses)-> QuerySet: levels = None calculated_subclasses = self._get_subclasses_recurse( self.model, levels=levels) @@ -170,12 +176,12 @@ def _get_sub_obj_recurse(self, obj, s): else: return node - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args, **kwargs)-> Optional[models.Model]: return self.select_subclasses().get(*args, **kwargs) class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): - def instance_of(self, *models): + def instance_of(self, *models: Type[models.Model]) -> QuerySet: """ Fetch only objects that are instances of the provided model(s). """ @@ -204,16 +210,16 @@ def instance_of(self, *models): class InheritanceManagerMixin: _queryset_class = InheritanceQuerySet - def get_queryset(self): + def get_queryset(self) -> InheritanceQuerySet: return self._queryset_class(self.model) - def select_subclasses(self, *subclasses): + def select_subclasses(self, *subclasses: Type[models.Model]) -> QuerySet: return self.get_queryset().select_subclasses(*subclasses) - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: Type[models.Model], **kwargs) -> Optional[models.Model]: return self.get_queryset().get_subclass(*args, **kwargs) - def instance_of(self, *models): + def instance_of(self, *models: Type[models.Model]) -> QuerySet: return self.get_queryset().instance_of(*models) @@ -231,11 +237,11 @@ def __init__(self, *args, **kwargs): self._order_by = None super().__init__() - def order_by(self, *args): + def order_by(self, *args) -> Self: self._order_by = args return self - def get_queryset(self): + def get_queryset(self) -> QuerySet: qs = super().get_queryset().filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) @@ -252,7 +258,7 @@ class SoftDeletableQuerySetMixin: its ``is_removed`` field to True. """ - def delete(self): + def delete(self) -> None: """ Soft delete objects from queryset (set their ``is_removed`` field to True) @@ -275,7 +281,7 @@ def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs): self.emit_deprecation_warnings = _emit_deprecation_warnings super().__init__(*args, **kwargs) - def get_queryset(self): + def get_queryset(self) -> QuerySet: """ Return queryset limited to not removed entries. """ @@ -318,7 +324,7 @@ def get_quoted_query(self, query): params = tuple(params) return query % params - def join(self, qs=None): + def join(self, qs=None) -> QuerySet: ''' Join one queryset together with another using a temporary table. If no queryset is used, it will use the current queryset and join that diff --git a/model_utils/models.py b/model_utils/models.py index 268db8c5..fb668adb 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -1,3 +1,4 @@ +from typing import Sequence from django.core.exceptions import ImproperlyConfigured from django.db import models, router, transaction from django.db.models.functions import Now @@ -192,8 +193,8 @@ def save(self, signals_to_disable=None, *args, **kwargs): super().save(*args, **kwargs) - def save_base(self, raw=False, force_insert=False, - force_update=False, using=None, update_fields=None): + def save_base(self, raw: bool=False, force_insert: bool=False, + force_update: bool=False, using=None, update_fields: Sequence[str]=None): """ Copied from base class for a minor change. This is an ugly overwriting but since Django's ``save_base`` method diff --git a/model_utils/py.typed b/model_utils/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 682ed293..2dc5ed85 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -1,5 +1,6 @@ from copy import deepcopy from functools import wraps +from typing import Any, Dict, Sequence from django.core.exceptions import FieldError from django.db import models @@ -204,10 +205,10 @@ def __call__(self, *fields): def deferred_fields(self): return self.instance.get_deferred_fields() - def get_field_value(self, field): + def get_field_value(self, field: str): return getattr(self.instance, self.field_map[field]) - def set_saved_fields(self, fields=None): + def set_saved_fields(self, fields: Sequence[str]=None) -> None: if not self.instance.pk: self.saved_data = {} elif fields is None: @@ -219,7 +220,7 @@ def set_saved_fields(self, fields=None): for field, field_value in self.saved_data.items(): self.saved_data[field] = lightweight_deepcopy(field_value) - def current(self, fields=None): + def current(self, fields: Sequence[str]=None) -> Dict[str, Any]: """Returns dict of current values for all tracked fields""" if fields is None: deferred_fields = self.deferred_fields @@ -233,7 +234,7 @@ def current(self, fields=None): return {f: self.get_field_value(f) for f in fields} - def has_changed(self, field): + def has_changed(self, field: str) -> bool: """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: # deferred fields haven't changed @@ -243,7 +244,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def previous(self, field): + def previous(self, field: str): """Returns currently saved value of given field""" # handle deferred fields that have not yet been loaded from the database @@ -263,7 +264,7 @@ def previous(self, field): return self.saved_data.get(field) - def changed(self): + def changed(self) -> Dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" return { field: self.previous(field) @@ -385,7 +386,7 @@ def __get__(self, instance, owner): class ModelInstanceTracker(FieldInstanceTracker): - def has_changed(self, field): + def has_changed(self, field) -> bool: """Returns ``True`` if field has changed from currently saved value""" if not self.instance.pk: return True @@ -394,7 +395,7 @@ def has_changed(self, field): else: raise FieldError('field "%s" not tracked' % field) - def changed(self): + def changed(self) -> Dict[str, Any]: """Returns dict of fields that changed since save (with old values)""" if not self.instance.pk: return {} diff --git a/setup.py b/setup.py index 2c6f855f..46cb6ec8 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,9 @@ def long_desc(root_path): zip_safe=False, package_data={ 'model_utils': [ - 'locale/*/LC_MESSAGES/django.po', 'locale/*/LC_MESSAGES/django.mo' + 'locale/*/LC_MESSAGES/django.po', + 'locale/*/LC_MESSAGES/django.mo', + 'py.typed', ], }, )