From 0d191b47219351fef2a102dd7461cc3e2f6dffac Mon Sep 17 00:00:00 2001 From: laggykiller Date: Sun, 3 Mar 2024 20:45:40 +0800 Subject: [PATCH] Improve typing in python scripts --- cmake/collect-symbols-pypy.py | 2 +- cmake/collect-symbols.py | 2 +- src/__main__.py | 2 +- src/stubgen.py | 178 +++++++++++++++++++++++----------- 4 files changed, 123 insertions(+), 61 deletions(-) diff --git a/cmake/collect-symbols-pypy.py b/cmake/collect-symbols-pypy.py index 4f32af890..805490d0c 100644 --- a/cmake/collect-symbols-pypy.py +++ b/cmake/collect-symbols-pypy.py @@ -2,7 +2,7 @@ import tarfile import subprocess -funcs = set() +funcs: "set[str]" = set() files = [ ('https://downloads.python.org/pypy/pypy3.9-v7.3.11-macos_arm64.tar.bz2', 'pypy3.9-v7.3.11-macos_arm64/bin/libpypy3.9-c.dylib') diff --git a/cmake/collect-symbols.py b/cmake/collect-symbols.py index abf42fe38..ad92da4c1 100644 --- a/cmake/collect-symbols.py +++ b/cmake/collect-symbols.py @@ -8,7 +8,7 @@ from urllib.request import urlopen import re -funcs = set() +funcs: "set[str]" = set() for ver in ['3.7', '3.8', '3.9']: url = f'https://raw.githubusercontent.com/python/cpython/{ver}/PC/python3.def' diff --git a/src/__main__.py b/src/__main__.py index f239099b7..78b8feca4 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,6 +1,6 @@ import argparse import sys -import sysconfig +import sysconfig # type: ignore from . import __version__, include_dir, cmake_dir diff --git a/src/stubgen.py b/src/stubgen.py index b374416d7..d6156a1de 100755 --- a/src/stubgen.py +++ b/src/stubgen.py @@ -52,6 +52,7 @@ class and repeatedly call ``.put()`` to register modules or contents within the """ import argparse +import builtins from inspect import Signature, Parameter, signature, ismodule, getmembers import textwrap import importlib @@ -59,10 +60,15 @@ class and repeatedly call ``.put()`` to register modules or contents within the import types import typing from dataclasses import dataclass -from typing import Dict, Sequence, List, Optional, Callable, Tuple, cast +from typing import Dict, Sequence, List, Optional, Tuple, cast, Generator, Any, Callable, Union, Protocol, Literal import re import sys +if sys.version_info < (3, 9): + from typing import Match, Pattern +else: + from re import Match, Pattern + # Standard operations supported by arithmetic enumerations # fmt: off ENUM_OPS = [ @@ -81,11 +87,57 @@ class and repeatedly call ``.put()`` to register modules or contents within the ] # fmt: on +# This type is used to track per-module imports (``import name as desired_name``) +# during stub generation. The actual name in the stub is given by the value element. +# (name, desired_as_name) -> actual_as_name +ImportDict = Dict[Tuple[Optional[str], Optional[str]], Optional[str]] + +# This type maps a module name to an `ImportDict` tuple that tracks the +# import declarations from that module. +# package_name -> ((name, desired_as_name) -> actual_as_name) +PackagesDict = Dict[str, ImportDict] + +# Type of an entry of the ``__nb_signature__`` tuple of nanobind functions. +# It stores a function signature string, docstring, and a tuple of default function values. +# (signature_str, doc_str, (default_arg_1, default_arg_2, ...)) +NbSignature = Tuple[Optional[str], Optional[str]] + +# Type of an entry of the ``__nb_signature__`` tuple of nanobind functions. +NbFunctionSignature = Tuple[Optional[str], Optional[str], Optional[Tuple[Any, ...]]] + +# Type of an entry of the ``__nb_signature__`` tuple of nanobind getters and setters. +NbGetterSetterSignature = Tuple[str, str] + +class NbFunction(Protocol): + __module__: Literal["nanobind"] + __name__: Literal["nb_func", "nb_method"] + __nb_signature__: Tuple[NbFunctionSignature, ...] + __call__: Callable[..., Any] + + +@typing.runtime_checkable +class NbGetterSetter(Protocol): + __nb_signature__: Tuple[NbGetterSetterSignature, ...] + + +class NbStaticProperty(Protocol): + __module__: Literal["nanobind"] + __name__: Literal["nb_static_property"] + fget: NbGetterSetter + fset: NbGetterSetter + + +class NbType(Protocol): + __module__: Literal["nanobind"] + __name__: Literal["nb_type"] + __nb_signature__: str + __bases__: Tuple[Any, ...] + @dataclass -class Pattern: +class ReplacePattern: # A replacement patterns as produced by ``load_pattern_file()`` below - query: re.Pattern + query: Pattern[str] lines: List[str] matches: int @@ -97,7 +149,7 @@ def __init__( include_docstrings: bool = True, include_private: bool = False, max_expr_length: int = 50, - patterns: List[Pattern] = [], + patterns: List[ReplacePattern] = [], ) -> None: # Module to check for name conflicts when adding helper imports self.module = module @@ -130,9 +182,7 @@ def __init__( # Dictionary to keep track of import directives added by the stub generator # Maps package_name -> ((name, desired_as_name) -> actual_as_name) - self.imports: Dict[ - str, Dict[Tuple[Optional[str], Optional[str]], Optional[str]] - ] = {} + self.imports: PackagesDict = {} # ---------- Regular expressions ---------- @@ -192,7 +242,7 @@ def put_docstr(self, docstr: str) -> None: docstr = f'{raw_str}"""{docstr}"""\n' self.write_par(docstr) - def put_nb_overload(self, fn, sig: List, name: Optional[str] = None) -> None: + def put_nb_overload(self, fn: NbFunction, sig: NbFunctionSignature, name: Optional[str] = None) -> None: """ The ``put_nb_func()`` repeatedly calls this method to render the individual method overloads. @@ -212,15 +262,15 @@ def put_nb_overload(self, fn, sig: List, name: Optional[str] = None) -> None: if default_args: for index, arg in enumerate(default_args): pos = -1 - custom_signature = False pattern = None + arg_str = None # First, handle the case where the user overrode the default value signature if isinstance(arg, str): pattern = f"\\={index}" pos = sig_str.find(pattern, start) if pos >= 0: - custom_signature = True + arg_str = arg # General case if pos < 0: @@ -232,14 +282,11 @@ def put_nb_overload(self, fn, sig: List, name: Optional[str] = None) -> None: "Could not locate default argument in function signature" ) - if custom_signature: - arg_str = arg - else: + if not arg_str: # Call expr_str to convert the default value to a string. # Abbreviate with '...' if it is too long. - arg_str = self.expr_str(arg, abbrev=True) - if arg_str is None: - arg_str = "..." + expr = self.expr_str(arg, abbrev=True) + arg_str = expr if expr else "..." assert ( "\n" not in arg_str @@ -266,7 +313,7 @@ def put_nb_overload(self, fn, sig: List, name: Optional[str] = None) -> None: self.depth -= 1 self.write("\n") - def put_nb_func(self, fn, name: Optional[str] = None) -> None: + def put_nb_func(self, fn: NbFunction, name: Optional[str] = None) -> None: """Append a nanobind function binding to the stub""" sigs = fn.__nb_signature__ count = len(sigs) @@ -281,7 +328,7 @@ def put_nb_func(self, fn, name: Optional[str] = None) -> None: self.write_ln(f"@{overload}") self.put_nb_overload(fn, s, name) - def put_function(self, fn, name: Optional[str] = None, parent=None): + def put_function(self, fn: Callable[..., Any], name: Optional[str] = None, parent: Optional[object] = None): """Append a function of an arbitrary type to the stub""" # Don't generate a constructor for nanobind classes that aren't constructible if name == "__init__" and type(parent).__name__.startswith("nb_type"): @@ -309,13 +356,15 @@ def put_function(self, fn, name: Optional[str] = None, parent=None): # Special handling for nanobind functions with overloads if type(fn).__module__ == "nanobind": + fn = cast(NbFunction, fn) self.put_nb_func(fn, name) return if name is None: name = fn.__name__ + assert name - overloads: Sequence[Callable] = [] + overloads: Sequence[Callable[..., Any]] = [] if hasattr(fn, "__module__"): if sys.version_info >= (3, 11, 0): overloads = typing.get_overloads(fn) @@ -352,7 +401,7 @@ def put_function(self, fn, name: Optional[str] = None, parent=None): self.depth -= 1 self.write("\n") - def put_property(self, prop, name): + def put_property(self, prop: property, name: Optional[str]): """Append a Python 'property' object""" fget, fset = prop.fget, prop.fset self.write_ln("@property") @@ -360,10 +409,7 @@ def put_property(self, prop, name): if fset: self.write_ln(f"@{name}.setter") docstrings_backup = self.include_docstrings - if ( - type(fget).__module__ == "nanobind" - and type(fset).__module__ == "nanobind" - ): + if isinstance(fget, NbGetterSetter) and isinstance(fset, NbGetterSetter): doc1 = fget.__nb_signature__[0][1] doc2 = fset.__nb_signature__[0][1] if doc1 and doc2 and doc1 == doc2: @@ -371,7 +417,7 @@ def put_property(self, prop, name): self.put(prop.fset, name=name) self.include_docstrings = docstrings_backup - def put_nb_static_property(self, name, prop): + def put_nb_static_property(self, name: Optional[str], prop: NbStaticProperty): """Append a 'nb_static_property' object""" getter_sig = prop.fget.__nb_signature__[0][0] getter_sig = getter_sig[getter_sig.find("/) -> ") + 6 :] @@ -380,7 +426,7 @@ def put_nb_static_property(self, name, prop): self.put_docstr(prop.__doc__) self.write("\n") - def put_type(self, tp, name): + def put_type(self, tp: NbType, name: Optional[str]): """Append a 'nb_type' type object""" if name and (name != tp.__name__ or self.module.__name__ != tp.__module__): if self.module.__name__ == tp.__module__: @@ -394,7 +440,7 @@ def put_type(self, tp, name): is_enum = self.is_enum(tp) docstr = tp.__doc__ tp_dict = dict(tp.__dict__) - tp_bases = None + tp_bases: Union[List[str], Tuple[Any, ...], None] = None if is_enum: # Rewrite enumerations so that they derive from a helper type @@ -449,11 +495,11 @@ def put_type(self, tp, name): self.write_ln("pass\n") self.depth -= 1 - def is_enum(self, tp: type) -> bool: + def is_enum(self, tp: object) -> bool: """Check if the given type is an enumeration""" return hasattr(tp, "@entries") - def is_function(self, tp) -> bool: + def is_function(self, tp: type) -> bool: """ Test if this is one of the many types of built-in functions supported by Python, or if it is a nanobind ``nb_func``. @@ -468,7 +514,7 @@ def is_function(self, tp) -> bool: or (tp.__module__ == "nanobind" and tp.__name__ == "nb_func") ) - def put_value(self, value: object, name: str, parent=None, abbrev=True) -> None: + def put_value(self, value: object, name: str, parent: Optional[object] = None, abbrev: bool = True) -> None: """ Render a ``name: type = value`` assignment at the module, class, or enum scope. @@ -486,7 +532,8 @@ def put_value(self, value: object, name: str, parent=None, abbrev=True) -> None: self.write("\n") elif self.is_function(tp) or isinstance(value, type): # This is a function or a type, import it from its actual source - self.import_object(value.__module__, value.__name__, name) # type: ignore + value = cast(type, value) + self.import_object(value.__module__, value.__name__, name) else: value_str = self.expr_str(value, abbrev) @@ -495,7 +542,7 @@ def put_value(self, value: object, name: str, parent=None, abbrev=True) -> None: # Catch a few different typing.* constructs if issubclass(tp, typing.TypeVar) or ( - hasattr(typing, "TypeVarTuple") and issubclass(tp, typing.TypeVarTuple) + sys.version_info >= (3, 11) and issubclass(tp, typing.TypeVarTuple) ): types = "" elif typing.get_origin(value): @@ -529,10 +576,11 @@ def simplify_types(self, s: str) -> str: """ # Process nd-array type annotations so that MyPy accepts them - def process_ndarray(m: re.Match) -> str: + def process_ndarray(m: Match[str]) -> str: s = m.group(2) ndarray = self.import_object("numpy.typing", "ArrayLike") + assert ndarray s = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", s) s = s.replace("*", "None") @@ -545,7 +593,7 @@ def process_ndarray(m: re.Match) -> str: s = self.ndarray_re.sub(process_ndarray, s) # Process other type names and add suitable import statements - def process_general(m: re.Match) -> str: + def process_general(m: Match[str]) -> str: full_name, mod_name, cls_name = m.group(0), m.group(1)[:-1], m.group(2) if mod_name == "builtins": @@ -566,7 +614,7 @@ def process_general(m: re.Match) -> str: return s - def apply_pattern(self, value: object, pattern: Pattern, match: re.Match) -> None: + def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str]) -> None: """ Called when ``value`` matched an entry of a pattern file """ @@ -580,7 +628,8 @@ def apply_pattern(self, value: object, pattern: Pattern, match: re.Match) -> Non "nb_func", "nb_method", ): - for tp_i in value.__nb_signature__: # type: ignore + value = cast(NbFunction, value) + for tp_i in value.__nb_signature__: doc = tp_i[1] if doc: break @@ -614,7 +663,7 @@ def apply_pattern(self, value: object, pattern: Pattern, match: re.Match) -> Non line = line.replace(f"\\{k}", v) self.write_ln(line) - def put(self, value: object, name: Optional[str] = None, parent=None) -> None: + def put(self, value: object, name: Optional[str] = None, parent: Optional[object] = None) -> None: old_prefix = self.prefix if value in self.stack: @@ -667,16 +716,21 @@ def put(self, value: object, name: Optional[str] = None, parent=None) -> None: for name, child in getmembers(value): self.put(child, name=name, parent=value) elif self.is_function(tp): + value = cast(NbFunction, value) self.put_function(value, name, parent) elif issubclass(tp, type): + value = cast(NbType, value) self.put_type(value, name) elif tp_mod == "nanobind": if tp_name == "nb_method": + value = cast(NbFunction, value) self.put_nb_func(value, name) elif tp_name == "nb_static_property": + value = cast(NbStaticProperty, value) self.put_nb_static_property(name, value) elif tp_mod == "builtins": if tp is property: + value = cast(property, value) self.put_property(value, name) else: assert name is not None @@ -712,7 +766,7 @@ def import_object( module_short = module # Query a cache of previously imported objects - imports_module = self.imports.get(module_short, None) + imports_module: Optional[ImportDict] = self.imports.get(module_short, None) if not imports_module: imports_module = {} self.imports[module_short] = imports_module @@ -753,14 +807,14 @@ def import_object( imports_module[key] = final_name return final_name if final_name else "" - def expr_str(self, e, abbrev=True): + def expr_str(self, e: Any, abbrev: bool = True) -> Optional[str]: """ Attempt to convert a value into valid Python syntax that regenerates that value. When ``abbrev`` is True, give up and replace with '...' if the expression is too complicated to be included in the stubs """ tp = type(e) - for t in [bool, int, type(None), type(...)]: + for t in [bool, int, type(None), type(builtins.Ellipsis)]: if issubclass(tp, t): return repr(e) if issubclass(tp, float): @@ -775,20 +829,22 @@ def expr_str(self, e, abbrev=True): return self.type_str(e) elif issubclass(tp, typing.ForwardRef): return f'"{e.__forward_arg__}"' - elif hasattr(typing, "TypeVarTuple") and issubclass(tp, typing.TypeVarTuple): + elif sys.version_info >= (3, 11) and issubclass(tp, typing.TypeVarTuple): tv = self.import_object("typing", "TypeVarTuple") return f'{tv}("{e.__name__}")' elif issubclass(tp, typing.TypeVar): tv = self.import_object("typing", "TypeVar") s = f'{tv}("{e.__name__}"' for v in getattr(e, "__constraints__", ()): - s += ", " + self.expr_str(v) + v = self.expr_str(v) + assert v + s += ", " + v for k in ["contravariant", "covariant", "bound", "infer_variance"]: v = getattr(e, f"__{k}__", None) if v: v = self.expr_str(v) if v is None: - return + return None s += f", {k}=" + v s += ")" return s @@ -824,10 +880,10 @@ def expr_str(self, e, abbrev=True): pass return None - def signature_str(self, s: Signature): + def signature_str(self, s: Signature) -> str: """Convert an inspect.Signature to into valid Python syntax""" posonly_sep, kwonly_sep = False, True - params = [] + params: List[str] = [] # Logic for placing '*' and '/' based on the # signature.Signature implementation @@ -855,7 +911,7 @@ def signature_str(self, s: Signature): result += " -> " + self.type_str(s.return_annotation) return result - def param_str(self, p: Parameter): + def param_str(self, p: Parameter) -> str: result = "" if p.kind == Parameter.VAR_POSITIONAL: result += "*" @@ -869,10 +925,12 @@ def param_str(self, p: Parameter): result += ": " + self.type_str(p.annotation) if has_def: result += " = " if has_type else "=" - result += self.expr_str(p.default) + p_default_str = self.expr_str(p.default) + assert p_default_str + result += p_default_str return result - def type_str(self, tp): + def type_str(self, tp: Union[List[Any], Tuple[Any, ...], Dict[Any, Any], Any]) -> str: """Attempt to convert a type into a Python expression which reproduces it""" origin, args = typing.get_origin(tp), typing.get_args(tp) @@ -881,20 +939,24 @@ def type_str(self, tp): elif isinstance(tp, typing.ForwardRef): return repr(tp.__forward_arg__) elif isinstance(tp, list): - return "[" + ", ".join(self.type_str(a) for a in tp) + "]" + list_gen: Generator[str, None, None] = (self.type_str(a) for a in tp) + return "[" + ", ".join(list_gen) + "]" elif isinstance(tp, tuple): - return "(" + ", ".join(self.type_str(a) for a in tp) + ")" + tuple_gen: Generator[str, None, None] = (self.type_str(a) for a in tp) + return "(" + ", ".join(tuple_gen) + ")" elif isinstance(tp, dict): + dict_gen: Generator[str, None, None] = (repr(k) + ": " + self.type_str(v) for k, v in tp.items()) return ( "{" - + ", ".join(repr(k) + ": " + self.type_str(v) for k, v in tp.items()) + + ", ".join(dict_gen) + "}" ) elif origin and args: + args_gen: Generator[str, None, None] = (self.type_str(a) for a in args) result = ( self.type_str(origin) + "[" - + ", ".join(self.type_str(a) for a in args) + + ", ".join(args_gen) + "]" ) elif tp is types.ModuleType: @@ -911,7 +973,7 @@ def get(self) -> str: for module in sorted(self.imports): imports = self.imports[module] - items = [] + items: List[str] = [] for (k, v1), v2 in imports.items(): if k is None: @@ -1080,13 +1142,13 @@ def parse_options(args: List[str]) -> argparse.Namespace: return opt -def load_pattern_file(fname: str) -> List[Pattern]: +def load_pattern_file(fname: str) -> List[ReplacePattern]: with open(fname, "r") as f: f_lines = f.readlines() - patterns: List[Pattern] = [] + patterns: List[ReplacePattern] = [] - def add_pattern(query, lines): + def add_pattern(query: str, lines: List[str]): # Exactly 1 empty line at the end while lines and lines[-1].isspace(): lines.pop() @@ -1095,7 +1157,7 @@ def add_pattern(query, lines): # Identify deletions (replacement by only whitespace) if all((p.isspace() or len(p) == 0 for p in lines)): lines = [] - patterns.append(Pattern(re.compile(query[:-1]), lines, 0)) + patterns.append(ReplacePattern(re.compile(query[:-1]), lines, 0)) lines: List[str] lines, query, dedent = [], None, 0 @@ -1139,7 +1201,7 @@ def main(args: Optional[List[str]] = None) -> None: opt = parse_options(sys.argv[1:] if args is None else args) - patterns: List[Pattern] + patterns: List[ReplacePattern] if opt.pattern_file: if not opt.quiet: print('Using pattern file "%s" ..' % opt.pattern_file)