Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,17 @@ def check_and_set_ninja_worker():
os.environ["MAX_JOBS"] = str(max_jobs)


def rename_cpp_to_cu(els, dst, recursive=False):
def rename_cpp_to_cu(els, dst, hipify, recursive=False):
def do_rename_and_mv(name, src, dst, ret):
newName = name
if name.endswith(".cpp") or name.endswith(".cu"):
newName = name.replace(".cpp", ".cu")
ret.append(f"{dst}/{newName}")
shutil.copy(f"{src}/{name}", f"{dst}/{newName}")
if hipify:
if name.endswith(".cpp") or name.endswith(".cu"):
newName = name.replace(".cpp", ".cu")
ret.append(f"{dst}/{newName}")
shutil.copy(f"{src}/{name}", f"{dst}/{newName}")
else:
if name.endswith(".cpp") or name.endswith(".cu"):
ret.append(f"{src}/{newName}")

ret = []
for el in els:
Expand All @@ -298,7 +302,9 @@ def do_rename_and_mv(name, src, dst, ret):
for entry in os.listdir(el):
if os.path.isdir(f"{el}/{entry}"):
if recursive:
ret += rename_cpp_to_cu([f"{el}/{entry}"], dst, recursive)
ret += rename_cpp_to_cu(
[f"{el}/{entry}"], dst, hipify, recursive
)
continue
do_rename_and_mv(entry, el, dst, ret)
else:
Expand Down Expand Up @@ -375,7 +381,7 @@ def build_module(
is_python_module,
is_standalone,
torch_exclude,
hipify=True,
hipify=False,
prebuild=0,
):
lock_path = f"{bd_dir}/lock_{md_name}"
Expand All @@ -400,10 +406,12 @@ def MainFunc():
os.remove(f"{get_user_jit_dir()}/{target_name}")

if prebuild != 2:
sources = rename_cpp_to_cu(srcs, src_dir)
sources = rename_cpp_to_cu(srcs, src_dir, hipify)
else:
sources = rename_cpp_to_cu(
[get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"], opbd_dir + "/srcs"
[get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"],
src_dir,
hipify,
)

flags_cc = ["-O3", "-std=c++20"]
Expand Down Expand Up @@ -466,7 +474,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
if AITER_LOG_MORE:
logger.info(f"exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}")
os.system(f"{PY} {blob_gen_cmd.format(blob_dir)}")
sources += rename_cpp_to_cu([blob_dir], src_dir, recursive=True)
sources += rename_cpp_to_cu([blob_dir], src_dir, hipify, recursive=True)
return sources

if prebuild != 2:
Expand All @@ -476,30 +484,40 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)

# TODO: Move all torch api into torch folder
old_bd_include_dir = f"{op_dir}/build/include"
os.makedirs(old_bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include"] + extra_include, old_bd_include_dir
)

if not is_standalone:
bd_include_dir = f"{op_dir}/build/include/torch"
os.makedirs(bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include/torch"] + extra_include, bd_include_dir
)

extra_include_paths = [
f"{CK_DIR}/include",
f"{CK_DIR}/library/include",
f"{old_bd_include_dir}",
]
if not hipify:
extra_include_paths += [
f"{AITER_CSRC_DIR}/include",
f"{op_dir}/blob",
] + extra_include
if not is_standalone:
extra_include_paths += [f"{AITER_CSRC_DIR}/include/torch"]
else:
old_bd_include_dir = f"{op_dir}/build/include"
extra_include_paths.append(old_bd_include_dir)
os.makedirs(old_bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include"] + extra_include,
old_bd_include_dir,
hipify,
)

if not is_standalone:
bd_include_dir = f"{op_dir}/build/include/torch"
os.makedirs(bd_include_dir, exist_ok=True)
rename_cpp_to_cu(
[f"{AITER_CSRC_DIR}/include/torch"],
bd_include_dir,
hipify,
)

try:
_jit_compile(
md_name,
sources,
sorted(set(sources)),
extra_cflags=flags_cc,
extra_cuda_cflags=flags_hip,
extra_ldflags=extra_ldflags,
Expand Down Expand Up @@ -550,7 +568,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):

def FinalFunc():
logger.info(
f"finish build [{md_name}], cost {time.perf_counter()-startTS:.8f}s"
f"\033[32mfinish build [{md_name}], cost {time.perf_counter()-startTS:.1f}s \033[0m"
)

mp_lock(lockPath=lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc)
Expand Down Expand Up @@ -846,7 +864,7 @@ def wrapper(*args, custom_build_args={}, **kwargs):
is_python_module = d_args["is_python_module"]
is_standalone = d_args["is_standalone"]
torch_exclude = d_args["torch_exclude"]
hipify = d_args.get("hipify", True)
hipify = d_args.get("hipify", False)
hip_clang_path = d_args.get("hip_clang_path", None)
prev_hip_clang_path = None
if hip_clang_path is not None and os.path.exists(hip_clang_path):
Expand Down
Loading
Loading