Skip to content

Commit

Permalink
Add NDArray protocol class for nd-array annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
yosh-matsuda committed Aug 17, 2024
1 parent 8e5fd14 commit 579a69e
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 92 deletions.
132 changes: 112 additions & 20 deletions src/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,31 @@ class and repeatedly call ``.put()`` to register modules or contents within the
import argparse
import builtins
import enum
from inspect import Signature, Parameter, signature, ismodule, getmembers
import textwrap
import importlib
import importlib.machinery
import importlib.util
import re
import sys
import textwrap
import types
import typing
from dataclasses import dataclass
from typing import Dict, Sequence, List, Optional, Tuple, cast, Generator, Any, Callable, Union, Protocol, Literal
from inspect import Parameter, Signature, getmembers, ismodule, signature
from pathlib import Path
import re
import sys
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Literal,
Optional,
Protocol,
Sequence,
Tuple,
Union,
cast,
)

if sys.version_info < (3, 9):
from typing import Match, Pattern
Expand Down Expand Up @@ -250,11 +263,28 @@ def __init__(
+ sep_after
)

# Precompile RE to extract nanobind nd-arrays
self.ndarray_re = re.compile(
sep_before + r"(numpy\.ndarray|ndarray|torch\.Tensor)\[([^\]]*)\]"
# Precompile RE to extract known nd-arrays
self.known_ndarray_re = re.compile(
sep_before
+ "("
+ "|".join(
[
r"numpy\.ndarray",
r"torch\.Tensor",
r"tensorflow\.python\.framework\.ops\.EagerTensor",
r"jaxlib\.xla_extension\.DeviceArray",
]
)
+ ")"
+ r"\[([^\]]*)\]"
)

# Precompile RE to extract nanobind nd-arrays
self.nb_ndarray_re = re.compile(sep_before + "(ndarray)" + r"\[([^\]]*)\]")

# Insert ndarray class
self.ndarray_class = False

# Types which moved from typing.* to collections.abc in Python 3.9
self.abc_re = re.compile(
'typing.(AsyncGenerator|AsyncIterable|AsyncIterator|Awaitable|Callable|'
Expand Down Expand Up @@ -606,7 +636,10 @@ def simplify_types(self, s: str) -> str:
- "NoneType" -> "None"
- "ndarray[...]" -> "Annotated[ArrayLike, dict(...)]"
- "<numpy|torch|tensorflow|jax array>[...]" -> "Annotated[<array>, dict(...)]"
- "ndarray[...]" -> "Annotated[NDArray, dict(...)]"
(with array protocol class added at top)
- "collections.abc.X" -> "X"
(with "from collections.abc import X" added at top)
Expand All @@ -616,22 +649,62 @@ def simplify_types(self, s: str) -> str:
changed to 'collections.abc' on newer Python versions)
"""

# Process nd-array type annotations so that MyPy accepts them
def process_ndarray(m: Match[str]) -> str:
s = m.group(2)
# Process nd-array type annotations with metadata
def process_known_ndarray(m: Match[str]) -> str:
ndarray_type = m.group(1)
meta = m.group(2)

ndarray = self.import_object("numpy.typing", "ArrayLike")
assert ndarray
s = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", s)
s = s.replace("*", "None")
if not meta:
return ndarray_type

if s:
if ndarray_type == "numpy.ndarray":
dm = re.search(r"dtype=([\w]*)\b", meta)
if dm and dm.group(1):
dtype = dm.group(1).replace("bool", "bool_")
ndarray_type = f"numpy.typing.NDArray[numpy.{dtype}]"

meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta)
meta = meta.replace("*", "None")

if sys.version_info >= (3, 9, 0):
annotated = self.import_object("typing", "Annotated")
return f"{annotated}[{ndarray}, dict({s})]"
else:
return ndarray
annotated = self.import_object("typing_extensions", "Annotated")
return f"{annotated}[{ndarray_type}, dict({meta})]"

s = self.known_ndarray_re.sub(process_known_ndarray, s)

# Process nb-ndarray type annotations with metadata
def process_nb_ndarray(m: Match[str]) -> str:
ndarray_type = "NDArray"
meta = m.group(2)

s = self.ndarray_re.sub(process_ndarray, s)
self.ndarray_class = True

self.import_object("typing", "Protocol")
if sys.version_info >= (3, 12, 0):
self.import_object("collections.abc", "Buffer")
else:
self.import_object("typing_extensions", "Buffer")
if sys.version_info >= (3, 10, 0):
self.import_object("typing", "TypeAlias")
else:
self.import_object("typing", "Union")
self.import_object("typing_extensions", "TypeAlias")

if not meta:
return ndarray_type

meta = re.sub(r"dtype=([\w]*)\b", r"dtype='\g<1>'", meta)
meta = meta.replace("*", "None")

if sys.version_info >= (3, 9, 0):
annotated = self.import_object("typing", "Annotated")
else:
annotated = self.import_object("typing_extensions", "Annotated")
return f"{annotated}[{ndarray_type}, dict({meta})]"

s = self.nb_ndarray_re.sub(process_nb_ndarray, s)

if sys.version_info >= (3, 9, 0):
s = self.abc_re.sub(r'collections.abc.\1', s)
Expand Down Expand Up @@ -1143,12 +1216,31 @@ def get(self) -> str:
s += items_v0 if len(items_v0) <= 70 else items_v1

s += "\n\n"
s += self.put_ndarray_class()

# Append the main generated stub
s += self.output

return s.rstrip() + "\n"

def put_ndarray_class(self) -> str:
s = ""
if not self.ndarray_class:
return s

s += "class DLPackBuffer(Protocol):\n"
s += " def __dlpack__(self) -> object: ...\n"
s += "\n"
if sys.version_info >= (3, 12, 0):
s += "type NDArray = Buffer | DLPackBuffer\n"
elif sys.version_info >= (3, 10, 0):
s += "NDArray: TypeAlias = Buffer | DLPackBuffer\n"
else:
s += "NDArray: TypeAlias = Union[Buffer, DLPackBuffer]\n"
s += "\n"

return s

def parse_options(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(
prog="python -m nanobind.stubgen",
Expand Down
Loading

0 comments on commit 579a69e

Please sign in to comment.