Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add jetson packages #16905

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/absl/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def repo():

# Attention: tools parse and update these lines.
# LINT.IfChange
ABSL_COMMIT = "fb3621f4f897824c0dbe0615fa94543df6192f30"
ABSL_SHA256 = "0320586856674d16b0b7a4d4afb22151bdc798490bb7f295eddd8f6a62b46fea"
ABSL_COMMIT = "67d126083c1584dd7dc584d700f853afaec365ca"
ABSL_SHA256 = "8652366395b2f20628281fd98c4413e9947d989fcb214f8bdc56351e8cd7e7d4"
# LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/abseil-cpp.cmake)

SYS_DIRS = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ OS_ARCH_DICT = {
_REDIST_ARCH_DICT = {
"linux-x86_64": "x86_64-unknown-linux-gnu",
"linux-sbsa": "aarch64-unknown-linux-gnu",
"linux-aarch64": "aarch64-unknown-linux-gnu",
}

SUPPORTED_ARCHIVE_EXTENSIONS = [
Expand Down Expand Up @@ -502,27 +503,55 @@ def cudnn_redist_init_repository(
cudnn_redist_path_prefix = cudnn_redist_path_prefix,
)

def detect_platform(repository_ctx):
"""Detect if the platform is x86_64, Jetson (aarch64), or generic ARM SBSA."""
# Use environment variables to distinguish between platforms
host_arch = repository_ctx.os.arch # Using Bazel's built-in context to detect architecture

if host_arch == "x86_64":
return "linux-x86_64" # x86_64 platform

if host_arch == "aarch64":
# Check if the build is targeting Jetson using an environment variable
is_jetson = repository_ctx.os.environ.get("JETSON_PLATFORM", None)
johnnynunez marked this conversation as resolved.
Show resolved Hide resolved
if is_jetson:
return "linux-aarch64" # Jetson platform
return "linux-sbsa" # Default to SBSA if not Jetson

# Default case, if something unexpected is encountered
fail("Unsupported architecture: {}".format(host_arch))

def cuda_redist_init_repositories(
cuda_redistributions,
cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX,
redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES):
# buildifier: disable=function-docstring-args
"""Initializes CUDA repositories."""
for redist_name, _ in redist_versions_to_build_templates.items():
cuda_redist_path_prefix=CUDA_REDIST_PATH_PREFIX,
redist_versions_to_build_templates=REDIST_VERSIONS_TO_BUILD_TEMPLATES):
"""Initializes CUDA repositories for multiple architectures, including Jetson and SBSA."""
for redist_name in redist_versions_to_build_templates:
if redist_name in ["cudnn", "cuda_nccl"]:
continue
if redist_name in cuda_redistributions.keys():

if redist_name in cuda_redistributions:
url_dict = _get_redistribution_urls(cuda_redistributions[redist_name])
else:
url_dict = {}

repo_data = redist_versions_to_build_templates[redist_name]
versions, templates = get_version_and_template_lists(
repo_data["version_to_template"],
repo_data["version_to_template"]
)

# Detect the platform architecture (x86_64, Jetson, or SBSA)
arch_key = detect_platform(repository_ctx)

# If the correct architecture is not found in url_dict, skip
if arch_key not in url_dict:
print(f"Platform {arch_key} is not supported for {redist_name}. Skipping.")
continue

cuda_repo(
name = repo_data["repo_name"],
versions = versions,
build_templates = templates,
url_dict = url_dict,
cuda_redist_path_prefix = cuda_redist_path_prefix,
name=repo_data["repo_name"],
versions=versions,
build_templates=templates,
url_dict=url_dict,
cuda_redist_path_prefix=cuda_redist_path_prefix,
)
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ CUDA_REDIST_JSON_DICT = {
"https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json",
"87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab",
],
"12.6.1": [
"https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.1.json",
"87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab",
],
}

CUDNN_REDIST_JSON_DICT = {
Expand Down Expand Up @@ -93,6 +97,10 @@ CUDNN_REDIST_JSON_DICT = {
"https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.3.0.json",
"d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e",
],
"9.4.0": [
"https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.4.0.json",
"d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e",
],
}

# The versions are different for x86 and aarch64 architectures because only
Expand Down Expand Up @@ -130,6 +138,7 @@ CUDA_NCCL_WHEELS = {
"12.5.0": CUDA_12_NCCL_WHEEL_DICT,
"12.5.1": CUDA_12_NCCL_WHEEL_DICT,
"12.6.0": CUDA_12_NCCL_WHEEL_DICT,
"12.6.1": CUDA_12_NCCL_WHEEL_DICT,
}

REDIST_VERSIONS_TO_BUILD_TEMPLATES = {
Expand Down