Skip to content

Commit

Permalink
feat: support recursive stubgen (#44)
Browse files Browse the repository at this point in the history
Implements the `recursive` and `output_directory` options of nanobind's stubgen, to allow stub generation for submodules in extensions.

Left to clarify is how to make generated stubs have the correct name in recursive mode - to that end, a shutil-based workaround was introduced in the stubgen script, fixing up the path name and copying the stub within the bindir.

This likely has to be done upstream, but requires a better understanding of file name generation in recursive mode, particularly depending on the working directory of the stubgen invoker.
  • Loading branch information
cemlyn007 authored Jan 6, 2025
1 parent 3fdbc38 commit 67b9fa1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
22 changes: 21 additions & 1 deletion build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,13 @@ def nanobind_stubgen(
name,
module,
output_file = None,
output_directory = None,
imports = [],
pattern_file = None,
marker_file = None,
include_private_members = False,
exclude_docstrings = False):
exclude_docstrings = False,
recursive = False):
"""Creates a stub file containing Python type annotations for a nanobind extension.
Args:
Expand All @@ -165,6 +167,10 @@ def nanobind_stubgen(
Output file path for the generated stub, relative to $(BINDIR).
If none is given, the stub will be placed under the same location
as the module in your source tree.
output_directory: str or None
Output directory for the generated stub, relative to $(BINDIR).
If none is given, the stub will be placed under the same location
as the module in your source tree.
imports: list
List of modules to import for stub generation.
pattern_file: Label or None
Expand All @@ -180,6 +186,8 @@ def nanobind_stubgen(
exclude_docstrings: bool
Whether to exclude all docstrings of all module members from the generated
stub file.
recursive: bool
Whether to perform stub generation on submodules as well.
"""
STUBGEN_WRAPPER = Label("@nanobind_bazel//:stubgen_wrapper.py")
loc = "$(rlocationpath {})"
Expand All @@ -196,6 +204,18 @@ def nanobind_stubgen(

args = ["-m " + loc.format(module)]

if recursive and output_file:
fail("Cannot specify an output file if recursive stubgen is requested")

if recursive and not output_directory:
fail("Must specify an output directory for recursive stubgen")

if recursive:
args.append("-r")

if output_directory:
args.append("-O {}".format(output_directory))

# to be searchable by path expansion, a file must be
# declared by a rule beforehand. This might not be the
# case for a generated stub, so we just give the raw name here
Expand Down
40 changes: 27 additions & 13 deletions stubgen_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys

import shutil
from pathlib import Path
from typing import Union

Expand All @@ -9,6 +9,7 @@
DEBUG = bool(os.getenv("DEBUG"))
RLOCATION_ROOT = Path("_main") # the Python path root under the script's runfiles.


def get_runfiles_dir(path: Union[str, os.PathLike]):
"""Obtain the runfiles root from the Python script path."""
ppath = Path(path)
Expand Down Expand Up @@ -74,18 +75,20 @@ def wrapper():
print(f"bindir = {bindir}")
fname = ""
for i, arg in enumerate(args):
if arg.startswith("-m"):
fname = args.pop(i + 1)
if not fname.endswith((".so", ".pyd")):
raise ValueError(
f"invalid extension file {fname!r}: "
"only shared object files with extensions "
".so, .abi3.so, or .pyd are supported"
)
modname = convert_path_to_module(fname)
args.insert(i + 1, modname)

if "-o" not in args:
if arg.startswith("-m"):
fname = args.pop(i + 1)
if not fname.endswith((".so", ".pyd")):
raise ValueError(
f"invalid extension file {fname!r}: "
"only shared object files with extensions "
".so, .abi3.so, or .pyd are supported"
)
modname = convert_path_to_module(fname)
args.insert(i + 1, modname)

if "-r" in args:
pass
elif "-o" not in args:
ext_path = runfiles_dir / fname
if DEBUG:
print(f"ext_path = {ext_path}")
Expand All @@ -108,13 +111,24 @@ def wrapper():
idx = args.index("-o")
args[idx + 1] = str(bindir / args[idx + 1])

if "-O" in args:
idx = args.index("-O")
args[idx + 1] = str(bindir / args[idx + 1])

if "-M" in args:
# fix up the path to the marker file relative to $(BINDIR).
idx = args.index("-M")
args[idx + 1] = str(bindir / args[idx + 1])

main(args)

if "-O" in args:
from_path = os.path.join(
*modname.split(".")[1:-1], ".".join(modname.split(".")[:-1]) + ".pyi"
)
to_path = os.path.join(*modname.split(".")[1:]) + ".pyi"
shutil.move(bindir / from_path, bindir / to_path)


if __name__ == "__main__":
wrapper()

0 comments on commit 67b9fa1

Please sign in to comment.