Skip to content

Commit

Permalink
Add two flags to support only building cuda kernel plugin or cuda pjr…
Browse files Browse the repository at this point in the history
…t plugin.

PiperOrigin-RevId: 591274120
  • Loading branch information
Jieying Luo authored and jax authors committed Dec 15, 2023
1 parent 4153112 commit c8b3567
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,24 @@ def main():
"plugin is still experimental and is not ready for use yet."
),
)
add_boolean_argument(
parser,
"build_cuda_kernel_plugin",
default=False,
help_str=(
"Are we building the cuda kernel plugin? jaxlib will not be built "
"when this flag is True."
),
)
add_boolean_argument(
parser,
"build_cuda_pjrt_plugin",
default=False,
help_str=(
"Are we building the cuda pjrt plugin? jaxlib will not be built "
"when this flag is True."
),
)
parser.add_argument(
"--gpu_plugin_cuda_version",
choices=["11", "12"],
Expand Down Expand Up @@ -560,38 +578,40 @@ def main():

print("\nBuilding XLA and installing it in the jaxlib source tree...")

command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
if args.build_gpu_plugin:
command.append("--include_gpu_plugin_extension")
if args.editable:
command += ["--editable"]
print(" ".join(command))
shell(command)

if args.build_gpu_plugin:
if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin:
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
if args.build_gpu_plugin:
command.append("--include_gpu_plugin_extension")
if args.editable:
command += ["--editable"]
print(" ".join(command))
shell(command)

if args.build_gpu_plugin or args.build_cuda_kernel_plugin:
build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_cuda_kernels_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
build_cuda_kernels_command.append("--editable")
print(" ".join(build_cuda_kernels_command))
shell(build_cuda_kernels_command)

if args.build_gpu_plugin or args.build_cuda_pjrt_plugin:
build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] +
["//jaxlib/tools:build_gpu_plugin_wheel", "--",
f"--output_path={output_path}",
f"--cpu={wheel_cpu}",
f"--cuda_version={args.gpu_plugin_cuda_version}"])
if args.editable:
command.append("--editable")
build_pjrt_plugin_command.append("--editable")
print(" ".join(build_pjrt_plugin_command))
shell(build_pjrt_plugin_command)

Expand Down

0 comments on commit c8b3567

Please sign in to comment.