Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion aiter/dist/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@

import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
try:
from torch.distributed import Backend, ProcessGroup
except ImportError:
# torch.distributed is not available on all Windows ROCm builds.
# Set to None so the module can be imported; code paths that actually
# use these names are guarded by distributed-availability checks.
Backend = None # type: ignore[assignment]
ProcessGroup = None # type: ignore[assignment]

import os
from aiter import logger
Expand Down
146 changes: 124 additions & 22 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,13 @@ def get_user_jit_dir() -> str:
def validate_and_update_archs():
archs = os.getenv("GPU_ARCHS", "native").split(";")
archs = [arch.strip() for arch in archs]

if "native" in archs and sys.platform == "win32":
# hipcc on Windows does not support --offload-arch=native;
# resolve it to the actual gfx string detected from hipinfo/rocminfo.
resolved = get_gfx()
archs = [resolved if a == "native" else a for a in archs]

# List of allowed architectures
allowed_archs = [
"native",
Expand Down Expand Up @@ -398,18 +405,56 @@ def validate_and_update_archs():
@functools.lru_cache()
def hip_flag_checker(flag_hip: str) -> bool:
import subprocess
import tempfile
from cpp_extension import get_cxx_compiler, ROCM_HOME

cmd = (
["hipcc"]
+ flag_hip.split()
+ ["-x", "hip", "-E", "-P", "/dev/null", "-o", "/dev/null"]
)
try:
subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
except subprocess.CalledProcessError:
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
return False
return True
hipcc_bin = get_cxx_compiler()

# On Windows hipcc.exe uses the HIP_PATH environment variable to find
# HIP headers and device libs — not --rocm-path compiler flags.
# Inject it into the subprocess environment if not already set.
extra_env = {}
if sys.platform == "win32" and ROCM_HOME and "HIP_PATH" not in os.environ:
extra_env["HIP_PATH"] = ROCM_HOME

if sys.platform == "win32":
tmp_in = None
tmp_out = None
try:
with tempfile.NamedTemporaryFile(suffix=".cu", delete=False, mode="w") as f:
f.write("// flag check\n")
tmp_in = f.name
tmp_out = tmp_in.replace(".cu", ".o")
env = {**os.environ, **extra_env}
cmd = (
[hipcc_bin]
+ flag_hip.split()
+ ["-x", "hip", "-c", tmp_in, "-o", tmp_out]
)
subprocess.check_output(cmd, stderr=subprocess.DEVNULL, env=env)
return True
except subprocess.CalledProcessError:
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
return False
finally:
for f in filter(None, [tmp_in, tmp_out]):
try:
os.unlink(f)
except OSError:
pass
else:
null_dev = "/dev/null"
cmd = (
[hipcc_bin]
+ flag_hip.split()
+ ["-x", "hip", "-E", "-P", null_dev, "-o", null_dev]
)
try:
subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
except subprocess.CalledProcessError:
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
return False
return True


@functools.lru_cache()
Expand All @@ -421,14 +466,53 @@ def check_LLVM_MAIN_REVISION():
#else
#define CK_TILE_HOST_DEVICE_EXTERN"""
import subprocess

cmd = """echo "#include <tuple>
import tempfile
from cpp_extension import get_cxx_compiler

src = "#include <tuple>\n__host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}\n"
hipcc_bin = get_cxx_compiler()
if sys.platform == "win32":
# On Windows we can't use bash pipes; write to a temp file instead.
with tempfile.NamedTemporaryFile(suffix=".cu", delete=False, mode="w") as f:
f.write(src)
tmp = f.name
try:
subprocess.check_output(
[
hipcc_bin,
"-x",
"hip",
"-P",
"-c",
"-Wno-unused-command-line-argument",
tmp,
],
stderr=subprocess.STDOUT,
text=True,
)
return 554785 - 1
except subprocess.CalledProcessError:
return 554785
finally:
try:
os.unlink(tmp)
except OSError:
pass
obj = tmp.replace(".cu", ".obj")
try:
os.unlink(obj)
except OSError:
pass
else:
cmd = """echo "#include <tuple>
__host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}" | hipcc -x hip -P -c -Wno-unused-command-line-argument -o /dev/null -"""
try:
subprocess.check_output(cmd, shell=True, text=True, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError:
return 554785
return 554785 - 1
try:
subprocess.check_output(
cmd, shell=True, text=True, stderr=subprocess.STDOUT
)
except subprocess.CalledProcessError:
return 554785
return 554785 - 1


def check_and_set_ninja_worker():
Expand Down Expand Up @@ -489,6 +573,8 @@ def do_rename_and_mv(name, src, dst, ret):

@torch_compile_guard()
def check_numa_custom_op() -> None:
if sys.platform == "win32":
return # /proc/sys/kernel/numa_balancing does not exist on Windows
numa_balance_set = os.popen("cat /proc/sys/kernel/numa_balancing").read().strip()
if numa_balance_set == "1":
logger.warning(
Expand Down Expand Up @@ -654,11 +740,20 @@ def check_git_version(required_major, required_minor):


def rm_module(md_name):
os.system(f"rm -rf {get_user_jit_dir()}/{md_name}.so")
import glob as _glob

_lib_ext = ".pyd" if sys.platform == "win32" else ".so"
for _f in _glob.glob(f"{get_user_jit_dir()}/{md_name}{_lib_ext}"):
try:
os.remove(_f)
except OSError:
pass


def clear_build(md_name):
os.system(f"rm -rf {bd_dir}/{md_name}")
import shutil as _shutil

_shutil.rmtree(f"{bd_dir}/{md_name}", ignore_errors=True)


def build_module(
Expand All @@ -679,7 +774,8 @@ def build_module(
os.makedirs(bd_dir, exist_ok=True)
lock_path = f"{bd_dir}/lock_{md_name}"
startTS = time.perf_counter()
target_name = f"{md_name}.so" if not is_standalone else md_name
_lib_ext = ".pyd" if sys.platform == "win32" else ".so"
target_name = f"{md_name}{_lib_ext}" if not is_standalone else md_name

for tp in third_party:
clone_3rdparty(tp)
Expand Down Expand Up @@ -1255,6 +1351,7 @@ def wrapper(*args, custom_build_args={}, **kwargs):
d_args = get_args_of_build(md_name)
d_args.update(custom_build_args)

# update module name if we have a custom build
md_name = custom_build_args.get("md_name", md_name)

srcs = d_args["srcs"]
Expand Down Expand Up @@ -1324,6 +1421,7 @@ def check_args():
doc_str = doc_str.replace("collections.abc.Sequence[", "List[")
doc_str = doc_str.replace("typing.SupportsInt", "int")
doc_str = doc_str.replace("typing.SupportsFloat", "float")
# A|None --> Optional[A]
pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
doc_str = re.sub(pattern, r"Optional[\1]", doc_str)
for el in enum_types:
Expand Down Expand Up @@ -1353,14 +1451,18 @@ def check_args():

if origin is None:
if not isinstance(arg, expected_type) and not (
# aiter_enum can be int
any(el in str(expected_type) for el in enum_types)
and isinstance(arg, int)
):
raise TypeError(
f"{loadName}: {el} needs to be {expected_type} but got {got_type}"
)
elif origin is list:
if not isinstance(arg, list):
if (
not isinstance(arg, list)
# or not all(isinstance(i, sub_t) for i in arg)
):
raise TypeError(
f"{loadName}: {el} needs to be List[{sub_t}] but got {arg}"
)
Expand Down
Loading