Skip to content
5 changes: 5 additions & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ if(BUILD_WHEEL)
${WHEEL_TARGET_NAME}
COMMENT "Copying files to wheel directory: ${WHEEL_TARGET_NAME}"
)
add_custom_command(TARGET python POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/stubgen.py -r ${PACKAGE_DIR_NAME} -v ${ORT_VERSION}
WORKING_DIRECTORY ${WHEEL_FILES_DIR}
COMMENT "Generate type stubs for onnxruntime_genai"
)
set(auditwheel_exclude_list
"libcublas.so.11"
"libcublas.so.12"
Expand Down
Empty file added src/python/py/py.typed
Empty file.
19 changes: 19 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,25 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
});
m.add_object("_cleanup", cleanup);

pybind11::enum_<OgaElementType>(m, "ElementType")
.value("undefined", OgaElementType_undefined)
.value("float32", OgaElementType_float32)
.value("uint8", OgaElementType_uint8)
.value("int8", OgaElementType_int8)
.value("uint16", OgaElementType_uint16)
.value("int16", OgaElementType_int16)
.value("int32", OgaElementType_int32)
.value("int64", OgaElementType_int64)
.value("string", OgaElementType_string)
.value("bool", OgaElementType_bool)
.value("float16", OgaElementType_float16)
.value("float64", OgaElementType_float64)
.value("uint32", OgaElementType_uint32)
.value("uint64", OgaElementType_uint64)
.value("complex64", OgaElementType_complex64)
.value("complex128", OgaElementType_complex128)
.value("bfloat16", OgaElementType_bfloat16);

pybind11::class_<PyGeneratorParams>(m, "GeneratorParams")
.def(pybind11::init<const OgaModel&>())
.def("try_graph_capture_with_max_batch_size", &PyGeneratorParams::TryGraphCaptureWithMaxBatchSize)
Expand Down
83 changes: 83 additions & 0 deletions src/python/stubgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
'''
Use this script to leverage python to find the stubgen installation.
'''

from pathlib import Path
from argparse import ArgumentParser
import subprocess
import sys
import platform
import shutil

def install_package(package_name: str, version: str | None = None) -> bool:
# Check if the package is already installed, is so, return False, meaning we did not install it.
try:
__import__(package_name)
return False
except ImportError:
pass
# First try install with --user, otherwise try install without --user.
package_spec = f"{package_name}=={version}" if version else package_name
print(f"Installing package: {package_spec}")
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--user', package_spec])
except subprocess.CalledProcessError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package_spec])
return True

def try_uninstall_package(package_name: str) -> None:
try:
subprocess.check_call([sys.executable, '-m', 'pip', 'uninstall', '-y', package_name])
except subprocess.CalledProcessError:
# Ignore any error.
pass

def generate_stubs(package: str, output: Path) -> None:
# Call in a subprocess to update sys.path correctly.
subprocess.check_call([sys.executable, '-c', f'from mypy.stubgen import main; main(["-p", "{package}", "-o", "{output.as_posix()}"])'])

def fix_stubs(root: Path) -> None:
# Try fix all .pyi files, though only the one generated from the native library needs fixing.
stub_files = root.rglob('*.pyi')
for stub_file in stub_files:
content = stub_file.read_text()
fixed_content = content.replace("Oga", "")
stub_file.write_text(fixed_content)

def clean_up_pycache(root: Path) -> None:
pycache_dirs = root.rglob('__pycache__')
for pycache_dir in pycache_dirs:
shutil.rmtree(pycache_dir, ignore_errors=True)

def main():
parser = ArgumentParser()
parser.add_argument('-r', '--wheel_root', type=Path, required=True, help='Wheel root directory containing __init__.py')
parser.add_argument('-v', '--ort-version', type=str, help='onnxruntime version to install, if not specified, the latest version will be installed')
args = parser.parse_args()

try:
onnxruntime_installed_by_us = install_package('onnxruntime', args.ort_version)
mypy_installed_by_us = install_package('mypy')
except subprocess.CalledProcessError as e:
print(f"Failed to install dependencies on this platform:\n{e}")
print("Skipping type stub generation.")
return

wheel_root = Path(args.wheel_root).resolve()
generate_stubs(wheel_root.name, wheel_root.parent)
fix_stubs(wheel_root)
clean_up_pycache(wheel_root)

if onnxruntime_installed_by_us:
try_uninstall_package('onnxruntime')
if mypy_installed_by_us:
try_uninstall_package('mypy')

if __name__ == '__main__':
# Explicitly skip windows ARM64.
# There is a cmake + MSVC issue causing the build to fail if the pip install subprocess fails.
# Even if we handled the exception and exited 0.
if sys.platform == 'win32' and platform.machine().lower() == 'arm64':
print('Stub generation is not supported on Windows ARM64 due to lack of onnxruntime package.')
exit(0)
main()
Loading