Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints and the py.typed stub #551

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_utils/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __contains__(self, item):
def __deepcopy__(self, memo):
return self.__class__(*copy.deepcopy(self._triples, memo))

def subset(self, *new_identifiers):
def subset(self, *new_identifiers) -> "Choices":
identifiers = set(self._identifier_map.keys())

if not identifiers.issuperset(new_identifiers):
Expand Down
30 changes: 18 additions & 12 deletions model_utils/managers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Type
import warnings

from django.core.exceptions import ObjectDoesNotExist
Expand All @@ -7,6 +8,11 @@
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join

try:
from typing import Self
except ImportError:
from typing import Any as Self


class InheritanceIterable(ModelIterable):
def __iter__(self):
Expand Down Expand Up @@ -43,7 +49,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._iterable_class = InheritanceIterable

def select_subclasses(self, *subclasses):
def select_subclasses(self, *subclasses)-> QuerySet:
levels = None
calculated_subclasses = self._get_subclasses_recurse(
self.model, levels=levels)
Expand Down Expand Up @@ -170,12 +176,12 @@ def _get_sub_obj_recurse(self, obj, s):
else:
return node

def get_subclass(self, *args, **kwargs):
def get_subclass(self, *args, **kwargs)-> Optional[models.Model]:
return self.select_subclasses().get(*args, **kwargs)


class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
def instance_of(self, *models):
def instance_of(self, *models: Type[models.Model]) -> QuerySet:
"""
Fetch only objects that are instances of the provided model(s).
"""
Expand Down Expand Up @@ -204,16 +210,16 @@ def instance_of(self, *models):
class InheritanceManagerMixin:
_queryset_class = InheritanceQuerySet

def get_queryset(self):
def get_queryset(self) -> InheritanceQuerySet:
return self._queryset_class(self.model)

def select_subclasses(self, *subclasses):
def select_subclasses(self, *subclasses: Type[models.Model]) -> QuerySet:
return self.get_queryset().select_subclasses(*subclasses)

def get_subclass(self, *args, **kwargs):
def get_subclass(self, *args: Type[models.Model], **kwargs) -> Optional[models.Model]:
return self.get_queryset().get_subclass(*args, **kwargs)

def instance_of(self, *models):
def instance_of(self, *models: Type[models.Model]) -> QuerySet:
return self.get_queryset().instance_of(*models)


Expand All @@ -231,11 +237,11 @@ def __init__(self, *args, **kwargs):
self._order_by = None
super().__init__()

def order_by(self, *args):
def order_by(self, *args) -> Self:
self._order_by = args
return self

def get_queryset(self):
def get_queryset(self) -> QuerySet:
qs = super().get_queryset().filter(self._q)
if self._order_by is not None:
return qs.order_by(*self._order_by)
Expand All @@ -252,7 +258,7 @@ class SoftDeletableQuerySetMixin:
its ``is_removed`` field to True.
"""

def delete(self):
def delete(self) -> None:
"""
Soft delete objects from queryset (set their ``is_removed``
field to True)
Expand All @@ -275,7 +281,7 @@ def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs):
self.emit_deprecation_warnings = _emit_deprecation_warnings
super().__init__(*args, **kwargs)

def get_queryset(self):
def get_queryset(self) -> QuerySet:
"""
Return queryset limited to not removed entries.
"""
Expand Down Expand Up @@ -318,7 +324,7 @@ def get_quoted_query(self, query):
params = tuple(params)
return query % params

def join(self, qs=None):
def join(self, qs=None) -> QuerySet:
'''
Join one queryset together with another using a temporary table. If
no queryset is used, it will use the current queryset and join that
Expand Down
5 changes: 3 additions & 2 deletions model_utils/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Sequence
from django.core.exceptions import ImproperlyConfigured
from django.db import models, router, transaction
from django.db.models.functions import Now
Expand Down Expand Up @@ -192,8 +193,8 @@ def save(self, signals_to_disable=None, *args, **kwargs):

super().save(*args, **kwargs)

def save_base(self, raw=False, force_insert=False,
force_update=False, using=None, update_fields=None):
def save_base(self, raw: bool=False, force_insert: bool=False,
force_update: bool=False, using=None, update_fields: Sequence[str]=None):
"""
Copied from base class for a minor change.
This is an ugly overwriting but since Django's ``save_base`` method
Expand Down
Empty file added model_utils/py.typed
Empty file.
17 changes: 9 additions & 8 deletions model_utils/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from functools import wraps
from typing import Any, Dict, Sequence

from django.core.exceptions import FieldError
from django.db import models
Expand Down Expand Up @@ -204,10 +205,10 @@ def __call__(self, *fields):
def deferred_fields(self):
return self.instance.get_deferred_fields()

def get_field_value(self, field):
def get_field_value(self, field: str):
return getattr(self.instance, self.field_map[field])

def set_saved_fields(self, fields=None):
def set_saved_fields(self, fields: Sequence[str]=None) -> None:
if not self.instance.pk:
self.saved_data = {}
elif fields is None:
Expand All @@ -219,7 +220,7 @@ def set_saved_fields(self, fields=None):
for field, field_value in self.saved_data.items():
self.saved_data[field] = lightweight_deepcopy(field_value)

def current(self, fields=None):
def current(self, fields: Sequence[str]=None) -> Dict[str, Any]:
"""Returns dict of current values for all tracked fields"""
if fields is None:
deferred_fields = self.deferred_fields
Expand All @@ -233,7 +234,7 @@ def current(self, fields=None):

return {f: self.get_field_value(f) for f in fields}

def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
Expand All @@ -243,7 +244,7 @@ def has_changed(self, field):
else:
raise FieldError('field "%s" not tracked' % field)

def previous(self, field):
def previous(self, field: str):
"""Returns currently saved value of given field"""

# handle deferred fields that have not yet been loaded from the database
Expand All @@ -263,7 +264,7 @@ def previous(self, field):

return self.saved_data.get(field)

def changed(self):
def changed(self) -> Dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
return {
field: self.previous(field)
Expand Down Expand Up @@ -385,7 +386,7 @@ def __get__(self, instance, owner):

class ModelInstanceTracker(FieldInstanceTracker):

def has_changed(self, field):
def has_changed(self, field) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
Expand All @@ -394,7 +395,7 @@ def has_changed(self, field):
else:
raise FieldError('field "%s" not tracked' % field)

def changed(self):
def changed(self) -> Dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def long_desc(root_path):
zip_safe=False,
package_data={
'model_utils': [
'locale/*/LC_MESSAGES/django.po', 'locale/*/LC_MESSAGES/django.mo'
'locale/*/LC_MESSAGES/django.po',
'locale/*/LC_MESSAGES/django.mo',
'py.typed',
],
},
)