From 67b9fa1a269a99ef171abba9ead00dc5c479c8f1 Mon Sep 17 00:00:00 2001 From: Cemlyn Waters Date: Mon, 6 Jan 2025 14:51:52 +0000 Subject: [PATCH] feat: support recursive stubgen (#44) Implements the `recursive` and `output_directory` options of nanobind's stubgen, to allow stub generation for submodules in extensions. Left to clarify is how to make generated stubs have the correct name in recursive mode - to that end, a shutil-based workaround was introduced in the stubgen script, fixing up the path name and copying the stub within the bindir. This likely has to be done upstream, but requires a better understanding of file name generation in recursive mode, particularly depending on the working directory of the stubgen invoker. --- build_defs.bzl | 22 +++++++++++++++++++++- stubgen_wrapper.py | 40 +++++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/build_defs.bzl b/build_defs.bzl index 56f87b5..861d444 100644 --- a/build_defs.bzl +++ b/build_defs.bzl @@ -148,11 +148,13 @@ def nanobind_stubgen( name, module, output_file = None, + output_directory = None, imports = [], pattern_file = None, marker_file = None, include_private_members = False, - exclude_docstrings = False): + exclude_docstrings = False, + recursive = False): """Creates a stub file containing Python type annotations for a nanobind extension. Args: @@ -165,6 +167,10 @@ def nanobind_stubgen( Output file path for the generated stub, relative to $(BINDIR). If none is given, the stub will be placed under the same location as the module in your source tree. + output_directory: str or None + Output directory for the generated stub, relative to $(BINDIR). + If none is given, the stub will be placed under the same location + as the module in your source tree. imports: list List of modules to import for stub generation. pattern_file: Label or None @@ -180,6 +186,8 @@ def nanobind_stubgen( exclude_docstrings: bool Whether to exclude all docstrings of all module members from the generated stub file. + recursive: bool + Whether to perform stub generation on submodules as well. """ STUBGEN_WRAPPER = Label("@nanobind_bazel//:stubgen_wrapper.py") loc = "$(rlocationpath {})" @@ -196,6 +204,18 @@ def nanobind_stubgen( args = ["-m " + loc.format(module)] + if recursive and output_file: + fail("Cannot specify an output file if recursive stubgen is requested") + + if recursive and not output_directory: + fail("Must specify an output directory for recursive stubgen") + + if recursive: + args.append("-r") + + if output_directory: + args.append("-O {}".format(output_directory)) + # to be searchable by path expansion, a file must be # declared by a rule beforehand. This might not be the # case for a generated stub, so we just give the raw name here diff --git a/stubgen_wrapper.py b/stubgen_wrapper.py index 049adc2..4fedce8 100644 --- a/stubgen_wrapper.py +++ b/stubgen_wrapper.py @@ -1,6 +1,6 @@ import os import sys - +import shutil from pathlib import Path from typing import Union @@ -9,6 +9,7 @@ DEBUG = bool(os.getenv("DEBUG")) RLOCATION_ROOT = Path("_main") # the Python path root under the script's runfiles. + def get_runfiles_dir(path: Union[str, os.PathLike]): """Obtain the runfiles root from the Python script path.""" ppath = Path(path) @@ -74,18 +75,20 @@ def wrapper(): print(f"bindir = {bindir}") fname = "" for i, arg in enumerate(args): - if arg.startswith("-m"): - fname = args.pop(i + 1) - if not fname.endswith((".so", ".pyd")): - raise ValueError( - f"invalid extension file {fname!r}: " - "only shared object files with extensions " - ".so, .abi3.so, or .pyd are supported" - ) - modname = convert_path_to_module(fname) - args.insert(i + 1, modname) - - if "-o" not in args: + if arg.startswith("-m"): + fname = args.pop(i + 1) + if not fname.endswith((".so", ".pyd")): + raise ValueError( + f"invalid extension file {fname!r}: " + "only shared object files with extensions " + ".so, .abi3.so, or .pyd are supported" + ) + modname = convert_path_to_module(fname) + args.insert(i + 1, modname) + + if "-r" in args: + pass + elif "-o" not in args: ext_path = runfiles_dir / fname if DEBUG: print(f"ext_path = {ext_path}") @@ -108,6 +111,10 @@ def wrapper(): idx = args.index("-o") args[idx + 1] = str(bindir / args[idx + 1]) + if "-O" in args: + idx = args.index("-O") + args[idx + 1] = str(bindir / args[idx + 1]) + if "-M" in args: # fix up the path to the marker file relative to $(BINDIR). idx = args.index("-M") @@ -115,6 +122,13 @@ def wrapper(): main(args) + if "-O" in args: + from_path = os.path.join( + *modname.split(".")[1:-1], ".".join(modname.split(".")[:-1]) + ".pyi" + ) + to_path = os.path.join(*modname.split(".")[1:]) + ".pyi" + shutil.move(bindir / from_path, bindir / to_path) + if __name__ == "__main__": wrapper()