Skip to content

Commit

Permalink
Add multi-platform torch example
Browse files Browse the repository at this point in the history
This has been a useful starting point for me locally to repro pypi
downloader issues we see in our project.
  • Loading branch information
keith committed Nov 27, 2024
1 parent 077eec6 commit 4320360
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/platform_specific_deps/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bazel-*
30 changes: 30 additions & 0 deletions examples/platform_specific_deps/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
load("@pip//:requirements.bzl", "requirement")
load("@rules_python//python:defs.bzl", "py_test")
load("@rules_uv//uv:pip.bzl", "pip_compile")

[
pip_compile(
name = "generate_{}_requirements".format(arch),
args = [
"--emit-index-url",
"--index-strategy=unsafe-best-match", # NOTE: required because torch's index contains requests and that is preferred over pypi
],
python_platform = arch,
requirements_in = "requirements.in",
requirements_txt = "{}-requirements.txt".format(arch),
)
for arch in [
"aarch64-unknown-linux-gnu",
"x86_64-unknown-linux-gnu",
"aarch64-apple-darwin",
]
]

py_test(
name = "test",
srcs = ["test_load.py"],
main = "test_load.py",
deps = [
requirement("torch"),
],
)
32 changes: 32 additions & 0 deletions examples/platform_specific_deps/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module(
name = "platform_specific_deps",
version = "0.0.0",
compatibility_level = 1,
)

bazel_dep(name = "bazel_skylib", version = "1.7.1")
bazel_dep(name = "rules_python", version = "0.0.0")
# TODO: Replace with builtin uv support if it supports platform specific requirements output
bazel_dep(name = "rules_uv", version = "0.42.0")

local_path_override(
module_name = "rules_python",
path = "../..",
)

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(
python_version = "3.11",
)

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.parse(
hub_name = "pip",
python_version = "3.11",
requirements_by_platform = {
"//:x86_64-unknown-linux-gnu-requirements.txt": "linux_x86_64",
"//:aarch64-unknown-linux-gnu-requirements.txt": "linux_aarch64",
},
requirements_lock = "aarch64-apple-darwin-requirements.txt",
)
use_repo(pip, "pip")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# This file was autogenerated by uv via the following command:
# bazel run @@//:generate_aarch64-apple-darwin_requirements
--index-url https://pypi.org/simple
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.16.1
# via torch
fsspec==2024.10.0
# via torch
jinja2==3.1.4
# via torch
markupsafe==3.0.1
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.4.2
# via torch
sympy==1.13.3
# via torch
torch==2.4.1
# via -r requirements.in
typing-extensions==4.12.2
# via torch
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# This file was autogenerated by uv via the following command:
# bazel run @@//:generate_aarch64-unknown-linux-gnu_requirements
--index-url https://pypi.org/simple
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.16.1
# via torch
fsspec==2024.10.0
# via torch
jinja2==3.1.4
# via torch
markupsafe==3.0.1
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.4.2
# via torch
sympy==1.13.3
# via torch
torch==2.4.1
# via -r requirements.in
typing-extensions==4.12.2
# via torch
3 changes: 3 additions & 0 deletions examples/platform_specific_deps/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.1; platform_machine != "x86_64"
torch==2.4.1+cpu; platform_machine == "x86_64"
3 changes: 3 additions & 0 deletions examples/platform_specific_deps/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch # This verifies the deps were loaded as expected for the current platform

print("worked!")
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# This file was autogenerated by uv via the following command:
# bazel run @@//:generate_x86_64-unknown-linux-gnu_requirements
--index-url https://pypi.org/simple
--extra-index-url https://download.pytorch.org/whl/cpu

filelock==3.16.1
# via torch
fsspec==2024.10.0
# via torch
jinja2==3.1.4
# via torch
markupsafe==3.0.1
# via jinja2
mpmath==1.3.0
# via sympy
networkx==3.4.2
# via torch
sympy==1.13.3
# via torch
torch==2.4.1+cpu
# via -r requirements.in
typing-extensions==4.12.2
# via torch

0 comments on commit 4320360

Please sign in to comment.