diff --git a/src/stubgen.py b/src/stubgen.py index be01133d..81c16b73 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -59,6 +59,7 @@ class and repeatedly call ``.put()`` to register modules or contents within the import textwrap import importlib import importlib.machinery +import importlib.util import types import typing from dataclasses import dataclass @@ -1089,13 +1090,43 @@ def type_str(self, tp: Union[List[Any], Tuple[Any, ...], Dict[Any, Any], Any]) - result = repr(tp) return self.simplify_types(result) + def check_party(self, module: str) -> Literal[0, 1, 2]: + """ + Check source of module + 0 = From stdlib + 1 = From 3rd party package + 2 = From the package being built + """ + if module.startswith(".") or module == self.module.__name__.split('.')[0]: + return 2 + + try: + spec = importlib.util.find_spec(module) + except ModuleNotFoundError: + return 1 + + if spec: + if spec.origin and "site-packages" in spec.origin: + return 1 + else: + return 0 + else: + return 1 + def get(self) -> str: """Generate the final stub output""" s = "" + last_party = None - for module in sorted(self.imports): + for module in sorted(self.imports, key=lambda i: str(self.check_party(i)) + i): imports = self.imports[module] items: List[str] = [] + party = self.check_party(module) + + if party != last_party: + if last_party is not None: + s += "\n" + last_party = party for (k, v1), v2 in imports.items(): if k is None: @@ -1108,15 +1139,16 @@ def get(self) -> str: items.append(f"{k} as {v2}") else: items.append(k) - + + items = sorted(items) if items: items_v0 = ", ".join(items) items_v0 = f"from {module} import {items_v0}\n" items_v1 = "(\n " + ",\n ".join(items) + "\n)" items_v1 = f"from {module} import {items_v1}\n" s += items_v0 if len(items_v0) <= 70 else items_v1 - if s: - s += "\n" + + s += "\n\n" s += self.put_abstract_enum_class() # Append the main generated stub @@ -1335,7 +1367,6 @@ def add_pattern(query: str, lines: List[str]): def main(args: Optional[List[str]] = None) -> None: import sys - import os # Ensure that the current directory is on the path if "" not in sys.path and "." not in sys.path: diff --git a/tests/common.py b/tests/common.py index 769e6576..2f8ccafa 100644 --- a/tests/common.py +++ b/tests/common.py @@ -5,9 +5,9 @@ is_pypy = platform.python_implementation() == 'PyPy' is_darwin = platform.system() == 'Darwin' -def collect(): +def collect() -> None: if is_pypy: - for i in range(3): + for _ in range(3): gc.collect() else: gc.collect() diff --git a/tests/py_stub_test.pyi.ref b/tests/py_stub_test.pyi.ref index e8273cd8..91d01468 100644 --- a/tests/py_stub_test.pyi.ref +++ b/tests/py_stub_test.pyi.ref @@ -1,5 +1,6 @@ from collections.abc import Callable -from typing import overload, TypeVar +from typing import TypeVar, overload + class AClass: __annotations__: dict = {'STATIC_VAR' : int} diff --git a/tests/test_classes_ext.pyi.ref b/tests/test_classes_ext.pyi.ref index 6eadc0a8..494ba479 100644 --- a/tests/test_classes_ext.pyi.ref +++ b/tests/test_classes_ext.pyi.ref @@ -1,5 +1,6 @@ from typing import overload + class A: def __init__(self, arg: int, /) -> None: ... diff --git a/tests/test_enum_ext.pyi.ref b/tests/test_enum_ext.pyi.ref index 5fbfd743..87f17d71 100644 --- a/tests/test_enum_ext.pyi.ref +++ b/tests/test_enum_ext.pyi.ref @@ -1,5 +1,6 @@ from typing import overload + class _Enum: def __init__(self, arg: object, /) -> None: ... def __repr__(self, /) -> str: ... diff --git a/tests/test_functions_ext.pyi.ref b/tests/test_functions_ext.pyi.ref index 450dd1e1..2898b3ea 100644 --- a/tests/test_functions_ext.pyi.ref +++ b/tests/test_functions_ext.pyi.ref @@ -1,6 +1,7 @@ from collections.abc import Callable import types -from typing import overload, Annotated, Any +from typing import Annotated, Any, overload + def call_guard_value() -> int: ... diff --git a/tests/test_make_iterator_ext.pyi.ref b/tests/test_make_iterator_ext.pyi.ref index a46055b6..80d9fb88 100644 --- a/tests/test_make_iterator_ext.pyi.ref +++ b/tests/test_make_iterator_ext.pyi.ref @@ -1,6 +1,7 @@ from collections.abc import Iterator, Mapping from typing import overload + class IdentityMap: def __init__(self) -> None: ... diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index f9407592..91ee0784 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -1,6 +1,8 @@ -from numpy.typing import ArrayLike from typing import Annotated, overload +from numpy.typing import ArrayLike + + class Cls: def __init__(self) -> None: ... diff --git a/tests/test_stl_ext.pyi.ref b/tests/test_stl_ext.pyi.ref index 5641a3db..971e08e8 100644 --- a/tests/test_stl_ext.pyi.ref +++ b/tests/test_stl_ext.pyi.ref @@ -1,8 +1,9 @@ -from collections.abc import Sequence, Callable, Mapping, Set +from collections.abc import Callable, Mapping, Sequence, Set import os import pathlib from typing import overload + class ClassWithMovableField: def __init__(self) -> None: ... diff --git a/tests/test_typing_ext.pyi.ref b/tests/test_typing_ext.pyi.ref index 0f03e31d..5918990f 100644 --- a/tests/test_typing_ext.pyi.ref +++ b/tests/test_typing_ext.pyi.ref @@ -1,7 +1,9 @@ +from collections.abc import Iterable +from typing import Generic, Optional, Self, TypeAlias, TypeVar + from . import submodule as submodule from .submodule import F as F, f as f2 -from collections.abc import Iterable -from typing import Self, Optional, TypeAlias, TypeVar, Generic + # a prefix