diff --git a/build/build.py b/build/build.py index 41994c4b4f08..671b0235425b 100755 --- a/build/build.py +++ b/build/build.py @@ -383,6 +383,24 @@ def main(): "plugin is still experimental and is not ready for use yet." ), ) + add_boolean_argument( + parser, + "build_cuda_kernel_plugin", + default=False, + help_str=( + "Are we building the cuda kernel plugin? jaxlib will not be built " + "when this flag is True." + ), + ) + add_boolean_argument( + parser, + "build_cuda_pjrt_plugin", + default=False, + help_str=( + "Are we building the cuda pjrt plugin? jaxlib will not be built " + "when this flag is True." + ), + ) parser.add_argument( "--gpu_plugin_cuda_version", choices=["11", "12"], @@ -560,19 +578,20 @@ def main(): print("\nBuilding XLA and installing it in the jaxlib source tree...") - command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true"] + - ["//jaxlib/tools:build_wheel", "--", - f"--output_path={output_path}", - f"--cpu={wheel_cpu}"]) - if args.build_gpu_plugin: - command.append("--include_gpu_plugin_extension") - if args.editable: - command += ["--editable"] - print(" ".join(command)) - shell(command) - - if args.build_gpu_plugin: + if not args.build_cuda_kernel_plugin and not args.build_cuda_pjrt_plugin: + command = ([bazel_path] + args.bazel_startup_options + + ["run", "--verbose_failures=true"] + + ["//jaxlib/tools:build_wheel", "--", + f"--output_path={output_path}", + f"--cpu={wheel_cpu}"]) + if args.build_gpu_plugin: + command.append("--include_gpu_plugin_extension") + if args.editable: + command += ["--editable"] + print(" ".join(command)) + shell(command) + + if args.build_gpu_plugin or args.build_cuda_kernel_plugin: build_cuda_kernels_command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + ["//jaxlib/tools:build_cuda_kernels_wheel", "--", @@ -580,10 +599,11 @@ def main(): f"--cpu={wheel_cpu}", f"--cuda_version={args.gpu_plugin_cuda_version}"]) if args.editable: - command.append("--editable") + build_cuda_kernels_command.append("--editable") print(" ".join(build_cuda_kernels_command)) shell(build_cuda_kernels_command) + if args.build_gpu_plugin or args.build_cuda_pjrt_plugin: build_pjrt_plugin_command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + ["//jaxlib/tools:build_gpu_plugin_wheel", "--", @@ -591,7 +611,7 @@ def main(): f"--cpu={wheel_cpu}", f"--cuda_version={args.gpu_plugin_cuda_version}"]) if args.editable: - command.append("--editable") + build_pjrt_plugin_command.append("--editable") print(" ".join(build_pjrt_plugin_command)) shell(build_pjrt_plugin_command)