Skip to content

Commit e2d73e2

Browse files
committed
Add site-package example
1 parent 6442564 commit e2d73e2

File tree

7 files changed

+1138
-0
lines changed

7 files changed

+1138
-0
lines changed

examples/site-packages/.bazelrc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
common --@rules_python//python/config_settings:bootstrap_impl=script
2+
common --@rules_python//python/config_settings:venvs_site_packages=yes
3+
# See https://github.com/bazel-contrib/rules_python/issues/2864#issuecomment-2859325467
4+
# common --@rules_python//python/config_settings:venvs_use_declare_symlink=no
5+
6+
build --incompatible_default_to_explicit_init_py=true
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
8.2.1

examples/site-packages/BUILD.bazel

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
load("@rules_python//python:defs.bzl", "py_test")
2+
3+
4+
py_test(
5+
name = "test",
6+
srcs = ["test.py"],
7+
deps = [
8+
"@pypi//mujoco",
9+
# "@pypi//mujoco_mjx",
10+
# "@pypi//tensorflow",
11+
# "@pypi//importlib_resources",
12+
# "@pypi//typing_extensions", # ModuleNotFoundError: No module named 'typing_extensions'
13+
"@pypi//jax",
14+
"@pypi//jax_cuda12_pjrt",
15+
"@pypi//jax_cuda12_plugin",
16+
"@pypi//nvidia_cublas_cu12",
17+
"@pypi//nvidia_cuda_cupti_cu12",
18+
"@pypi//nvidia_cuda_nvcc_cu12",
19+
"@pypi//nvidia_cuda_nvrtc_cu12",
20+
"@pypi//nvidia_cuda_runtime_cu12",
21+
"@pypi//nvidia_cudnn_cu12",
22+
"@pypi//nvidia_cufft_cu12",
23+
"@pypi//nvidia_cusolver_cu12",
24+
"@pypi//nvidia_cusparse_cu12",
25+
"@pypi//nvidia_nccl_cu12",
26+
"@pypi//nvidia_nvjitlink_cu12",
27+
"@pypi//nvidia_nvshmem_cu12",
28+
],
29+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Site-packages example."""
2+
3+
module(
4+
name = "site-packages-example",
5+
version = "0.0.1",
6+
)
7+
8+
bazel_dep(name = "rules_python", version = "0.0.0")
9+
bazel_dep(name = "platforms", version = "0.0.11")
10+
11+
local_path_override(
12+
module_name = "rules_python",
13+
path = "../..",
14+
)
15+
16+
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
17+
python.toolchain(
18+
is_default = True,
19+
python_version = "3.12",
20+
)
21+
python.override(minor_mapping = {"3.12": "3.12.7"})
22+
23+
pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
24+
pip.parse(
25+
experimental_index_url = "https://pypi.org/simple",
26+
hub_name = "pypi",
27+
python_version = "3.12",
28+
requirements_lock = "requirements.txt",
29+
# enable_implicit_namespace_pkgs = False,
30+
)
31+
32+
use_repo(pip, "pypi")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mujoco
2+
# mujoco-mjx
3+
jax[cuda12]==0.7.1
4+
jaxlib==0.7.1
5+
tensorflow

examples/site-packages/requirements.txt

Lines changed: 1045 additions & 0 deletions
Large diffs are not rendered by default.

examples/site-packages/test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""A simple site-packages test."""
2+
3+
import jax
4+
import mujoco
5+
# `mjx` does not work because "it's a prefix of the other."
6+
# from mujoco import mjx
7+
import nvidia
8+
9+
10+
def main() -> None:
11+
print(f"Hello, {nvidia=}!")
12+
print(f"Hello, {jax=}!")
13+
print(f"{mujoco=}!")
14+
# print(f"{mjx=}")
15+
16+
print(f"{jax.devices('gpu')=}")
17+
18+
19+
if __name__ == "__main__":
20+
main()

0 commit comments

Comments
 (0)