Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 12 additions & 39 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ def mp_lock(
return ret


PREBUILD_KERNELS = False
if os.path.exists(os.path.dirname(os.path.abspath(__file__)) + "/aiter_.so"):
aiter_ = importlib.import_module(f"{__package__}.aiter_")
PREBUILD_KERNELS = True
logger = logging.getLogger("aiter")

PY = sys.executable
Expand Down Expand Up @@ -446,7 +442,6 @@ def build_module(
is_standalone,
torch_exclude,
hipify=False,
prebuild=0,
):
lock_path = f"{bd_dir}/lock_{md_name}"
startTS = time.perf_counter()
Expand All @@ -467,14 +462,7 @@ def MainFunc():
if os.path.exists(f"{get_user_jit_dir()}/{target_name}"):
os.remove(f"{get_user_jit_dir()}/{target_name}")

if prebuild != 2:
sources = rename_cpp_to_cu(srcs, src_dir, hipify)
else:
sources = rename_cpp_to_cu(
[get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"],
src_dir,
hipify,
)
sources = rename_cpp_to_cu(srcs, src_dir, hipify)

flags_cc = ["-O3", "-std=c++20"]
flags_hip = [
Expand Down Expand Up @@ -539,12 +527,11 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
sources += rename_cpp_to_cu([blob_dir], src_dir, hipify, recursive=True)
return sources

if prebuild != 2:
if isinstance(blob_gen_cmd, list):
for s_blob_gen_cmd in blob_gen_cmd:
sources = exec_blob(s_blob_gen_cmd, op_dir, src_dir, sources)
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)
if isinstance(blob_gen_cmd, list):
for s_blob_gen_cmd in blob_gen_cmd:
sources = exec_blob(s_blob_gen_cmd, op_dir, src_dir, sources)
else:
sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources)

extra_include_paths = [
f"{CK_HELPER_DIR}",
Expand Down Expand Up @@ -592,23 +579,9 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources):
is_standalone=is_standalone,
torch_exclude=torch_exclude,
hipify=hipify,
prebuild=prebuild,
)
if is_python_module and not is_standalone:
if prebuild == 1:
shutil.copy(
f"{opbd_dir}/{target_name}",
f"{get_user_jit_dir()}/build/aiter_/build",
)
elif prebuild == 2:
from pathlib import Path

src_dir = Path(opbd_dir)
dst_dir = Path(get_user_jit_dir())
for src_file in src_dir.glob("*.so"):
shutil.move(str(src_file), str(dst_dir / src_file.name))
else:
shutil.copy(f"{opbd_dir}/{target_name}", f"{get_user_jit_dir()}")
shutil.copy(f"{opbd_dir}/{target_name}", f"{get_user_jit_dir()}")
else:
shutil.copy(
f"{opbd_dir}/{target_name}", f"{AITER_ROOT_DIR}/op_tests/cpp/mha"
Expand Down Expand Up @@ -746,15 +719,15 @@ def wrapper(*args, custom_build_args={}, **kwargs):
module = None
if gen_func is not None:
custom_build_args.update(gen_func(*args, **kwargs))
if PREBUILD_KERNELS:
if hasattr(aiter_, loadName):
module = aiter_
elif AITER_REBUILD and md_name not in rebuilded_list:
rebuilded_list.append(md_name)
raise ModuleNotFoundError("start rebuild")
if module is None:
md = custom_build_args.get("md_name", md_name)
module = get_module(md)
try:
module = get_module(md_name)
except Exception as e:
md = custom_build_args.get("md_name", md_name)
module = get_module(md)
except ModuleNotFoundError:
d_args = get_args_of_build(md_name)
d_args.update(custom_build_args)
Expand Down
66 changes: 16 additions & 50 deletions aiter/jit/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,6 @@ def _jit_compile(
keep_intermediates=True,
torch_exclude=False,
hipify=True,
prebuild=0,
) -> None:
if is_python_module and is_standalone:
raise ValueError(
Expand Down Expand Up @@ -1237,7 +1236,6 @@ def _jit_compile(
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
prebuild=prebuild,
)
elif verbose:
print(
Expand Down Expand Up @@ -1320,7 +1318,6 @@ def _write_ninja_file_and_build_library(
is_python_module: bool,
is_standalone: bool = False,
torch_exclude: bool = False,
prebuild: int = 0,
) -> None:
verify_ninja_availability()

Expand All @@ -1329,7 +1326,7 @@ def _write_ninja_file_and_build_library(
if with_cuda is None:
with_cuda = any(map(_is_cuda_file, sources))
extra_ldflags = _prepare_ldflags(
extra_ldflags or [], with_cuda, verbose, is_standalone, torch_exclude, prebuild
extra_ldflags or [], with_cuda, verbose, is_standalone, torch_exclude
)
build_file_path = os.path.join(build_directory, "build.ninja")
if verbose:
Expand All @@ -1348,7 +1345,6 @@ def _write_ninja_file_and_build_library(
is_python_module=is_python_module,
is_standalone=is_standalone,
torch_exclude=torch_exclude,
prebuild=prebuild,
)

if verbose:
Expand All @@ -1374,9 +1370,7 @@ def verify_ninja_availability():
raise RuntimeError("Ninja is required to load C++ extensions")


def _prepare_ldflags(
extra_ldflags, with_cuda, verbose, is_standalone, torch_exclude, prebuild
):
def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone, torch_exclude):
extra_ldflags.append("-mcmodel=large")
extra_ldflags.append("-ffunction-sections")
extra_ldflags.append("-fdata-sections ")
Expand All @@ -1388,18 +1382,15 @@ def _prepare_ldflags(
_TORCH_PATH = os.path.join(os.path.dirname(torch.__file__))
TORCH_LIB_PATH = os.path.join(_TORCH_PATH, "lib")
extra_ldflags.append(f"-L{TORCH_LIB_PATH}")
if prebuild != 1:
extra_ldflags.append("-lc10")
if with_cuda:
extra_ldflags.append("-lc10_hip" if IS_HIP_EXTENSION else "-lc10_cuda")
extra_ldflags.append("-ltorch_cpu")
if with_cuda:
extra_ldflags.append(
"-ltorch_hip" if IS_HIP_EXTENSION else "-ltorch_cuda"
)
extra_ldflags.append("-ltorch")
if not is_standalone:
extra_ldflags.append("-ltorch_python")
extra_ldflags.append("-lc10")
if with_cuda:
extra_ldflags.append("-lc10_hip" if IS_HIP_EXTENSION else "-lc10_cuda")
extra_ldflags.append("-ltorch_cpu")
if with_cuda:
extra_ldflags.append("-ltorch_hip" if IS_HIP_EXTENSION else "-ltorch_cuda")
extra_ldflags.append("-ltorch")
if not is_standalone:
extra_ldflags.append("-ltorch_python")

if is_standalone:
extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")
Expand All @@ -1409,8 +1400,7 @@ def _prepare_ldflags(
print("Detected CUDA files, patching ldflags", file=sys.stderr)

extra_ldflags.append(f'-L{_join_rocm_home("lib")}')
if prebuild != 1:
extra_ldflags.append("-lamdhip64")
extra_ldflags.append("-lamdhip64")
return extra_ldflags


Expand Down Expand Up @@ -1538,7 +1528,6 @@ def _write_ninja_file_to_build_library(
is_python_module,
is_standalone,
torch_exclude,
prebuild=0,
) -> None:
extra_cflags = [flag.strip() for flag in extra_cflags]
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
Expand Down Expand Up @@ -1571,10 +1560,7 @@ def _write_ninja_file_to_build_library(
user_includes = [os.path.abspath(file) for file in extra_include_paths]

if not torch_exclude:
if prebuild == 0:
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
else:
common_cflags.append(f"-DTORCH_EXTENSION_NAME=aiter_")
common_cflags.append(f"-DTORCH_EXTENSION_NAME={name}")
# common_cflags.append("-DTORCH_API_INCLUDE_EXTENSION_H")
# common_cflags += [f"{x}" for x in _get_pybind11_abi_build_flags()]
# common_cflags += [f"{x}" for x in _get_glibcxx_abi_build_flags()]
Expand All @@ -1589,8 +1575,6 @@ def _write_ninja_file_to_build_library(
cuda_flags = ["-DWITH_HIP"] + cflags + COMMON_HIP_FLAGS + COMMON_HIPCC_FLAGS
cuda_flags += extra_cuda_cflags
cuda_flags += _get_rocm_arch_flags(cuda_flags)
if prebuild == 1:
cuda_flags += ["-fvisibility=default -DEXPORT_SYMBOLS"]

def object_file_path(source_file: str) -> str:
# '/path/to/file.cpp' -> 'file'
Expand All @@ -1608,8 +1592,6 @@ def object_file_path(source_file: str) -> str:

ext = EXEC_EXT if is_standalone else LIB_EXT
library_target = f"{name}{ext}"
if prebuild == 2:
library_target = "aiter_.so"

_write_ninja_file(
path=path,
Expand All @@ -1623,7 +1605,6 @@ def object_file_path(source_file: str) -> str:
ldflags=ldflags,
library_target=library_target,
with_cuda=with_cuda,
prebuild=prebuild,
)


Expand All @@ -1639,7 +1620,6 @@ def _write_ninja_file(
ldflags,
library_target,
with_cuda,
prebuild=0,
) -> None:
r"""Write a ninja file that does the desired compiling and linking.

Expand Down Expand Up @@ -1719,15 +1699,6 @@ def sanitize_flags(flags):
source_file = source_file.replace(" ", "$ ")
object_file = object_file.replace(" ", "$ ")
build.append(f"build {object_file}: {rule} {source_file}")
if prebuild == 2:
o_path = path.split("build/aiter_")[0]
ldflags.append(f"-Wl,-rpath={o_path}")

for root, dirs, files in os.walk(o_path):
for file in files:
mid_file_dir = o_path + file
if file.endswith(".so") and file not in objects:
objects.append(file)

flags.append(f'ldflags = {" ".join(ldflags)}')
if cuda_dlink_post_cflags:
Expand All @@ -1742,14 +1713,9 @@ def sanitize_flags(flags):
if library_target is not None:
link_rule = ["rule link"]

if prebuild == 2:
link_rule.append(
f" command = $cxx @$out.rsp $ldflags -Wl,-rpath,'$$ORIGIN' -o $out\n rspfile = $out.rsp\n rspfile_content = $in"
)
else:
link_rule.append(
" command = $cxx @$out.rsp $ldflags -o $out\n rspfile = $out.rsp\n rspfile_content = $in"
)
link_rule.append(
" command = $cxx @$out.rsp $ldflags -o $out\n rspfile = $out.rsp\n rspfile_content = $in"
)

link = [f'build {library_target}: link {" ".join(objects)}']

Expand Down
4 changes: 3 additions & 1 deletion csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,10 @@ namespace aiter
}
((P *)result)[idx] = write_reg;
}
__syncthreads();
}
end_sync<ngpus, true>(sg, self_sg, rank);
// maybe do not need device sync
// end_sync<ngpus, true>(sg, self_sg, rank);
}

template <typename T, int ngpus>
Expand Down
44 changes: 2 additions & 42 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,15 @@ def build_one_module(one_opt_args):
core.build_module(
md_name=one_opt_args["md_name"],
srcs=one_opt_args["srcs"],
flags_extra_cc=one_opt_args["flags_extra_cc"] + ["-DPREBUILD_KERNELS"],
flags_extra_hip=one_opt_args["flags_extra_hip"]
+ ["-DPREBUILD_KERNELS"],
flags_extra_cc=one_opt_args["flags_extra_cc"],
flags_extra_hip=one_opt_args["flags_extra_hip"],
blob_gen_cmd=one_opt_args["blob_gen_cmd"],
extra_include=one_opt_args["extra_include"],
extra_ldflags=None,
verbose=False,
is_python_module=True,
is_standalone=False,
torch_exclude=False,
prebuild=1,
)

# step 1, build *.cu -> module*.so
Expand All @@ -126,44 +124,6 @@ def build_one_module(one_opt_args):
with ThreadPoolExecutor(max_workers=prebuid_thread_num) as executor:
list(executor.map(build_one_module, all_opts_args_build))

ck_batched_gemm_folders = [
f"{this_dir}/csrc/{name}/include"
for name in os.listdir(f"{this_dir}/csrc")
if os.path.isdir(os.path.join(f"{this_dir}/csrc", name))
and name.startswith("ck_batched_gemm")
]
ck_gemm_folders = [
f"{this_dir}/csrc/{name}/include"
for name in os.listdir(f"{this_dir}/csrc")
if os.path.isdir(os.path.join(f"{this_dir}/csrc", name))
and name.startswith("ck_gemm_a")
]
ck_gemm_inc = ck_batched_gemm_folders + ck_gemm_folders
for src in ck_gemm_inc:
dst = f"{prebuild_dir}/include"
shutil.copytree(src, dst, dirs_exist_ok=True)

shutil.copytree(
f"{this_dir}/csrc/include", f"{prebuild_dir}/include", dirs_exist_ok=True
)

# step 2, link module*.so -> aiter_.so
core.build_module(
md_name="aiter_",
srcs=[f"{prebuild_dir}/srcs/rocm_ops.cu"],
flags_extra_cc=prebuild_link_param["flags_extra_cc"]
+ ["-DPREBUILD_KERNELS"],
flags_extra_hip=prebuild_link_param["flags_extra_hip"]
+ ["-DPREBUILD_KERNELS"],
blob_gen_cmd=prebuild_link_param["blob_gen_cmd"],
extra_include=prebuild_link_param["extra_include"],
extra_ldflags=None,
verbose=False,
is_python_module=True,
is_standalone=False,
torch_exclude=False,
prebuild=2,
)
else:
raise NotImplementedError("Only ROCM is supported")

Expand Down
Loading