diff --git a/libshortfin/python/CMakeLists.txt b/libshortfin/python/CMakeLists.txt index f0048d6a7..f651484f9 100644 --- a/libshortfin/python/CMakeLists.txt +++ b/libshortfin/python/CMakeLists.txt @@ -35,18 +35,43 @@ target_link_libraries(shortfin_python_extension PRIVATE ${SHORTFIN_LINK_LIBRARY_NAME} ) -if (SHORTFIN_ENABLE_TRACING) +function(shortfin_python_stubs build_type) nanobind_add_stub( shortfin_python_extension_stub - MODULE _shortfin_tracy.lib - OUTPUT _shortfin_tracy/lib.pyi + MODULE _shortfin_${build_type}.lib + OUTPUT _shortfin_${build_type}/lib.pyi DEPENDS shortfin_python_extension ) -else() + +endfunction() + +function(shortfin_python_stubs build_variant output_root) + file(MAKE_DIRECTORY ${output_root}) nanobind_add_stub( - shortfin_python_extension_stub + shortfin_python_extension_stub_lib_${build_variant} MODULE _shortfin_default.lib - OUTPUT _shortfin_default/lib.pyi + OUTPUT ${output_root}/lib/__init__.pyi DEPENDS shortfin_python_extension ) + + nanobind_add_stub( + shortfin_python_extension_stub_array_${build_variant} + MODULE _shortfin_default.lib.array + OUTPUT ${output_root}/lib/array.pyi + DEPENDS shortfin_python_extension + ) + + nanobind_add_stub( + shortfin_python_extension_stub_local_${build_variant} + MODULE _shortfin_default.lib.local + OUTPUT ${output_root}/lib/local.pyi + DEPENDS shortfin_python_extension + ) +endfunction() + +# Generate the same stubs against the default build for each variant but +# output the files to the right package. +shortfin_python_stubs(default ${CMAKE_CURRENT_BINARY_DIR}/_shortfin_default) +if (SHORTFIN_ENABLE_TRACING) + shortfin_python_stubs(tracy ${CMAKE_CURRENT_BINARY_DIR}/_shortfin_tracy) endif() diff --git a/libshortfin/python/_shortfin/__init__.py b/libshortfin/python/_shortfin/__init__.py index 79fbd924a..9bfa3a497 100644 --- a/libshortfin/python/_shortfin/__init__.py +++ b/libshortfin/python/_shortfin/__init__.py @@ -7,23 +7,28 @@ # The proper way to import this package is via: # from _shortfin import lib as sfl +from typing import TYPE_CHECKING + import os import sys import warnings -variant = os.getenv("SHORTFIN_PY_RUNTIME", "default") - -if variant == "tracy": - try: - from _shortfin_tracy import lib - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "Shortfin Tracy runtime requested via SHORTFIN_PY_RUNTIME but it is not enabled in this build" - ) - print("-- Using Tracy runtime (SHORTFIN_PY_RUNTIME=tracy)", file=sys.stderr) -else: - if variant != "default": - warnings.warn( - f"Unknown value for SHORTFIN_PY_RUNTIME env var ({variant}): Using default" - ) +if TYPE_CHECKING: from _shortfin_default import lib +else: + variant = os.getenv("SHORTFIN_PY_RUNTIME", "default") + + if variant == "tracy": + try: + from _shortfin_tracy import lib + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "Shortfin Tracy runtime requested via SHORTFIN_PY_RUNTIME but it is not enabled in this build" + ) + print("-- Using Tracy runtime (SHORTFIN_PY_RUNTIME=tracy)", file=sys.stderr) + else: + if variant != "default": + warnings.warn( + f"Unknown value for SHORTFIN_PY_RUNTIME env var ({variant}): Using default" + ) + from _shortfin_default import lib