From e1cb67001cca16189827426b2d103e8c9e5700fb Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Sun, 3 Mar 2024 20:29:23 +0100 Subject: [PATCH] stubgen: don't include aliases to functions/classes of different packages --- src/stubgen.py | 64 +++++++++++++++++++++++++++++++++++-------- tests/py_stub_test.py | 4 +++ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/src/stubgen.py b/src/stubgen.py index b27ba6f9..59b1ee57 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -110,7 +110,17 @@ class and repeatedly call ``.put()`` to register modules or contents within the # Type of an entry of the ``__nb_signature__`` tuple of nanobind getters and setters. NbGetterSetterSignature = Tuple[str, str] +class NamedObject(Protocol): + """ + Typing protocol representing a an object with __name__ and __module__ members + """ + __module__: str + __name__: str + class NbFunction(Protocol): + """ + Typing protocol representing a nanobind function with its __nb_signature__ property + """ __module__: Literal["nanobind"] __name__: Literal["nb_func", "nb_method"] __nb_signature__: Tuple[NbFunctionSignature, ...] @@ -123,6 +133,7 @@ class NbGetterSetter(Protocol): class NbStaticProperty(Protocol): + """Typing protocol representing a nanobind static property""" __module__: Literal["nanobind"] __name__: Literal["nb_static_property"] fget: NbGetterSetter @@ -130,6 +141,7 @@ class NbStaticProperty(Protocol): class NbType(Protocol): + """typing protocol representing a nanobind type object""" __module__: Literal["nanobind"] __name__: Literal["nb_type"] __nb_signature__: str @@ -155,6 +167,8 @@ def __init__( module: types.ModuleType, include_docstrings: bool = True, include_private: bool = False, + include_internal_imports: bool = True, + include_external_imports: bool = False, max_expr_length: int = 50, patterns: List[ReplacePattern] = [], ) -> None: @@ -167,6 +181,12 @@ def __init__( # Include private members that start or end with a single underscore? self.include_private = include_private + # Include types and functions imported from the same package (but a different module) + self.include_internal_imports = include_internal_imports + + # Include types and functions imported from external packages? + self.include_external_imports = include_external_imports + # Maximal length (in characters) before an expression gets abbreviated as '...' self.max_expr_length = max_expr_length @@ -435,12 +455,18 @@ def put_nb_static_property(self, name: Optional[str], prop: NbStaticProperty): def put_type(self, tp: NbType, name: Optional[str]): """Append a 'nb_type' type object""" - if name and (name != tp.__name__ or self.module.__name__ != tp.__module__): - if self.module.__name__ == tp.__module__: - # This is an alias of a type in the same module + tp_name, tp_mod_name = tp.__name__, tp.__module__ + mod_name = self.module.__name__ + + if name and (name != tp_name or mod_name != tp_mod_name): + same_module = tp_mod_name == mod_name + same_toplevel_module = tp_mod_name.split(".")[0] == mod_name.split(".")[0] + + if same_module: + # This is an alias of a type in the same module or same top-level module alias_tp = self.import_object("typing", "TypeAlias") - self.write_ln(f"{name}: {alias_tp} = {tp.__name__}\n") - else: + self.write_ln(f"{name}: {alias_tp} = {tp_name}\n") + elif self.include_external_imports or (same_toplevel_module and self.include_internal_imports): # Import from a different module self.put_value(tp, name) else: @@ -475,7 +501,7 @@ def put_type(self, tp: NbType, name: Optional[str]): self.write_ln(self.simplify_types(s)) self.output = self.output[:-1] + ":\n" else: - self.write_ln(f"class {tp.__name__}:") + self.write_ln(f"class {tp_name}:") if tp_bases is None: tp_bases = getattr(tp, "__orig_bases__", None) if tp_bases is None: @@ -531,6 +557,14 @@ def put_value(self, value: object, name: str, parent: Optional[object] = None, a """ tp = type(value) + # Ignore module imports of non-type values like 'from typing import Optional' + if ( + not self.include_external_imports + and tp.__module__ == "typing" + and str(value) == f"typing.{name}" + ): + return + if isinstance(parent, type) and issubclass(tp, parent) and self.is_enum(parent): # This is an entry of an enumeration self.write_ln(f"{name}: {self.type_str(tp)}") @@ -538,9 +572,12 @@ def put_value(self, value: object, name: str, parent: Optional[object] = None, a self.put_docstr(value.__doc__) self.write("\n") elif self.is_function(tp) or isinstance(value, type): - # This is a function or a type, import it from its actual source - value = cast(type, value) - self.import_object(value.__module__, value.__name__, name) + named_value = cast(NamedObject, value) + same_toplevel_module = named_value.__module__.split(".")[0] == self.module.__name__.split(".")[0] + + if self.include_external_imports or (same_toplevel_module and self.include_internal_imports): + # This is a function or a type, import it from its actual source + self.import_object(named_value.__module__, named_value.__name__, name) else: value_str = self.expr_str(value, abbrev) @@ -728,6 +765,10 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object if ismodule(value): if len(self.stack) != 1: + is_external = value.__name__.split(".")[0] != self.module.__name__.split(".")[0] + if not self.include_external_imports and is_external: + return + # Do not recurse into submodules, but include a directive to import them self.import_object(value.__name__, name=None, as_name=name) return @@ -831,8 +872,9 @@ def import_object( def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]: """ Attempt to convert a value into valid Python syntax that regenerates - that value. When ``abbrev`` is True, give up and replace with '...' if - the expression is too complicated to be included in the stubs + that value. When ``abbrev`` is True, the implementation gives up and + returns ``None`` when the expression is considered to be too + complicated. """ tp = type(e) for t in [bool, int, type(None), type(builtins.Ellipsis)]: diff --git a/tests/py_stub_test.py b/tests/py_stub_test.py index f00bb244..bfa2c20b 100644 --- a/tests/py_stub_test.py +++ b/tests/py_stub_test.py @@ -6,6 +6,10 @@ else: import typing +# Ignore a type and a function from elsewhere. These shouldn't be included in +# the stub by default +from os import PathLike, getcwd + del sys C = 123