Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 86 additions & 50 deletions external-builds/pytorch/build_prod_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,48 +347,14 @@ def find_dir_containing(file_name: str, *possible_paths: Path) -> Path:
raise ValueError(f"No directory contains {file_name}: {possible_paths}")


def do_build(args: argparse.Namespace):
if args.install_rocm:
do_install_rocm(args)

if not args.version_suffix:
args.version_suffix = get_version_suffix_for_installed_rocm_package()

triton_dir: Path | None = args.triton_dir
pytorch_dir: Path | None = args.pytorch_dir
pytorch_audio_dir: Path | None = args.pytorch_audio_dir
pytorch_vision_dir: Path | None = args.pytorch_vision_dir
apex_dir: Path | None = args.apex_dir

rocm_sdk_version = get_rocm_sdk_version()
cmake_prefix = get_rocm_path("cmake")
bin_dir = get_rocm_path("bin")
rocm_dir = get_rocm_path("root")

print(f"rocm version {rocm_sdk_version}:")
print(f" PYTHON VERSION: {sys.version}")
print(f" CMAKE_PREFIX_PATH = {cmake_prefix}")
print(f" BIN = {bin_dir}")
print(f" ROCM_HOME = {rocm_dir}")

system_path = str(bin_dir) + os.path.pathsep + os.environ.get("PATH", "")
print(f" PATH = {system_path}")

pytorch_rocm_arch = args.pytorch_rocm_arch
if pytorch_rocm_arch is None:
pytorch_rocm_arch = get_rocm_sdk_targets()
print(
f" Using default PYTORCH_ROCM_ARCH from rocm-sdk targets: {pytorch_rocm_arch}"
)
else:
print(f" Using provided PYTORCH_ROCM_ARCH: {pytorch_rocm_arch}")

if not pytorch_rocm_arch:
raise ValueError(
"No --pytorch-rocm-arch provided and rocm-sdk targets returned empty. "
"Please specify --pytorch-rocm-arch (e.g., gfx942)."
)

def _setup_common_build_env(
cmake_prefix: Path,
rocm_dir: Path,
pytorch_rocm_arch: str,
triton_dir: Path | None,
is_windows: bool,
) -> dict[str, str]:
"""Construct the common environment dict shared by all wheel builds."""
env: dict[str, str] = {
"PYTHONUTF8": "1", # Some build files use utf8 characters, force IO encoding
"CMAKE_PREFIX_PATH": str(cmake_prefix),
Expand All @@ -398,12 +364,6 @@ def do_build(args: argparse.Namespace):
"USE_KINETO": os.environ.get("USE_KINETO", "ON" if not is_windows else "OFF"),
}

if args.use_ccache:
print("Building with ccache, clearing stats first")
env["CMAKE_C_COMPILER_LAUNCHER"] = "ccache"
env["CMAKE_CXX_COMPILER_LAUNCHER"] = "ccache"
run_command(["ccache", "--zero-stats"], cwd=tempfile.gettempdir())

# GLOO enabled for only Linux
if not is_windows:
env["USE_GLOO"] = "ON"
Expand Down Expand Up @@ -456,7 +416,7 @@ def do_build(args: argparse.Namespace):
# on directory layout.
# Obviously, this should be completely burned with fire once the root causes
# are eliminted.
hip_device_lib_path = get_rocm_path("root") / "lib" / "llvm" / "amdgcn" / "bitcode"
hip_device_lib_path = rocm_dir / "lib" / "llvm" / "amdgcn" / "bitcode"
if not hip_device_lib_path.exists():
print(
"WARNING: Default location of device libs not found. Relying on "
Expand All @@ -466,7 +426,7 @@ def do_build(args: argparse.Namespace):
env["HIP_DEVICE_LIB_PATH"] = str(hip_device_lib_path)

# OpenBLAS path setup
host_math_path = get_rocm_path("root") / "lib" / "host-math"
host_math_path = rocm_dir / "lib" / "host-math"
if not host_math_path.exists():
print(
"WARNING: Default location of host-math not found. "
Expand All @@ -477,6 +437,19 @@ def do_build(args: argparse.Namespace):
env["OpenBLAS_HOME"] = str(host_math_path)
env["OpenBLAS_LIB_NAME"] = "rocm-openblas"

return env


def _do_build_wheels_core(
args: argparse.Namespace,
env: dict[str, str],
triton_dir: Path | None,
pytorch_dir: Path | None,
pytorch_audio_dir: Path | None,
pytorch_vision_dir: Path | None,
apex_dir: Path | None,
) -> None:
"""Execute all wheel builds (triton, pytorch, audio, vision, apex)."""
# Build triton.
triton_requirement = None
if args.build_triton or (args.build_triton is None and triton_dir):
Expand Down Expand Up @@ -524,6 +497,69 @@ def do_build(args: argparse.Namespace):

print("--- Builds all completed")


def do_build(args: argparse.Namespace):
if args.install_rocm:
do_install_rocm(args)

if not args.version_suffix:
args.version_suffix = get_version_suffix_for_installed_rocm_package()

triton_dir: Path | None = args.triton_dir
pytorch_dir: Path | None = args.pytorch_dir
pytorch_audio_dir: Path | None = args.pytorch_audio_dir
pytorch_vision_dir: Path | None = args.pytorch_vision_dir
apex_dir: Path | None = args.apex_dir

rocm_sdk_version = get_rocm_sdk_version()
cmake_prefix = get_rocm_path("cmake")
bin_dir = get_rocm_path("bin")
rocm_dir = get_rocm_path("root")

print(f"rocm version {rocm_sdk_version}:")
print(f" PYTHON VERSION: {sys.version}")
print(f" CMAKE_PREFIX_PATH = {cmake_prefix}")
print(f" BIN = {bin_dir}")
print(f" ROCM_HOME = {rocm_dir}")

system_path = str(bin_dir) + os.path.pathsep + os.environ.get("PATH", "")
print(f" PATH = {system_path}")

pytorch_rocm_arch = args.pytorch_rocm_arch
if pytorch_rocm_arch is None:
pytorch_rocm_arch = get_rocm_sdk_targets()
print(
f" Using default PYTORCH_ROCM_ARCH from rocm-sdk targets: {pytorch_rocm_arch}"
)
else:
print(f" Using provided PYTORCH_ROCM_ARCH: {pytorch_rocm_arch}")

if not pytorch_rocm_arch:
raise ValueError(
"No --pytorch-rocm-arch provided and rocm-sdk targets returned empty. "
"Please specify --pytorch-rocm-arch (e.g., gfx942)."
)

env = _setup_common_build_env(
cmake_prefix, rocm_dir, pytorch_rocm_arch, triton_dir, is_windows
)

if args.use_ccache:
print("Building with ccache, clearing stats first")
env["CMAKE_C_COMPILER_LAUNCHER"] = "ccache"
env["CMAKE_CXX_COMPILER_LAUNCHER"] = "ccache"
run_command(["ccache", "--zero-stats"], cwd=tempfile.gettempdir())

_do_build_wheels_core(
args,
env,
triton_dir,
pytorch_dir,
pytorch_audio_dir,
pytorch_vision_dir,
apex_dir,
)

if args.use_ccache:
ccache_stats_output = capture(
["ccache", "--show-stats"], cwd=tempfile.gettempdir()
Expand Down
Loading