diff --git a/external-builds/pytorch/build_prod_wheels.py b/external-builds/pytorch/build_prod_wheels.py index 10e878f800e..d97ac03b170 100755 --- a/external-builds/pytorch/build_prod_wheels.py +++ b/external-builds/pytorch/build_prod_wheels.py @@ -347,48 +347,14 @@ def find_dir_containing(file_name: str, *possible_paths: Path) -> Path: raise ValueError(f"No directory contains {file_name}: {possible_paths}") -def do_build(args: argparse.Namespace): - if args.install_rocm: - do_install_rocm(args) - - if not args.version_suffix: - args.version_suffix = get_version_suffix_for_installed_rocm_package() - - triton_dir: Path | None = args.triton_dir - pytorch_dir: Path | None = args.pytorch_dir - pytorch_audio_dir: Path | None = args.pytorch_audio_dir - pytorch_vision_dir: Path | None = args.pytorch_vision_dir - apex_dir: Path | None = args.apex_dir - - rocm_sdk_version = get_rocm_sdk_version() - cmake_prefix = get_rocm_path("cmake") - bin_dir = get_rocm_path("bin") - rocm_dir = get_rocm_path("root") - - print(f"rocm version {rocm_sdk_version}:") - print(f" PYTHON VERSION: {sys.version}") - print(f" CMAKE_PREFIX_PATH = {cmake_prefix}") - print(f" BIN = {bin_dir}") - print(f" ROCM_HOME = {rocm_dir}") - - system_path = str(bin_dir) + os.path.pathsep + os.environ.get("PATH", "") - print(f" PATH = {system_path}") - - pytorch_rocm_arch = args.pytorch_rocm_arch - if pytorch_rocm_arch is None: - pytorch_rocm_arch = get_rocm_sdk_targets() - print( - f" Using default PYTORCH_ROCM_ARCH from rocm-sdk targets: {pytorch_rocm_arch}" - ) - else: - print(f" Using provided PYTORCH_ROCM_ARCH: {pytorch_rocm_arch}") - - if not pytorch_rocm_arch: - raise ValueError( - "No --pytorch-rocm-arch provided and rocm-sdk targets returned empty. " - "Please specify --pytorch-rocm-arch (e.g., gfx942)." - ) - +def _setup_common_build_env( + cmake_prefix: Path, + rocm_dir: Path, + pytorch_rocm_arch: str, + triton_dir: Path | None, + is_windows: bool, +) -> dict[str, str]: + """Construct the common environment dict shared by all wheel builds.""" env: dict[str, str] = { "PYTHONUTF8": "1", # Some build files use utf8 characters, force IO encoding "CMAKE_PREFIX_PATH": str(cmake_prefix), @@ -398,12 +364,6 @@ def do_build(args: argparse.Namespace): "USE_KINETO": os.environ.get("USE_KINETO", "ON" if not is_windows else "OFF"), } - if args.use_ccache: - print("Building with ccache, clearing stats first") - env["CMAKE_C_COMPILER_LAUNCHER"] = "ccache" - env["CMAKE_CXX_COMPILER_LAUNCHER"] = "ccache" - run_command(["ccache", "--zero-stats"], cwd=tempfile.gettempdir()) - # GLOO enabled for only Linux if not is_windows: env["USE_GLOO"] = "ON" @@ -456,7 +416,7 @@ def do_build(args: argparse.Namespace): # on directory layout. # Obviously, this should be completely burned with fire once the root causes # are eliminted. - hip_device_lib_path = get_rocm_path("root") / "lib" / "llvm" / "amdgcn" / "bitcode" + hip_device_lib_path = rocm_dir / "lib" / "llvm" / "amdgcn" / "bitcode" if not hip_device_lib_path.exists(): print( "WARNING: Default location of device libs not found. Relying on " @@ -466,7 +426,7 @@ def do_build(args: argparse.Namespace): env["HIP_DEVICE_LIB_PATH"] = str(hip_device_lib_path) # OpenBLAS path setup - host_math_path = get_rocm_path("root") / "lib" / "host-math" + host_math_path = rocm_dir / "lib" / "host-math" if not host_math_path.exists(): print( "WARNING: Default location of host-math not found. " @@ -477,6 +437,19 @@ def do_build(args: argparse.Namespace): env["OpenBLAS_HOME"] = str(host_math_path) env["OpenBLAS_LIB_NAME"] = "rocm-openblas" + return env + + +def _do_build_wheels_core( + args: argparse.Namespace, + env: dict[str, str], + triton_dir: Path | None, + pytorch_dir: Path | None, + pytorch_audio_dir: Path | None, + pytorch_vision_dir: Path | None, + apex_dir: Path | None, +) -> None: + """Execute all wheel builds (triton, pytorch, audio, vision, apex).""" # Build triton. triton_requirement = None if args.build_triton or (args.build_triton is None and triton_dir): @@ -524,6 +497,69 @@ def do_build(args: argparse.Namespace): print("--- Builds all completed") + +def do_build(args: argparse.Namespace): + if args.install_rocm: + do_install_rocm(args) + + if not args.version_suffix: + args.version_suffix = get_version_suffix_for_installed_rocm_package() + + triton_dir: Path | None = args.triton_dir + pytorch_dir: Path | None = args.pytorch_dir + pytorch_audio_dir: Path | None = args.pytorch_audio_dir + pytorch_vision_dir: Path | None = args.pytorch_vision_dir + apex_dir: Path | None = args.apex_dir + + rocm_sdk_version = get_rocm_sdk_version() + cmake_prefix = get_rocm_path("cmake") + bin_dir = get_rocm_path("bin") + rocm_dir = get_rocm_path("root") + + print(f"rocm version {rocm_sdk_version}:") + print(f" PYTHON VERSION: {sys.version}") + print(f" CMAKE_PREFIX_PATH = {cmake_prefix}") + print(f" BIN = {bin_dir}") + print(f" ROCM_HOME = {rocm_dir}") + + system_path = str(bin_dir) + os.path.pathsep + os.environ.get("PATH", "") + print(f" PATH = {system_path}") + + pytorch_rocm_arch = args.pytorch_rocm_arch + if pytorch_rocm_arch is None: + pytorch_rocm_arch = get_rocm_sdk_targets() + print( + f" Using default PYTORCH_ROCM_ARCH from rocm-sdk targets: {pytorch_rocm_arch}" + ) + else: + print(f" Using provided PYTORCH_ROCM_ARCH: {pytorch_rocm_arch}") + + if not pytorch_rocm_arch: + raise ValueError( + "No --pytorch-rocm-arch provided and rocm-sdk targets returned empty. " + "Please specify --pytorch-rocm-arch (e.g., gfx942)." + ) + + env = _setup_common_build_env( + cmake_prefix, rocm_dir, pytorch_rocm_arch, triton_dir, is_windows + ) + + if args.use_ccache: + print("Building with ccache, clearing stats first") + env["CMAKE_C_COMPILER_LAUNCHER"] = "ccache" + env["CMAKE_CXX_COMPILER_LAUNCHER"] = "ccache" + run_command(["ccache", "--zero-stats"], cwd=tempfile.gettempdir()) + + _do_build_wheels_core( + args, + env, + triton_dir, + pytorch_dir, + pytorch_audio_dir, + pytorch_vision_dir, + apex_dir, + ) + if args.use_ccache: ccache_stats_output = capture( ["ccache", "--show-stats"], cwd=tempfile.gettempdir()