Skip to content

Commit

Permalink
Add --use_cuda_nvcc flag to build.py to enable compilation of CUDA …
Browse files Browse the repository at this point in the history
…code with clang.

If add `--use_cuda_nvcc`, then NVCC compiler will be used to build CUDA code (default case); otherwise, if add `--nouse_cuda_nvcc`, Clang will be used to build CUDA code.

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 676660938
  • Loading branch information
Google-ML-Automation committed Sep 21, 2024
1 parent d63afd8 commit 262aa94
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 33 deletions.
46 changes: 24 additions & 22 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ build:native_arch_posix --host_copt=-march=native

build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1

build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang"
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
# offset of in the current version of ubp.
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
build:clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:clang --copt=-Qunused-arguments

build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
Expand All @@ -68,14 +78,6 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand All @@ -89,23 +91,18 @@ build:win_clang --compiler=clang-cl
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
# The list of CUDA pip packages that JAX depends on are present in setup.py.
build:cuda --linkopt=-Wl,--disable-new-dtags
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

build:cuda_clang --config=clang
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
# offset of in the current version of ubp.
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:cuda_clang --copt=-Qunused-arguments

# Build with nvcc for CUDA and clang for host
build:nvcc_clang --config=cuda
build:nvcc_clang --config=cuda_clang
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
build:cuda_nvcc --config=cuda
build:cuda_nvcc --config=clang
build:cuda_nvcc --action_env=TF_NVCC_CLANG="1"
build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
Expand All @@ -114,6 +111,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1

build:nonccl --define=no_nccl_support=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Windows has a relatively short command line limit, which JAX has begun to hit.
# See https://docs.bazel.build/versions/main/windows.html
build:windows --features=compiler_param_file
Expand Down Expand Up @@ -200,7 +202,7 @@ build:rbe_linux --host_linkopt=-lm
# Use the GPU toolchain until the CPU one is ready.
# https://github.com/bazelbuild/bazel/issues/13623
build:rbe_cpu_linux_base --config=rbe_linux
build:rbe_cpu_linux_base --config=cuda_clang
build:rbe_cpu_linux_base --config=clang
build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64"
Expand All @@ -223,7 +225,7 @@ build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1

build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
Expand Down
35 changes: 25 additions & 10 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_clang_path_or_exit():
return str(pathlib.Path(which_clang_output).resolve())
else:
print(
"--use_clang set, but --clang_path is unset and clang cannot be found"
"--clang_path is unset and clang cannot be found"
" on the PATH. Please pass --clang_path directly."
)
sys.exit(-1)
Expand All @@ -241,8 +241,9 @@ def write_bazelrc(*, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, target_cpu_features,
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
python_version):
clang_major_version, python_version,
enable_cuda, enable_nccl, enable_rocm,
use_cuda_nvcc):

with open("../.jax_configure.bazelrc", "w") as f:
if not remote_build:
Expand Down Expand Up @@ -286,8 +287,10 @@ def write_bazelrc(*, remote_build,
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=nvcc_clang\n")
f.write("build --config=cuda_clang\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if use_cuda_nvcc:
f.write("build --config=cuda_nvcc\n")
if cuda_version:
f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n"
.format(cuda_version=cuda_version))
Expand Down Expand Up @@ -392,15 +395,15 @@ def main():
"use_clang",
default = "true",
help_str=(
"Should we build using clang as the host compiler? Requires "
"clang to be findable via the PATH, or a path to be given via "
"--clang_path."
"Should we build using clang as the host compiler? "
"DEPRECATED: This flag is redundant because clang is "
"always used as default compiler."
),
)
parser.add_argument(
"--clang_path",
help=(
"Path to clang binary to use if --use_clang is set. The default is "
"Path to clang binary to use. The default is "
"to find clang via the PATH."
),
)
Expand All @@ -413,7 +416,18 @@ def main():
add_boolean_argument(
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN."
)
add_boolean_argument(
parser,
"use_cuda_nvcc",
default=True,
help_str=(
"Should we build CUDA using NVCC as the compiler? "
"The default value is true. "
"If --nouse_cuda_nvcc flag is used then CUDA builds by clang compiler."
),
)
add_boolean_argument(
parser,
"build_gpu_plugin",
Expand Down Expand Up @@ -617,10 +631,11 @@ def main():
use_clang=args.use_clang,
clang_path=clang_path,
clang_major_version=clang_major_version,
python_version=python_version,
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_rocm=args.enable_rocm,
python_version=python_version,
use_cuda_nvcc=args.use_cuda_nvcc,
)

if args.requirements_update or args.requirements_nightly_update:
Expand Down
3 changes: 2 additions & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ There are two ways to build `jaxlib` with CUDA support: (1) use
support, or (2) use
`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12`
to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and
jax-cuda-pjrt).
jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and
clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag.

See `python build/build.py --help` for configuration options. Here
`python` should be the name of your Python 3 interpreter; on some systems, you
Expand Down

0 comments on commit 262aa94

Please sign in to comment.