Skip to content

Commit

Permalink
fix incorrect checks for marks
Browse files Browse the repository at this point in the history
  • Loading branch information
Aran-Fey committed Dec 20, 2024
1 parent 83a052b commit 73d645d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 25 deletions.
29 changes: 13 additions & 16 deletions introspection/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import sentinel

from .mark import FORWARDS_ARGUMENTS, forwards_arguments, does_not_alter_signature
from .mark import forwards_arguments, has_mark
from .misc import static_vars, is_abstract, static_mro
from .errors import *
from .types import Slot, Function, Class
Expand Down Expand Up @@ -483,10 +483,10 @@ def __new__(original_new, cls, *args, **kwargs):
if method_type is auto:
method_type = get_implicit_method_type(name)

original_method, wrap_original = _get_original_method(cls, name, method_type)
original_method, decorator_for_replacement_method = _get_original_method(cls, name, method_type)

@does_not_alter_signature
@wrap_original
@functools.wraps(original_method) # type: ignore (wtf?)
@decorator_for_replacement_method
def replacement_method(*args, **kwargs):
return method(original_method, *args, **kwargs)

Expand All @@ -497,13 +497,18 @@ def _get_original_method(cls: type, method_name: str, method_type):
cls_vars = static_vars(cls)

original_method: Callable[..., Any]
decorator_for_replacement_method = lambda func: func

try:
original_method = cls_vars[method_name] # type: ignore
except KeyError:
# === SPECIAL METHOD: __new__ ===
if method_name == "__new__":
original_method, wrap_original = _make_original_new_method(cls)
original_method = _make_original_new_method(cls)

# We're just gonna assume that the user is going to properly call our original_method,
# because if not, they should've just used `add_method_to_class` instead.
decorator_for_replacement_method = forwards_arguments

# === STATICMETHODS ===
elif method_type is staticmethod:
Expand All @@ -513,7 +518,6 @@ def _original_method1(*args, **kwargs):
return super_method(*args, **kwargs)

original_method = _original_method1
wrap_original = lambda func: func

# === INSTANCE- AND CLASSMETHODS ===
else:
Expand All @@ -523,14 +527,11 @@ def _original_method2(self_or_cls, *args, **kwargs):
return super_method(*args, **kwargs)

original_method = _original_method2
wrap_original = lambda func: func
else:
if isinstance(original_method, (staticmethod, classmethod)):
original_method = original_method.__func__

wrap_original = functools.wraps(original_method)

return original_method, wrap_original
return original_method, decorator_for_replacement_method


def _make_original_new_method(cls: type):
Expand Down Expand Up @@ -563,7 +564,7 @@ def original_method(class_: type, *args, **kwargs):
except KeyError:
continue

if new in FORWARDS_ARGUMENTS:
if has_mark(new, forwards_arguments):
continue

if getattr(new, "_forwards_args", False): # Legacy version of `@forwards_arguments`
Expand All @@ -578,8 +579,4 @@ def original_method(class_: type, *args, **kwargs):

return super_new(class_, *args, **kwargs) # type: ignore

# We're just gonna assume that the user is going to properly call our original_method, because
# if not, they should've just used `add_method_to_class` instead.
wrap_original = forwards_arguments

return original_method, wrap_original
return original_method
21 changes: 15 additions & 6 deletions introspection/mark.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
import weakref
from typing import Callable, MutableSet, TypeVar
import typing as t

from .misc import iter_wrapped


__all__ = ["does_not_alter_signature", "forwards_arguments"]


C = TypeVar("C", bound=Callable)
C = t.TypeVar("C", bound=t.Callable)


DOES_NOT_ALTER_SIGNATURE: MutableSet[Callable] = weakref.WeakSet()
FORWARDS_ARGUMENTS: MutableSet[Callable] = weakref.WeakSet()
MARKS = weakref.WeakKeyDictionary[t.Callable, t.MutableSet[t.Callable]]()


def does_not_alter_signature(func: C) -> C:
DOES_NOT_ALTER_SIGNATURE.add(func)
MARKS.setdefault(func, set()).add(does_not_alter_signature)
return func


def forwards_arguments(func: C) -> C:
FORWARDS_ARGUMENTS.add(func)
MARKS.setdefault(func, set()).add(forwards_arguments)
return func


def has_mark(func: t.Callable, mark: t.Callable) -> bool:
for func in iter_wrapped(func):
if mark in MARKS.get(func, ()):
return True

return False
9 changes: 6 additions & 3 deletions introspection/signature_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .bound_arguments import BoundArguments
from .parameter import Parameter
from .mark import DOES_NOT_ALTER_SIGNATURE
from .mark import has_mark, does_not_alter_signature
from .misc import static_mro, static_vars
from ._utils import SIG_EMPTY, NONE, _Sentinel
from .errors import *
Expand Down Expand Up @@ -168,7 +168,7 @@ def from_callable( # type: ignore[incompatible-override]
# Cache the result, if possible
try:
callable_.__signature__ = signature # type: ignore
except AttributeError:
except (AttributeError, TypeError):
pass

return signature
Expand Down Expand Up @@ -743,7 +743,10 @@ def _find_constructor_function(cls: type) -> t.Callable:
if not callable(bound_method):
continue

if func in DOES_NOT_ALTER_SIGNATURE or bound_method in DOES_NOT_ALTER_SIGNATURE:
if (
has_mark(func, does_not_alter_signature) # type: ignore
or has_mark(bound_method, does_not_alter_signature)
):
continue

return bound_method
Expand Down
14 changes: 14 additions & 0 deletions tests/test_class_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,20 @@ class Foo(Bar):
assert new1_kwargs == {"what": "are these doing here"}


def test_wrap_method_does_not_alter_signature():
class Foo:
def func(self, foo: int):
pass

def wrapper(wrapped_func, *args, **kwargs):
return wrapped_func(*args, **kwargs)

wrap_method(Foo, wrapper, "func")

signature = introspection.signature(Foo.func)
assert list(signature.parameters) == ["self", "foo"]


def test_wrap_method_typeerror():
with pytest.raises(TypeError):
# 2nd argument must be a class
Expand Down
23 changes: 23 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,29 @@ def __init__(self, foo):
assert list(sig.parameters) == ["foo"]


def test_mark_on_decorated_function():
def wrap_in_unnecessary_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper

class Meta(type):
# Make sure the mark still takes effect if the function is decorated
@wrap_in_unnecessary_decorator
@introspection.mark.does_not_alter_signature
def __call__(cls, *args, **kwargs):
return super().__call__(*args, **kwargs)

class Cls(metaclass=Meta):
def __init__(self, foo):
pass

sig = Signature.from_callable(Cls)
assert list(sig.parameters) == ["foo"]


def test_partial():
def func(x, *y, z):
pass
Expand Down

0 comments on commit 73d645d

Please sign in to comment.