diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 85ba4bd094..ec7ddebcdb 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -281,13 +281,17 @@ def check_and_set_ninja_worker(): os.environ["MAX_JOBS"] = str(max_jobs) -def rename_cpp_to_cu(els, dst, recursive=False): +def rename_cpp_to_cu(els, dst, hipify, recursive=False): def do_rename_and_mv(name, src, dst, ret): newName = name - if name.endswith(".cpp") or name.endswith(".cu"): - newName = name.replace(".cpp", ".cu") - ret.append(f"{dst}/{newName}") - shutil.copy(f"{src}/{name}", f"{dst}/{newName}") + if hipify: + if name.endswith(".cpp") or name.endswith(".cu"): + newName = name.replace(".cpp", ".cu") + ret.append(f"{dst}/{newName}") + shutil.copy(f"{src}/{name}", f"{dst}/{newName}") + else: + if name.endswith(".cpp") or name.endswith(".cu"): + ret.append(f"{src}/{newName}") ret = [] for el in els: @@ -298,7 +302,9 @@ def do_rename_and_mv(name, src, dst, ret): for entry in os.listdir(el): if os.path.isdir(f"{el}/{entry}"): if recursive: - ret += rename_cpp_to_cu([f"{el}/{entry}"], dst, recursive) + ret += rename_cpp_to_cu( + [f"{el}/{entry}"], dst, hipify, recursive + ) continue do_rename_and_mv(entry, el, dst, ret) else: @@ -375,7 +381,7 @@ def build_module( is_python_module, is_standalone, torch_exclude, - hipify=True, + hipify=False, prebuild=0, ): lock_path = f"{bd_dir}/lock_{md_name}" @@ -400,10 +406,12 @@ def MainFunc(): os.remove(f"{get_user_jit_dir()}/{target_name}") if prebuild != 2: - sources = rename_cpp_to_cu(srcs, src_dir) + sources = rename_cpp_to_cu(srcs, src_dir, hipify) else: sources = rename_cpp_to_cu( - [get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"], opbd_dir + "/srcs" + [get_user_jit_dir() + "/../../csrc/rocm_ops.cpp"], + src_dir, + hipify, ) flags_cc = ["-O3", "-std=c++20"] @@ -466,7 +474,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): if AITER_LOG_MORE: logger.info(f"exec_blob ---> {PY} {blob_gen_cmd.format(blob_dir)}") os.system(f"{PY} {blob_gen_cmd.format(blob_dir)}") - sources += rename_cpp_to_cu([blob_dir], src_dir, recursive=True) + sources += rename_cpp_to_cu([blob_dir], src_dir, hipify, recursive=True) return sources if prebuild != 2: @@ -476,30 +484,40 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): else: sources = exec_blob(blob_gen_cmd, op_dir, src_dir, sources) - # TODO: Move all torch api into torch folder - old_bd_include_dir = f"{op_dir}/build/include" - os.makedirs(old_bd_include_dir, exist_ok=True) - rename_cpp_to_cu( - [f"{AITER_CSRC_DIR}/include"] + extra_include, old_bd_include_dir - ) - - if not is_standalone: - bd_include_dir = f"{op_dir}/build/include/torch" - os.makedirs(bd_include_dir, exist_ok=True) - rename_cpp_to_cu( - [f"{AITER_CSRC_DIR}/include/torch"] + extra_include, bd_include_dir - ) - extra_include_paths = [ f"{CK_DIR}/include", f"{CK_DIR}/library/include", - f"{old_bd_include_dir}", ] + if not hipify: + extra_include_paths += [ + f"{AITER_CSRC_DIR}/include", + f"{op_dir}/blob", + ] + extra_include + if not is_standalone: + extra_include_paths += [f"{AITER_CSRC_DIR}/include/torch"] + else: + old_bd_include_dir = f"{op_dir}/build/include" + extra_include_paths.append(old_bd_include_dir) + os.makedirs(old_bd_include_dir, exist_ok=True) + rename_cpp_to_cu( + [f"{AITER_CSRC_DIR}/include"] + extra_include, + old_bd_include_dir, + hipify, + ) + + if not is_standalone: + bd_include_dir = f"{op_dir}/build/include/torch" + os.makedirs(bd_include_dir, exist_ok=True) + rename_cpp_to_cu( + [f"{AITER_CSRC_DIR}/include/torch"], + bd_include_dir, + hipify, + ) try: _jit_compile( md_name, - sources, + sorted(set(sources)), extra_cflags=flags_cc, extra_cuda_cflags=flags_hip, extra_ldflags=extra_ldflags, @@ -550,7 +568,7 @@ def exec_blob(blob_gen_cmd, op_dir, src_dir, sources): def FinalFunc(): logger.info( - f"finish build [{md_name}], cost {time.perf_counter()-startTS:.8f}s" + f"\033[32mfinish build [{md_name}], cost {time.perf_counter()-startTS:.1f}s \033[0m" ) mp_lock(lockPath=lock_path, MainFunc=MainFunc, FinalFunc=FinalFunc) @@ -846,7 +864,7 @@ def wrapper(*args, custom_build_args={}, **kwargs): is_python_module = d_args["is_python_module"] is_standalone = d_args["is_standalone"] torch_exclude = d_args["torch_exclude"] - hipify = d_args.get("hipify", True) + hipify = d_args.get("hipify", False) hip_clang_path = d_args.get("hip_clang_path", None) prev_hip_clang_path = None if hip_clang_path is not None and os.path.exists(hip_clang_path): diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index 996c15ae2b..34f4dec6d2 100644 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -182,107 +182,89 @@ }, "module_batched_gemm_bf16": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'", "f'{AITER_CSRC_DIR}/pybind/batched_gemm_bf16_pybind.cu'", "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/batched_gemm_bf16.cu'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'" + ], "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_BF16_BATCHED_GEMM_FILE}'" }, "module_batched_gemm_a8w8": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/include'", "f'{AITER_CSRC_DIR}/pybind/batched_gemm_a8w8_pybind.cu'", "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/batched_gemm_a8w8.cu'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/include'" + ], "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_A8W8_BATCHED_GEMM_FILE}'" }, "module_gemm_a8w8": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/include'", "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_pybind.cu'", "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gemm_a8w8.cu'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/include'" + ], "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_FILE}'" }, "module_gemm_a8w8_blockscale": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/include'", "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_pybind.cu'", "f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'", "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu'" ], - "flags_extra_cc": [], + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/include'" + ], "flags_extra_hip": [ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_FILE}'" }, "module_gemm_a8w8_blockscale_bpreshuffle": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/include'", "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_bpreshuffle_pybind.cu'", "f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'", "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle.cu'" ], - "flags_extra_cc": [], + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/include'" + ], "flags_extra_hip": [ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BLOCKSCALE_BPRESHUFFLE_FILE}'" }, "module_gemm_a4w4_blockscale": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/include'", "f'{AITER_CSRC_DIR}/pybind/gemm_a4w4_blockscale_pybind.cu'", "f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'", "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale.cu'" ], - "flags_extra_cc": [], + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/include'" + ], "flags_extra_hip": [ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A4W4_FILE}'" }, "module_gemm_a8w8_bpreshuffle": { "srcs": [ - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/include'", "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_bpreshuffle_pybind.cu'", "f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'", "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle.cu'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/include'" + ], "is_python_module": "True", "is_standalone": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune_file {AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE}'" @@ -373,11 +355,10 @@ "module_moe_ck2stages": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/moe_ck_2stages_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.cuh'", - "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_mxfp4.cuh'", - "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common_blockscale.cuh'", - "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.h'" + "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu'" + ], + "extra_include": [ + "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen'" ], "md_name": "'module_moe_ck2stages'", "flags_extra_cc": [ @@ -390,7 +371,6 @@ "'-mllvm --misched-prera-direction=bottomup' if int(torch.version.hip.split('.')[0]) >= 7 else '-mllvm --misched-bottomup=1'" ], "extra_ldflags": "None", - "extra_include": [], "verbose": "False", "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_moe_2stages_codegen/gen_instances.py --working_path {{}}'" @@ -475,96 +455,76 @@ "module_batched_gemm_bf16_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/batched_gemm_bf16_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/batched_gemm_bf16_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_bf16/gen_instances.py --working_path {{}} --tune'" }, "module_batched_gemm_a8w8_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/batched_gemm_a8w8_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_batched_gemm_a8w8/gen_instances.py --working_path {{}} --tune'" }, "module_gemm_a8w8_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gemm_a8w8_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gemm_a8w8_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8/gen_instances.py --working_path {{}} --tune'" }, "module_gemm_a8w8_blockscale_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale/gen_instances.py --working_path {{}} --tune'" }, "module_gemm_a8w8_blockscale_bpreshuffle_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_blockscale_bpreshuffle_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gemm_a8w8_blockscale_bpreshuffle_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py --working_path {{}} --tune'" }, "module_gemm_a4w4_blockscale_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a4w4_blockscale_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gemm_a4w4_blockscale_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/include'" ], - "flags_extra_cc": [], "flags_extra_hip": [ "'-mllvm -greedy-reverse-local-assignment=1'", "'-mllvm --amdgpu-use-amdgpu-trackers=1'" ], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "hip_clang_path": "os.environ.get('GEMM_A4W4_BLOCKWISE_HIP_CLANG_PATH')", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a4w4_blockscale/gen_instances.py --working_path {{}} --tune'" }, "module_gemm_a8w8_bpreshuffle_tune": { "srcs": [ "f'{AITER_CSRC_DIR}/pybind/gemm_a8w8_bpreshuffle_tune_pybind.cu'", - "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.cu'", + "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.cu'" + ], + "extra_include": [ "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/include'" ], - "flags_extra_cc": [], - "flags_extra_hip": [], - "extra_ldflags": "None", - "extra_include": [], - "verbose": "False", "is_python_module": "True", "is_standalone": "False", "blob_gen_cmd": "f'{AITER_CSRC_DIR}/ck_gemm_a8w8_bpreshuffle/gen_instances.py --working_path {{}} --tune'" @@ -942,4 +902,4 @@ "verbose": "False", "blob_gen_cmd": "''" } -} +} \ No newline at end of file diff --git a/aiter/jit/utils/cpp_extension.py b/aiter/jit/utils/cpp_extension.py index 8901bd1036..5f66a477a9 100644 --- a/aiter/jit/utils/cpp_extension.py +++ b/aiter/jit/utils/cpp_extension.py @@ -1782,6 +1782,7 @@ def sanitize_flags(flags): def _is_cuda_file(path: str) -> bool: + return True valid_ext = [".cu", ".cuh"] if IS_HIP_EXTENSION: valid_ext.append(".hip") diff --git a/aiter/ops/communication.py b/aiter/ops/communication.py index f419039f38..192a2e945f 100644 --- a/aiter/ops/communication.py +++ b/aiter/ops/communication.py @@ -1,22 +1,24 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import logging + import torch -from torch import Tensor import torch.distributed as dist +from torch import Tensor + +# from ..dist.utils import get_open_port, get_distributed_init_method, get_ip +import aiter + from ..dist.parallel_state import ( + destroy_distributed_environment, + destroy_model_parallel, ensure_model_parallel_initialized, + get_tp_group, init_distributed_environment, set_custom_all_reduce, - get_tp_group, - destroy_model_parallel, - destroy_distributed_environment, ) -# from ..dist.utils import get_open_port, get_distributed_init_method, get_ip -import aiter -import logging - logger = logging.getLogger("aiter") @@ -63,7 +65,7 @@ def all_reduce_asm(inp: torch.Tensor): return torch.empty_like(inp) else: # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should + # custom allreduce incurs a cost of hipMemcpy, which should # be small(<=1% of overall latency) compared to the performance # gains of using custom kernels return aiter.all_reduce_asm_( diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index 3eb254c4a6..fd7dc3045e 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -import torch from typing import List -from ..jit.core import ( - compile_ops, -) + +import torch + +from ..jit.core import compile_ops MD_NAME = "module_custom_all_reduce" @@ -155,7 +155,7 @@ def register_buffer( # def gen_get_graph_buffer_ipc_meta_fake_tensors(_fa: int) -> List[torch.Tensor]: -# handle_sz = 64 # sizeof(cudaIpcMemHandle_t) is 64 byte +# handle_sz = 64 # sizeof(hipIpcMemHandle_t) is 64 byte # num_buffers = 4 # ??? # handles = torch.empty((handle_sz * num_buffers,), dtype=torch.uint8, device="cuda") diff --git a/aiter/utility/base_tuner.py b/aiter/utility/base_tuner.py index 9f2376a2a8..3bc566d563 100644 --- a/aiter/utility/base_tuner.py +++ b/aiter/utility/base_tuner.py @@ -151,7 +151,7 @@ def tune(self, untunedf, tunedf, args): @abstractmethod def getKernelName(self, kernel_id): - """获取kernel name""" + """??kernel name""" pass @abstractmethod diff --git a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py index 5fd254b313..1a70e0772b 100644 --- a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py +++ b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import os import aiter import pandas as pd diff --git a/csrc/ck_batched_gemm_a8w8/gen_instances.py b/csrc/ck_batched_gemm_a8w8/gen_instances.py index b269fa019c..b8d572b3fd 100644 --- a/csrc/ck_batched_gemm_a8w8/gen_instances.py +++ b/csrc/ck_batched_gemm_a8w8/gen_instances.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch -from batched_gemm_a8w8_common import kernelInstance, kernels_list, default_kernels_dict +from batched_gemm_a8w8_common import default_kernels_dict, kernelInstance, kernels_list class batched_gemm_a8w8_fwd_codegen: @@ -127,7 +128,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_batched_gemm_a8w8/include/batched_gemm_a8w8_common.cuh b/csrc/ck_batched_gemm_a8w8/include/batched_gemm_a8w8_common.cuh index 2b844ed84e..062ee54c8c 100644 --- a/csrc/ck_batched_gemm_a8w8/include/batched_gemm_a8w8_common.cuh +++ b/csrc/ck_batched_gemm_a8w8/include/batched_gemm_a8w8_common.cuh @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM @@ -14,9 +14,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -272,7 +272,7 @@ __forceinline__ torch::Tensor batched_gemm_a8w8_rowwise_impl( int BatchStrideB = N * K; int BatchStrideE = M * N; - const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); auto device_gemm = DeviceGemmInstance{}; auto invoker = device_gemm.MakeInvoker(); @@ -307,7 +307,7 @@ __forceinline__ torch::Tensor batched_gemm_a8w8_rowwise_impl( TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); return Y; } diff --git a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py index 333da01d7a..dfd8aa65b9 100644 --- a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py +++ b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. import os import aiter import pandas as pd diff --git a/csrc/ck_batched_gemm_bf16/gen_instances.py b/csrc/ck_batched_gemm_bf16/gen_instances.py index 92c4185c1b..0f1a11ed1d 100644 --- a/csrc/ck_batched_gemm_bf16/gen_instances.py +++ b/csrc/ck_batched_gemm_bf16/gen_instances.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch -from batched_gemm_bf16_common import kernelInstance, kernels_list, default_kernels_dict +from batched_gemm_bf16_common import default_kernels_dict, kernelInstance, kernels_list class batched_gemm_bf16_fwd_codegen: @@ -127,7 +128,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" torch::Tensor {name}( diff --git a/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16_common.cuh b/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16_common.cuh index 6555eb9565..a1e416715d 100644 --- a/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16_common.cuh +++ b/csrc/ck_batched_gemm_bf16/include/batched_gemm_bf16_common.cuh @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM @@ -14,9 +14,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -151,7 +151,7 @@ __forceinline__ torch::Tensor batched_gemm_bf16_impl( int BatchStrideB = N * K; int BatchStrideE = M * N; - const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); auto device_gemm = DeviceGemmInstance{}; auto invoker = device_gemm.MakeInvoker(); @@ -184,7 +184,7 @@ __forceinline__ torch::Tensor batched_gemm_bf16_impl( TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); return Y; } diff --git a/csrc/ck_gemm_a4w4_blockscale/gen_instances.py b/csrc/ck_gemm_a4w4_blockscale/gen_instances.py index 3b929f5888..9e9da1346a 100755 --- a/csrc/ck_gemm_a4w4_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a4w4_blockscale/gen_instances.py @@ -1,18 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch from gemm_a4w4_blockscale_common import ( + default_kernels_dict, kernelInstance, kernels_list, - default_kernels_dict, ) - """ a4w4_blockscale_gemm instance gen @@ -102,7 +102,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale_common.cuh b/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale_common.cuh index 36515fb40e..1322798f28 100755 --- a/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale_common.cuh +++ b/csrc/ck_gemm_a4w4_blockscale/include/gemm_a4w4_blockscale_common.cuh @@ -1,168 +1,168 @@ -#pragma once -// SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#ifdef USE_ROCM - -#undef __HIP_NO_HALF_OPERATORS__ -#undef __HIP_NO_HALF_CONVERSIONS__ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" - -template -using S = ck::Sequence; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; -using MFMA = ck::tensor_layout::gemm::MFMA; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using F4PK = ck::f4x2_pk_t; -using F16 = ck::half_t; -using B16 = ck::bhalf_t; -using F32 = float; -using E8M0PK = int32_t; - -using ADataType = F4PK; -using BDataType = F4PK; -using XPackedDataType = E8M0PK; - -using AccDataType = float; - -using ALayout = Row; -// using BLayout = MFMA; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size -constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 - -static constexpr auto Intrawave = ck::BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = ck::BlockGemmPipelineScheduler::Interwave; - -template -using DeviceGemmHelperF4BlockScale = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3 - // clang-format off - , S<1, 0, 2>, - 2, AK1, AK1, - true, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - S<1, 0, 2>, S<1, 0, 2>, - 2, BK1, BK1, - true, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ADataType, BDataType>; -// clang-format on - -template -__forceinline__ torch::Tensor gemm_a4w4_blockscale_impl( - torch::Tensor &A, - torch::Tensor &B, - torch::Tensor &a_scale, - torch::Tensor &b_scale, - torch::Tensor &C, - int splitK) -{ - int M = A.size(0); - int N = B.size(0); - int K = A.size(1) * 2; // always fp4_x2 - - // TODO: support batch gemm - int KBatch = std::pow(2, splitK); - - int StrideA = A.stride(-2) * 2; // always fp4_x2 - int StrideB = B.stride(-2) * 2; // always fp4_x2 - int StrideC = C.stride(-2); - int Scale_Stride_A = a_scale.stride(-2); - int Scale_Stride_B = b_scale.stride(-2); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - - // do GEMM - auto device_gemm = DeviceGemmInstance{}; - auto invoker = device_gemm.MakeInvoker(); - auto argument = device_gemm.MakeArgument(static_cast(A.data_ptr()), - static_cast(a_scale.data_ptr()), - static_cast(B.data_ptr()), - static_cast(b_scale.data_ptr()), - static_cast(C.data_ptr()), - M, - N, - K, - StrideA, - Scale_Stride_A, - StrideB, - Scale_Stride_B, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op); - - TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); - return C; -} - -#endif // USE_ROCM +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#ifdef USE_ROCM + +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using F4PK = ck::f4x2_pk_t; +using F16 = ck::half_t; +using B16 = ck::bhalf_t; +using F32 = float; +using E8M0PK = int32_t; + +using ADataType = F4PK; +using BDataType = F4PK; +using XPackedDataType = E8M0PK; + +using AccDataType = float; + +using ALayout = Row; +// using BLayout = MFMA; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr auto Intrawave = ck::BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = ck::BlockGemmPipelineScheduler::Interwave; + +template +using DeviceGemmHelperF4BlockScale = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3 + // clang-format off + , S<1, 0, 2>, + 2, AK1, AK1, + true, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, BK1, BK1, + true, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ADataType, BDataType>; +// clang-format on + +template +__forceinline__ torch::Tensor gemm_a4w4_blockscale_impl( + torch::Tensor &A, + torch::Tensor &B, + torch::Tensor &a_scale, + torch::Tensor &b_scale, + torch::Tensor &C, + int splitK) +{ + int M = A.size(0); + int N = B.size(0); + int K = A.size(1) * 2; // always fp4_x2 + + // TODO: support batch gemm + int KBatch = std::pow(2, splitK); + + int StrideA = A.stride(-2) * 2; // always fp4_x2 + int StrideB = B.stride(-2) * 2; // always fp4_x2 + int StrideC = C.stride(-2); + int Scale_Stride_A = a_scale.stride(-2); + int Scale_Stride_B = b_scale.stride(-2); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + + // do GEMM + auto device_gemm = DeviceGemmInstance{}; + auto invoker = device_gemm.MakeInvoker(); + auto argument = device_gemm.MakeArgument(static_cast(A.data_ptr()), + static_cast(a_scale.data_ptr()), + static_cast(B.data_ptr()), + static_cast(b_scale.data_ptr()), + static_cast(C.data_ptr()), + M, + N, + K, + StrideA, + Scale_Stride_A, + StrideB, + Scale_Stride_B, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); + + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); + return C; +} + +#endif // USE_ROCM diff --git a/csrc/ck_gemm_a8w8/gen_instances.py b/csrc/ck_gemm_a8w8/gen_instances.py index ad68d3f64a..cd0010ecaf 100644 --- a/csrc/ck_gemm_a8w8/gen_instances.py +++ b/csrc/ck_gemm_a8w8/gen_instances.py @@ -1,12 +1,13 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch -from gemm_a8w8_common import kernelInstance, kernels_list, default_kernels_dict +from gemm_a8w8_common import default_kernels_dict, kernelInstance, kernels_list class gemm_a8w8_fwd_codegen: @@ -160,7 +161,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh index 6342e8abcc..b8fa035b7a 100644 --- a/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh +++ b/csrc/ck_gemm_a8w8/include/gemm_a8w8_common.cuh @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM @@ -14,9 +14,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -129,7 +129,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu template < typename ABDataType, typename AccDataType, - typename DDataType, + typename DDataType, typename EDataType, typename CDEElementOp, int BLOCK_SIZE, @@ -218,7 +218,7 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( int StrideB = K; int StrideE = N; - const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); auto device_gemm = DeviceGemmInstance{}; auto invoker = device_gemm.MakeInvoker(); @@ -249,7 +249,7 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( cde_element_op); TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); } else { @@ -274,7 +274,7 @@ __forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( cde_element_op); TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); } return Y; } diff --git a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py index 61590665ac..83f8505118 100755 --- a/csrc/ck_gemm_a8w8_blockscale/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale/gen_instances.py @@ -1,18 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch from gemm_a8w8_blockscale_common import ( + default_kernels_dict, kernelInstance, kernels_list, - default_kernels_dict, ) - """ a8w8_blockscale_gemm instance gen @@ -112,7 +112,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh index 1c909a4705..17ecaffeb1 100755 --- a/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale_common.cuh @@ -1,166 +1,166 @@ -#pragma once -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#ifdef USE_ROCM - -#undef __HIP_NO_HALF_OPERATORS__ -#undef __HIP_NO_HALF_CONVERSIONS__ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -#include "ck/utility/blkgemmpipe_scheduler.hpp" - -template -using S = ck::Sequence; - -using BF16 = ck::bhalf_t; -using B16 = ck::bhalf_t; -using FP8 = ck::f8_t; -using F32 = float; -using I8 = int8_t; -using I32 = int; -using F16 = ck::half_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using A0DataType = FP8; -using A1DataType = F32; -using B0DataType = FP8; -using B1DataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = BF16; - -using A0Layout = Row; -using B0Layout = Col; -using D0Layout = Row; -using D1Layout = Col; -using DsLayout = ck::Tuple<>; -using ELayout = Row; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CDEElementOp = PassThrough; - -// static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - -// static constexpr ck::index_t Scale_Block_M = 1; -// static constexpr ck::index_t Scale_Block_N = 128; -// static constexpr ck::index_t Scale_Block_K = 128; - -template -using DeviceGemmHelperF8BlockScale = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 - // clang-format off - , S<1, 0, 2>, - 2, AK1, AK1, 0, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - S<1, 0, 2>, S<1, 0, 2>, - 2, BK1, BK1, 0, - CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, - CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, A0DataType>; - // clang-format on - -template -__forceinline__ torch::Tensor gemm_a8w8_blockscale_impl( - torch::Tensor& XQ, - torch::Tensor& WQ, - torch::Tensor& x_scale, - torch::Tensor& w_scale, - torch::Tensor& Y) -{ - int M = XQ.size(0); - int N = WQ.size(0); - int K = XQ.size(1); - - int StrideA = XQ.stride(-2); - int StrideB = WQ.stride(-2); - int StrideE = N; - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; - - constexpr ck::index_t NumDTensor = DsDataType::Size(); - - // do GEMM - auto device_gemm = DeviceGemmInstance{}; - auto invoker = device_gemm.MakeInvoker(); - auto argument = device_gemm.MakeArgument(XQ.data_ptr(), - WQ.data_ptr(), - std::array{}, - reinterpret_cast(Y.data_ptr()), - M, - N, - K, - StrideA, - StrideB, - std::array{}, - StrideE, - reinterpret_cast(x_scale.data_ptr()), - reinterpret_cast(w_scale.data_ptr()), - a_element_op, - b_element_op, - cde_element_op); - - TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); - return Y; -} - -#endif // USE_ROCM +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using B16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; +using I8 = int8_t; +using I32 = int; +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +// static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// static constexpr ck::index_t Scale_Block_M = 1; +// static constexpr ck::index_t Scale_Block_N = 128; +// static constexpr ck::index_t Scale_Block_K = 128; + +template +using DeviceGemmHelperF8BlockScale = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + // clang-format off + , S<1, 0, 2>, + 2, AK1, AK1, 0, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, BK1, BK1, 0, + CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, A0DataType>; + // clang-format on + +template +__forceinline__ torch::Tensor gemm_a8w8_blockscale_impl( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) +{ + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + int StrideA = XQ.stride(-2); + int StrideB = WQ.stride(-2); + int StrideE = N; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_gemm = DeviceGemmInstance{}; + auto invoker = device_gemm.MakeInvoker(); + auto argument = device_gemm.MakeArgument(XQ.data_ptr(), + WQ.data_ptr(), + std::array{}, + reinterpret_cast(Y.data_ptr()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), + a_element_op, + b_element_op, + cde_element_op); + + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); + + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); + return Y; +} + +#endif // USE_ROCM diff --git a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py index 43553eddb8..ba0c08a0e2 100755 --- a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py +++ b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/gen_instances.py @@ -112,7 +112,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/include/gemm_a8w8_blockscale_bpreshuffle_common.cuh b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/include/gemm_a8w8_blockscale_bpreshuffle_common.cuh index 52950f1b31..dc5838f0e8 100755 --- a/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/include/gemm_a8w8_blockscale_bpreshuffle_common.cuh +++ b/csrc/ck_gemm_a8w8_blockscale_bpreshuffle/include/gemm_a8w8_blockscale_bpreshuffle_common.cuh @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM @@ -13,9 +13,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include "ck/ck.hpp" @@ -93,24 +93,24 @@ using DeviceGemmHelperF8BlockScaleBPreshuffle = ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle // clang-format off , S<1, 0, 2>, - 2, AK1, AK1, 0, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - S<1, 0, 2>, S<1, 0, 2>, - 2, BK1, BK1, 0, + BlockSize, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MPerXDL, NPerXDL, + MXdlPerWave, NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, AK1, AK1, 0, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, BK1, BK1, 0, CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, - CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, BlkGemmPipelineVer, A0DataType>; // clang-format on @@ -159,7 +159,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_bpreshuffle_impl(torch::Tenso TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); return Y; } diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py b/csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py index 8db3c74b29..8f4268fb38 100755 --- a/csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py +++ b/csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py @@ -1,18 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import argparse import os +import shutil from pathlib import Path + import pandas as pd -import argparse -import shutil import torch from gemm_a8w8_bpreshuffle_common import ( + default_kernels_dict, kernelInstance, kernels_list, - default_kernels_dict, ) - """ a8w8_bpreshuffle_gemm instance gen @@ -111,7 +111,7 @@ def gen_instance(self, k: kernelInstance): INSTANCE_template = """// SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -#include "{name}.cuh" +#include "impl/{name}.cuh" template torch::Tensor {name}<{dtypes}>( diff --git a/csrc/ck_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_common.cuh b/csrc/ck_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_common.cuh index 4cb7725824..b3369e63f3 100644 --- a/csrc/ck_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_common.cuh +++ b/csrc/ck_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_common.cuh @@ -13,9 +13,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include "ck/ck.hpp" @@ -140,23 +140,23 @@ using DeviceGemmHelperF8Flatmm = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle // clang-format off , S<1, 0, 2>, - 2, AK1, AK1, 0, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - S<1, 0, 2>, S<1, 0, 2>, - 2, BK1, BK1, 0, + MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MPerXDL, NPerXDL, + MXdlPerWave, NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, AK1, AK1, 0, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, S<1, 0, 2>, + 2, BK1, BK1, 0, CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, - CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, BlkGemmPipelineVer, A0DataType>; template @@ -205,7 +205,7 @@ __forceinline__ torch::Tensor gemm_a8w8_bpreshuffle_impl( TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); - invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()}); return Y; } diff --git a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu index 49d4570475..051125faa3 100644 --- a/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu +++ b/csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include -#include +#include +#include #include "gemm_moe_ck2stages_lookup.h" #include "gemm_moe_ck2stages.h" #include "ck2stages_moe_stage1_heuristic_dispatch.hpp" @@ -58,8 +58,8 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token int quant_type = 0, int activation = 0) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); - at::cuda::getCurrentCUDAStream().stream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); + at::hip::getCurrentHIPStream(); TORCH_CHECK(out.dtype() == at::ScalarType::BFloat16 || out.dtype() == at::ScalarType::Half, "Out dtype only support BFloat16/Float16!") @@ -97,7 +97,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token auto kernel = moe_dispatch<1>(kernelName, MPerBlock, N, hidden_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight); - kernel(at::cuda::getCurrentCUDAStream().stream(), + kernel(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); } @@ -153,7 +153,7 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token activation = !activation; auto kernel = moe_dispatch<2>(kernelName, MPerBlock, K, inter_states.dtype().toScalarType(), w1.dtype().toScalarType(), out.dtype().toScalarType(), activation, quant_type, MulRoutedWeight); - kernel(at::cuda::getCurrentCUDAStream().stream(), + kernel(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); } \ No newline at end of file diff --git a/csrc/cpp_itfs/pa/pa.py b/csrc/cpp_itfs/pa/pa.py index b4d26597bf..400e248e75 100644 --- a/csrc/cpp_itfs/pa/pa.py +++ b/csrc/cpp_itfs/pa/pa.py @@ -1,8 +1,9 @@ -from jinja2 import Template -from csrc.cpp_itfs.utils import compile_template_op, AITER_CORE_DIR import ctypes import math +from jinja2 import Template + +from csrc.cpp_itfs.utils import AITER_CORE_DIR, compile_template_op MD_NAME = "pa" @@ -76,12 +77,14 @@ def paged_attention_rocm( query_scale=None, ): import torch + from csrc.cpp_itfs.torch_utils import torch_to_c_types dtype_map = { torch.bfloat16: "__hip_bfloat16", torch.float16: "_Float16", torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e4m3fn: "uint8_t", } warpSize = torch.cuda.get_device_properties(out.device).warp_size diff --git a/csrc/include/binary_operator.cuh b/csrc/include/binary_operator.cuh index 01051c0363..c1955f8975 100644 --- a/csrc/include/binary_operator.cuh +++ b/csrc/include/binary_operator.cuh @@ -1,5 +1,5 @@ /* - * Copyright © Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) Advanced Micro Devices, Inc. All rights reserved. * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,19 +16,15 @@ */ #pragma once #include -#include -#include +#include +#include #include "hip_compat.h" #include "dispatch_utils.h" #include -#ifdef USE_ROCM #include +#include typedef __hip_bfloat16 nv_bfloat16; -#else -#include -#endif -#include namespace aiter { @@ -1268,7 +1264,7 @@ struct BinaryOperationPattern<1, Operation, _T0, _T1> const int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K); const dim3 grid_dim(grid_x, 1, 1); const dim3 block_dim(256, 1, 1); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); if (order_flag) @@ -1310,7 +1306,7 @@ struct BinaryOperationPattern<2, Operation, _T0, _T1> N = shape[0] * shape[1] * shape[2]; K = shape[3]; } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); const uint32_t rows = 8; @@ -1457,7 +1453,7 @@ struct BinaryOperationPattern<3, Operation, _T0, _T1> int grid_x = M * ((N + BIG_TILE_SIZE_N - 1) / BIG_TILE_SIZE_N) * ((K + BIG_TILE_SIZE_K - 1) / BIG_TILE_SIZE_K); const dim3 grid_dim(grid_x, 1, 1); const dim3 block_dim(256, 1, 1); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); constexpr int rows = 8; int vec_size = 16 / output.element_size(); @@ -1529,7 +1525,7 @@ struct BinaryOperationPattern<5, Operation, _T0, _T1> int num_elements = output.numel(); int vec_size = 16 / output.element_size(); constexpr uint32_t row = 8; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); // optimize kernel @@ -1607,7 +1603,7 @@ struct BinaryOperationPattern<6, Operation, _T0, _T1> int num_elements = output.numel(); int vec_size = 16 / output.element_size(); constexpr uint32_t row = 8; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); // optimize kernel @@ -1686,7 +1682,7 @@ struct BinaryOperationPattern<7, Operation, _T0, _T1> int num_elements = output.numel(); int vec_size = 16 / output.element_size(); constexpr uint32_t row = 8; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); // optimize kernel @@ -1857,7 +1853,7 @@ struct BinaryOperationPattern<4, Operation, _T0, _T1> } } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); bool types_match = typeid(_T0) == typeid(_T1); int vec = 16 / output.element_size(); hipDevice_t dev; diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index b8f2b38e93..19eccb8f76 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -1,6 +1,6 @@ #pragma once /* - * Copyright © Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) Advanced Micro Devices, Inc. All rights reserved. * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,36 +15,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#ifdef USE_ROCM +#include "aiter_hip_common.h" +#include "ck_tile/core.hpp" +#include "communication_asm.h" +#include "hip_float8.h" #include -typedef __hip_bfloat16 nv_bfloat16; -#else -#include -#endif -#include -#include - +#include +#include #include #include #include #include #include -#include "communication_asm.h" -#include "hip_float8.h" -#include "ck_tile/core.hpp" -#define CUDACHECK(cmd) \ - do \ - { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) \ - { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) namespace aiter { @@ -117,13 +100,13 @@ namespace aiter DINLINE float &assign_add(float &a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } + DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); } template <> - DINLINE nv_bfloat16 downcast_s(float val) + DINLINE __hip_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } - DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) + DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b) { a = __hadd(a, b); return a; @@ -823,9 +806,9 @@ namespace aiter } } - using IPC_KEY = std::array; - static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); - static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + using IPC_KEY = std::array; + static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t)); + static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t)); class CustomAllreduce { @@ -855,7 +838,7 @@ namespace aiter * are passed in from the constructor */ CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, + const hipIpcMemHandle_t *handles, const std::vector &offsets, int rank, bool full_nvlink = true) : rank_(rank), @@ -889,9 +872,9 @@ namespace aiter if (new_handle) { char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), - cudaIpcMemLazyEnablePeerAccess)); + HIP_CALL(hipIpcOpenMemHandle((void **)&ipc_ptr, + *((const hipIpcMemHandle_t *)ipc_handle), + hipIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; @@ -901,7 +884,7 @@ namespace aiter get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); - auto handle_sz = sizeof(cudaIpcMemHandle_t); + auto handle_sz = sizeof(hipIpcMemHandle_t); std::vector handles(handle_sz * num_buffers, 0); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) @@ -910,16 +893,16 @@ namespace aiter void *base_ptr; // note: must share the base address of each allocation, or we get wrong // address - if (cuPointerGetAttribute(&base_ptr, + if (hipPointerGetAttribute(&base_ptr, #ifdef USE_ROCM HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, #else CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, #endif - (CUdeviceptr)ptr) != CUDA_SUCCESS) + (hipDeviceptr_t)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); - CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); + HIP_CALL(hipIpcGetMemHandle( + (hipIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); offsets[i] = ((char *)ptr) - ((char *)base_ptr); } return std::make_pair(handles, offsets); @@ -942,7 +925,7 @@ namespace aiter { if (i != rank_) { - cudaIpcMemHandle_t* ipc_handle_ptr = (cudaIpcMemHandle_t*)handles[i].data_ptr(); + hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[i].data_ptr(); char *handle = open_ipc_handle((void*)ipc_handle_ptr); handle += offsets[i]; data.ptrs[i] = handle; @@ -953,12 +936,12 @@ namespace aiter } } auto d_data = d_rank_data_base_++; - CUDACHECK( - cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + HIP_CALL( + hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice)); buffers_[self] = d_data; } - RankData *get_buffer_RD(cudaStream_t stream, void *input) + RankData *get_buffer_RD(hipStream_t stream, void *input) { RankData *ptrs; auto it = buffers_.find(input); @@ -968,9 +951,9 @@ namespace aiter } else { - cudaStreamCaptureStatus status; - CUDACHECK(cudaStreamIsCapturing(stream, &status)); - if (status == cudaStreamCaptureStatusActive) + hipStreamCaptureStatus status; + HIP_CALL(hipStreamIsCapturing(stream, &status)); + if (status == hipStreamCaptureStatusActive) { ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); graph_unreg_buffers_.push_back(input); @@ -1009,7 +992,7 @@ namespace aiter { if (j != rank_) { - cudaIpcMemHandle_t* ipc_handle_ptr = (cudaIpcMemHandle_t*)handles[j].data_ptr() + i; + hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[j].data_ptr() + i; char *handle = open_ipc_handle(ipc_handle_ptr); handle += *((int64_t*)offsets[j].data_ptr() + i); rd.ptrs[j] = handle; @@ -1020,9 +1003,9 @@ namespace aiter } } } - CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + HIP_CALL(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, - cudaMemcpyHostToDevice)); + hipMemcpyHostToDevice)); d_rank_data_base_ += num_buffers; graph_unreg_buffers_.clear(); } @@ -1035,7 +1018,7 @@ namespace aiter * should quant scale match hidden_dim when hidden_dim less than 128? * */ template - void runFp8QuantKernel(cudaStream_t stream, T* input, T* output, int size) + void runFp8QuantKernel(hipStream_t stream, T* input, T* output, int size) { RankData *ptrs = get_buffer_RD(stream, input); // 32 block 512 thread or 64 block 256 thread @@ -1099,7 +1082,7 @@ namespace aiter * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(hipStream_t stream, T *input, T *output, int size, #ifndef USE_ROCM int threads = 512, int block_limit = 20){ #else @@ -1199,14 +1182,14 @@ namespace aiter { for (auto [_, ptr] : ipc_handles_) { - CUDACHECK(cudaIpcCloseMemHandle(ptr)); + HIP_CALL(hipIpcCloseMemHandle(ptr)); } } }; // namespace aiter /** * To inspect PTX/SASS, copy paste this header file to compiler explorer and add a template instantiation: - * template void aiter::CustomAllreduce::allreduce(cudaStream_t, half *, + * template void aiter::CustomAllreduce::allreduce(hipStream_t, half *, half *, int, int, int); */ } // namespace aiter diff --git a/csrc/include/dtype_bfloat16.cuh b/csrc/include/dtype_bfloat16.cuh index 15fed672d8..1b4f175ef7 100644 --- a/csrc/include/dtype_bfloat16.cuh +++ b/csrc/include/dtype_bfloat16.cuh @@ -1,11 +1,11 @@ /* - * Copyright © Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) Advanced Micro Devices, Inc. All rights reserved. * Adapted from * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * and * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (C) 2023-2025, The vLLM team. + * Copyright (C) 2020-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,16 +24,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#ifndef USE_ROCM - #include - #include -#else - #include - #include - +#include +#include typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat16 __nv_bfloat16; -#endif #include diff --git a/csrc/include/mha_common.h b/csrc/include/mha_common.h index 120d55fbe5..211fde00e7 100644 --- a/csrc/include/mha_common.h +++ b/csrc/include/mha_common.h @@ -1,30 +1,33 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -#include +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch +// headers. +#include +#include +#include #include -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif +#include #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace aiter { __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state); -inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { +inline int +num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ // If we have enough to almost fill the SMs, then just use 1 split - if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + if(batch_nheads_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); @@ -34,22 +37,35 @@ inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int nu // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { efficiency.push_back(0.f); - } else { + } + else + { float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); + float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } + if(eff > max_efficiency) + { + max_efficiency = eff; + } efficiency.push_back(eff); } } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (!is_split_eligible(num_splits)) { continue; } - if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } @@ -57,7 +73,8 @@ inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int nu return 1; } -inline int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +inline int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { int device; auto status = hipGetDevice(&device); @@ -83,7 +100,7 @@ inline int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen return num_splits; } -template +template inline void print_fmha_fwd_args(ARG args) { printf("seqlen_q = %d\n", args.seqlen_q); @@ -124,4 +141,4 @@ inline void print_fmha_fwd_args(ARG args) printf("s_randval = %d\n", args.s_randval); } -} // namespace aiter \ No newline at end of file +} // namespace aiter diff --git a/csrc/include/quick_all_reduce.cuh b/csrc/include/quick_all_reduce.cuh index c26a85373c..96f77c0111 100644 --- a/csrc/include/quick_all_reduce.cuh +++ b/csrc/include/quick_all_reduce.cuh @@ -189,11 +189,11 @@ struct CodecQ4 : public CodecBase { int32_t int16_2 = (qw >> (i * 4)) & kMask000F; int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); - nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); - nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + __hip_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + __hip_bfloat16 bf_high = __float2bfloat16(static_cast(high)); nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); int32_t packed_bf16 = *reinterpret_cast(&bf2); - w[i] = packed_add(packed_bf16, kRangeMin); + w[i] = packed_add<__hip_bfloat16>(packed_bf16, kRangeMin); } } } @@ -365,11 +365,11 @@ struct CodecQ6 : public CodecBase { int16_t low = static_cast(int16_2 & 0xFFFF); int16_t high = static_cast((int16_2 >> 16) & 0xFFFF); - nv_bfloat16 bf_low = __float2bfloat16(static_cast(low)); - nv_bfloat16 bf_high = __float2bfloat16(static_cast(high)); + __hip_bfloat16 bf_low = __float2bfloat16(static_cast(low)); + __hip_bfloat16 bf_high = __float2bfloat16(static_cast(high)); nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); int32_t packed_bf16 = *reinterpret_cast(&bf2); - w[i] = packed_add(packed_bf16, kRangeMin); + w[i] = packed_add<__hip_bfloat16>(packed_bf16, kRangeMin); } } } @@ -534,7 +534,7 @@ struct CodecFP8 : public CodecBase { : "v"(wf1[0]), "v"(wf1[1])); } } else { - nv_bfloat16* wbf = reinterpret_cast(&w); + __hip_bfloat16* wbf = reinterpret_cast<__hip_bfloat16*>(&w); for (int i = 0; i < 2; i++) { fp32x2_t wf0_vec = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 0); fp32x2_t wf1_vec = __builtin_amdgcn_cvt_pk_f32_fp8(qw[i], 1); @@ -898,11 +898,11 @@ struct DeviceComms { TWOSHOT_DISPATCH(CodecFP) break; } - HIP_CHECK(cudaGetLastError()); + HIP_CHECK(hipGetLastError()); // Rotate the flag color. flag_color += divceil(N, grid); } }; -} // namespace aiter \ No newline at end of file +} // namespace aiter diff --git a/csrc/include/quick_all_reduce_base.h b/csrc/include/quick_all_reduce_base.h index 1b06ad5775..ee661a4132 100644 --- a/csrc/include/quick_all_reduce_base.h +++ b/csrc/include/quick_all_reduce_base.h @@ -142,7 +142,7 @@ __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4 } template <> -__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B) +__quickreduce_device_inline__ void packed_assign_add<__hip_bfloat16>(int32x4_t* A, int32x4_t* B) { nv_bfloat162* tA = reinterpret_cast(A); nv_bfloat162* tB = reinterpret_cast(B); @@ -165,7 +165,7 @@ __quickreduce_device_inline__ int packed_max(int a, int b) } template <> -__quickreduce_device_inline__ int packed_max(int a, int b) +__quickreduce_device_inline__ int packed_max<__hip_bfloat16>(int a, int b) { bf162_int_union A, B, R; A.i = a; @@ -186,7 +186,7 @@ __quickreduce_device_inline__ int packed_min(int a, int b) } template <> -__quickreduce_device_inline__ int packed_min(int a, int b) +__quickreduce_device_inline__ int packed_min<__hip_bfloat16>(int a, int b) { bf162_int_union A, B, R; A.i = a; @@ -211,7 +211,7 @@ __quickreduce_device_inline__ int packed_abs_max(int a, int b) } template <> -__quickreduce_device_inline__ int packed_abs_max(int a, int b) +__quickreduce_device_inline__ int packed_abs_max<__hip_bfloat16>(int a, int b) { bf162_int_union A, B, R; A.i = a; @@ -233,7 +233,7 @@ __quickreduce_device_inline__ int packed_add(int a, int b) } template <> -__quickreduce_device_inline__ int packed_add(int a, int b) +__quickreduce_device_inline__ int packed_add<__hip_bfloat16>(int a, int b) { bf162_int_union A, B, R; A.i = a; @@ -264,7 +264,7 @@ __quickreduce_device_inline__ int packed_sub(int a, int b) } template <> -__quickreduce_device_inline__ int packed_sub(int a, int b) +__quickreduce_device_inline__ int packed_sub<__hip_bfloat16>(int a, int b) { bf162_int_union A, B, R; A.i = a; @@ -285,7 +285,7 @@ __quickreduce_device_inline__ int packed_mul(int a, int b) } template <> -__quickreduce_device_inline__ int packed_mul(int a, int b) +__quickreduce_device_inline__ int packed_mul<__hip_bfloat16>(int a, int b) { nv_bfloat162* tA = reinterpret_cast(&a); nv_bfloat162* tB = reinterpret_cast(&b); @@ -303,7 +303,7 @@ __quickreduce_device_inline__ int packed_rcp(int a) } template <> -__quickreduce_device_inline__ int packed_rcp(int a) +__quickreduce_device_inline__ int packed_rcp<__hip_bfloat16>(int a) { bf162_int_union A, R; A.i = a; @@ -314,7 +314,7 @@ __quickreduce_device_inline__ int packed_rcp(int a) // changes dtype __quickreduce_device_inline__ float T2float_cast(half a) { return __half2float(a); } -__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { return __bfloat162float(a); } +__quickreduce_device_inline__ float T2float_cast(__hip_bfloat16 a) { return __bfloat162float(a); } template __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) diff --git a/csrc/kernels/activation_kernels.cu b/csrc/kernels/activation_kernels.cu index e8cf7da395..6c7da88327 100644 --- a/csrc/kernels/activation_kernels.cu +++ b/csrc/kernels/activation_kernels.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include +#include +#include #include #include @@ -164,8 +164,8 @@ static constexpr int nextPow2(unsigned int num) num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ dim3 grid(num_tokens); \ dim3 block(num_wave * 64); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ + const hipStream_t stream = at::hip::getCurrentHIPStream(); \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ using input_dtype = typename t2ck::type; \ AITER_DISPATCH_CASE_VEC_SIZE( \ @@ -186,8 +186,8 @@ static constexpr int nextPow2(unsigned int num) num_wave = num_wave > max_wave_num ? max_wave_num : num_wave; \ dim3 grid(num_tokens); \ dim3 block(num_wave * 64); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ + const hipStream_t stream = at::hip::getCurrentHIPStream(); \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ using input_dtype = typename t2ck::type; \ AITER_DISPATCH_CASE_VEC_SIZE( \ @@ -253,8 +253,8 @@ __global__ void activation_kernel(scalar_t* __restrict__ out, // [..., d int64_t num_tokens = input.numel() / d; \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); \ + const hipStream_t stream = at::hip::getCurrentHIPStream(); \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "activation_kernel", [&] { \ aiter::activation_kernel> \ <<>>(out.data_ptr(), input.data_ptr(), d); \ diff --git a/csrc/kernels/attention.cu b/csrc/kernels/attention.cu index 1228ee6c44..30d4a2ff52 100644 --- a/csrc/kernels/attention.cu +++ b/csrc/kernels/attention.cu @@ -16,8 +16,8 @@ */ #include -#include -#include +#include +#include #include #include "hip_compat.h" #include "attention.h" @@ -2249,8 +2249,8 @@ void paged_attention_custom_launcher( constexpr int NTHR = 256; //PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); switch (gqa_ratio) { case 1: //LAUNCH_CUSTOM_ATTENTION(1); diff --git a/csrc/kernels/attention_ragged.cu b/csrc/kernels/attention_ragged.cu index 206adb72d2..9d9f0486a1 100644 --- a/csrc/kernels/attention_ragged.cu +++ b/csrc/kernels/attention_ragged.cu @@ -1,15 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include #include #include -#include "attention_ragged.h" -#include "attention_common.cuh" - -#if defined(__HIPCC__) && \ - (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)) +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) || defined(__gfx950__)) #define __HIP__MI3XX_MI250__ #endif @@ -75,9 +71,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_ { return; } - const int64_t query_loc = static_cast(seq_idx); + const int64_t query_loc = static_cast(seq_idx); const int* block_table_seq = kv_page_indices + kv_indptr[seq_idx]; - _paged_attention_kernel(block_table_seq, query_loc, context_len, partition_start_token_idx, q, k_cache, v_cache, scale, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_seq_stride, exp_sums, max_logits, out, logits_soft_cap, logits_soft_cap_rcp, k_scale_ptr, v_scale_ptr, variant); + _paged_attention_kernel(block_table_seq, + query_loc, + context_len, + partition_start_token_idx, + q, + k_cache, + v_cache, + scale, + alibi_slopes, + q_stride, + kv_block_stride, + kv_head_stride, + kv_seq_stride, + exp_sums, + max_logits, + out, + logits_soft_cap, + logits_soft_cap_rcp, + k_scale_ptr, + v_scale_ptr, + variant); } // Grid: (num_heads, num_seqs). @@ -116,12 +140,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kern context_len = kv_indptr[seq_idx + 1] - kv_indptr[seq_idx]; } const int64_t query_loc = static_cast(seq_idx); - _paged_attention_ll4mi_reduce_kernel(query_loc, context_len, out, exp_sums, max_logits, tmp_out, max_num_partitions, fp8_out_scale_ptr); + _paged_attention_ll4mi_reduce_kernel(query_loc, + context_len, + out, + exp_sums, + max_logits, + tmp_out, + max_num_partitions, + fp8_out_scale_ptr); } #else // !defined(__HIP__MI3XX_MI250__) TODO: Add NAVI support - template -#include -#include +#include +#include #include "attention_v1.h" #include "attention_common.cuh" @@ -291,8 +291,8 @@ void paged_attention_custom_launcher(torch::Tensor& out, dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch(gqa_ratio) diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index a0fc08b018..5c7dd50643 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include +#include +#include #include #include "dispatch_utils.h" @@ -26,20 +26,20 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& bl { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); - cudaMemcpyKind memcpy_type; + hipMemcpyKind memcpy_type; if(src_device.is_cuda() && dst_device.is_cuda()) { TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); - memcpy_type = cudaMemcpyDeviceToDevice; + memcpy_type = hipMemcpyDeviceToDevice; } else if(src_device.is_cuda() && dst_device.is_cpu()) { - memcpy_type = cudaMemcpyDeviceToHost; + memcpy_type = hipMemcpyDeviceToHost; } else if(src_device.is_cpu() && dst_device.is_cuda()) { - memcpy_type = cudaMemcpyHostToDevice; + memcpy_type = hipMemcpyHostToDevice; } else { @@ -55,8 +55,8 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& bl char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(src_device.is_cuda() ? src_device : dst_device); + const hipStream_t stream = at::hip::getCurrentHIPStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. const int64_t num_blocks = block_mapping.size(0); for(size_t i = 0; i < num_blocks; i++) @@ -65,7 +65,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& bl int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( + hipMemcpyAsync( dst_ptr + dst_offset, src_ptr + src_offset, block_size_in_bytes, memcpy_type, stream); } } @@ -149,8 +149,8 @@ void copy_blocks(std::vector const& key_caches, const int numel_per_block = key_caches[0][0].numel(); dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); - const at::cuda::OptionalCUDAGuard device_guard(cache_device); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(cache_device); + const hipStream_t stream = at::hip::getCurrentHIPStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { aiter::copy_blocks_kernel <<>>( @@ -1034,8 +1034,8 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(key)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if(asm_layout) { @@ -1092,8 +1092,8 @@ void reshape_and_cache_flash( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(key)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE_FLASH); } @@ -1262,8 +1262,8 @@ void reshape_and_cache_with_pertoken_quant( dim3 grid(num_tokens, num_heads); dim3 block(64); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(key)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); using dequant_scale_t = float; // should align with k_dequant_scales/v_dequant_scales dtype @@ -1343,8 +1343,8 @@ void reshape_and_cache_with_block_quant( dim3 grid(batch_size, (seq_len + block_size - 1) / block_size + 1, num_heads); dim3 block(blockDimx); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(key)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); using dequant_scale_t = float; // should align with k_dequant_scales/v_dequant_scales dtype @@ -1432,8 +1432,8 @@ void reshape_and_cache_with_block_quant_for_asm_pa( int blockDimx = (ori_block_size + 255) / 256 * 256; dim3 grid(batch_size, (seq_len + ori_block_size - 1) / ori_block_size + 1, num_heads); dim3 block(blockDimx); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(key)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); using dequant_scale_t = float; // should align with k_dequant_scales/v_dequant_scales dtype diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 58f436a787..d28ad53ece 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -14,44 +14,48 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include +#include +#include #include #include "custom_all_reduce.cuh" // fake pointer type, must match fptr_t type in ops.h using fptr_t = int64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +static_assert(sizeof(void*) == sizeof(fptr_t)); namespace aiter { -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int64_t rank, +fptr_t init_custom_ar(torch::Tensor& meta, + torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, + int64_t rank, bool full_nvlink) { - int world_size = offsets.size(); - if (world_size > 8) - throw std::invalid_argument("world size > 8 is not supported"); - if (world_size % 2 != 0) - throw std::invalid_argument("Odd num gpus is not supported for now"); - if (world_size != handles.size()) - throw std::invalid_argument( - "handles length should equal to offsets length"); - if (rank < 0 || rank >= world_size) - throw std::invalid_argument("invalid rank passed in"); - - cudaIpcMemHandle_t ipc_handles[8]; - for (int i = 0; i < world_size; i++) - { - cudaIpcMemHandle_t* ipc_handle_ptr = (cudaIpcMemHandle_t*)handles[i].data_ptr(); - std::memcpy(&ipc_handles[i], ipc_handle_ptr, sizeof(cudaIpcMemHandle_t)); - } - return (fptr_t) new aiter::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), - rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); + int world_size = offsets.size(); + if(world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if(world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if(world_size != handles.size()) + throw std::invalid_argument("handles length should equal to offsets length"); + if(rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + hipIpcMemHandle_t ipc_handles[8]; + for(int i = 0; i < world_size; i++) + { + hipIpcMemHandle_t* ipc_handle_ptr = (hipIpcMemHandle_t*)handles[i].data_ptr(); + std::memcpy(&ipc_handles[i], ipc_handle_ptr, sizeof(hipIpcMemHandle_t)); + } + return (fptr_t) new aiter::CustomAllreduce(reinterpret_cast(meta.data_ptr()), + rank_data.data_ptr(), + rank_data.numel(), + ipc_handles, + offsets, + rank, + full_nvlink); } /** @@ -70,155 +74,150 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) +bool _is_weak_contiguous(torch::Tensor& t) { - return t.is_contiguous() || - (t.storage().nbytes() - t.storage_offset() * t.element_size() == - t.numel() * t.element_size()); + return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, - cudaStream_t stream, bool open_fp8_quant) +void _all_reduce( + fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool open_fp8_quant) { - auto fa = reinterpret_cast(_fa); - TORCH_CHECK(_is_weak_contiguous(out)); - switch (out.scalar_type()) - { - case at::ScalarType::Float: - { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); - break; - } - case at::ScalarType::Half: - { - /* - * By default, hidden_dim is a multiple of 128 - * Obvious effects can only be achieved when the data scale reaches a certain level - * */ - if (open_fp8_quant && out.numel() >= 128 * 2048) + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(_is_weak_contiguous(out)); + switch(out.scalar_type()) { - fa->runFp8QuantKernel(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + case at::ScalarType::Float: { + fa->allreduce(stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; } - else - { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + case at::ScalarType::Half: { + /* + * By default, hidden_dim is a multiple of 128 + * Obvious effects can only be achieved when the data scale reaches a certain level + * */ + if(open_fp8_quant && out.numel() >= 128 * 2048) + { + fa->runFp8QuantKernel(stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + } + else + { + fa->allreduce(stream, + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + } + break; } - break; - } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - case at::ScalarType::BFloat16: - { - fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); - break; - } + case at::ScalarType::BFloat16: { + fa->allreduce<__hip_bfloat16>(stream, + reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(out.data_ptr()), + out.numel()); + break; + } #endif - default: - throw std::runtime_error( - "custom allreduce only supports float32, float16 and bfloat16"); - } + default: + throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16"); + } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, bool open_fp8_quant) +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, bool open_fp8_quant) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - _all_reduce(_fa, inp, out, stream, open_fp8_quant); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream, open_fp8_quant); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - - auto input_size = inp.numel() * inp.element_size(); - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); - TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), - "registered buffer is too small to contain the input"); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), - input_size, cudaMemcpyDeviceToDevice, stream)); - _all_reduce(_fa, reg_buffer, out, stream, false); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + HIP_CALL(hipMemcpyAsync( + reg_buffer.data_ptr(), inp.data_ptr(), input_size, hipMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream, false); } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - delete fa; + auto fa = reinterpret_cast(_fa); + delete fa; } int64_t meta_size() { return sizeof(aiter::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) +void register_buffer(fptr_t _fa, + torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { - auto fa = reinterpret_cast(_fa); - fa->register_buffer(handles, offsets, t.data_ptr()); + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); } -std::vector get_graph_buffer_ipc_meta( - fptr_t _fa) +std::vector get_graph_buffer_ipc_meta(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); - auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto handles = - torch::empty({static_cast(handle_bytes.size())}, options); - std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); - - torch::Tensor offset_tensor = torch::from_blob(offsets.data(), {static_cast(offsets.size())}, torch::kInt64).clone(); - return {handles, offset_tensor}; + auto fa = reinterpret_cast(_fa); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + + torch::Tensor offset_tensor = + torch::from_blob(offsets.data(), {static_cast(offsets.size())}, torch::kInt64) + .clone(); + return {handles, offset_tensor}; } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector &offsets) +void register_graph_buffers(fptr_t _fa, + const std::vector& handles, + const std::vector& offsets) { - auto fa = reinterpret_cast(_fa); - fa->register_graph_buffers(handles, offsets); + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); } #ifdef USE_ROCM -void free_meta_buffer(void *buffer) { CUDACHECK(cudaFree(buffer)); } +void free_meta_buffer(void* buffer) { HIP_CALL(hipFree(buffer)); } -torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor &inp) +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) { - auto options = - torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); - auto data_handle = - torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); - CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)data_handle.data_ptr(), - inp.data_ptr())); - return data_handle; + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + HIP_CALL(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(), inp.data_ptr())); + return data_handle; } torch::Tensor allocate_meta_buffer(int64_t size) { - auto device_index = c10::cuda::current_device(); - at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); - void *buffer; - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - AT_CUDA_CHECK( - hipExtMallocWithFlags((void **)&buffer, size, hipDeviceMallocUncached)); - AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); - auto options = torch::TensorOptions() - .dtype(torch::kI8) - .device(torch::kCUDA, device_index); - return torch::from_blob(buffer, {size}, free_meta_buffer, options); + auto device_index = c10::hip::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed; + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + HIP_CALL(hipThreadExchangeStreamCaptureMode(&mode)); + HIP_CALL(hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); + HIP_CALL(hipMemsetAsync(buffer, 0, size, stream)); + HIP_CALL(hipStreamSynchronize(stream)); + HIP_CALL(hipThreadExchangeStreamCaptureMode(&mode)); + auto options = torch::TensorOptions().dtype(torch::kI8).device(torch::kCUDA, device_index); + return torch::from_blob(buffer, {size}, free_meta_buffer, options); } #endif diff --git a/csrc/kernels/custom_kernels.cu b/csrc/kernels/custom_kernels.cu index 4158b9a5e4..35390c28d8 100644 --- a/csrc/kernels/custom_kernels.cu +++ b/csrc/kernels/custom_kernels.cu @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "hip_compat.h" -#include +#include +#include #include -#include -#include -#include -#include +#include +#include +#include #include #include @@ -383,7 +383,7 @@ void LLGemm1(void* in_a, void* out_c, const int M, const int K, - cudaStream_t stream, + hipStream_t stream, const int rows_per_block, const c10::ScalarType scalar_type) { @@ -1658,7 +1658,7 @@ int mindiv(int N, int div1, int div2) constexpr int MAX_N = 16; template void launch_wv_splitk_small_fp16_bf16_kernel( - cudaStream_t stream, int K_in, int M_in, fptype* af4, const fptype* bf4, fptype* c, int CuCount) + hipStream_t stream, int K_in, int M_in, fptype* af4, const fptype* bf4, fptype* c, int CuCount) { dim3 grid(CuCount); dim3 block(64, 1); @@ -1679,7 +1679,7 @@ void launch_wv_splitk_small_fp16_bf16_kernel( } template -using KernelFuncPtr = void (*)(cudaStream_t, int, int, fptype*, const fptype*, fptype*, int); +using KernelFuncPtr = void (*)(hipStream_t, int, int, fptype*, const fptype*, fptype*, int); // generate jump table during compilation (1~MAX_N) template @@ -1695,7 +1695,7 @@ void wv_splitk_small_fp16_bf16(void* in_a, const int M_in, const int K_in, const int N_in, - cudaStream_t stream, + hipStream_t stream, const int CuCount, const c10::ScalarType scalar_type) { @@ -1724,7 +1724,7 @@ void wvSplitK_(void* in_a, const int M_in, const int K_in, const int N_in, - cudaStream_t stream, + hipStream_t stream, const int CuCount, const c10::ScalarType scalar_type) { @@ -2220,7 +2220,7 @@ void wvSplitKQ_(void* in_a, const int K_in, const int Kp_in, const int N_in, - cudaStream_t stream, + hipStream_t stream, const int CuCount, const c10::ScalarType a_scalar_type, const c10::ScalarType c_scalar_type) @@ -2352,7 +2352,7 @@ void LLGemmZZ(void* in_a, void* out_c, const int M, const int K, - cudaStream_t stream, + hipStream_t stream, const int solidx = 0) { // m -> M, n-> K @@ -2398,14 +2398,14 @@ void LLGemmZZ(void* in_a, reinterpret_cast(in_b), reinterpret_cast<_Float16*>(out_c)); } - cudaError_t err = cudaGetLastError(); - if(cudaSuccess != err) + hipError_t err = hipGetLastError(); + if(hipSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } // instantiate the kernel template for T=float: // template void AddGPUKernel(float *in_a, float *in_b, float *out_c, -// const int M, const int K, cudaStream_t stream); +// const int M, const int K, hipStream_t stream); const unsigned int TILE_WIDTH = 32; // Compute C = A * B __global__ void matrixMultiplyShared(float* A, @@ -2466,7 +2466,7 @@ void MMGPUKernel(float* in_a, int numBColumns, int numCRows, int numCColumns, - cudaStream_t stream) + hipStream_t stream) { // Initialize the grid and block dimensions dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); @@ -2475,8 +2475,8 @@ void MMGPUKernel(float* in_a, matrixMultiplyShared<<>>( in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, numCColumns); - cudaError_t err = cudaGetLastError(); - if(cudaSuccess != err) + hipError_t err = hipGetLastError(); + if(hipSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } } // namespace aiter diff --git a/csrc/kernels/fused_kernels.cu b/csrc/kernels/fused_kernels.cu index fad4889805..5c21d1f61b 100644 --- a/csrc/kernels/fused_kernels.cu +++ b/csrc/kernels/fused_kernels.cu @@ -1,6 +1,6 @@ /* * Copyright © Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2024, The vLLM team. + * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include #include +#include +#include +#include +#include constexpr int WARP_SIZE = 64; @@ -192,7 +193,7 @@ __global__ void LLGemm_Silu_kernel(float4 *af4, __half2 *bf4, _Float16 *c, // define the kernel calling code: // template void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, - cudaStream_t stream, const int rows_per_block = 4) + hipStream_t stream, const int rows_per_block = 4) { float4 *af4 = reinterpret_cast(in_a); auto *bf4 = reinterpret_cast<__half2 *>(in_b); @@ -227,7 +228,7 @@ void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, <<>>(af4, bf4, c, d); } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) + hipError_t err = hipGetLastError(); + if (hipSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } diff --git a/csrc/kernels/generate_binaryop.py b/csrc/kernels/generate_binaryop.py index ecdb810f72..b79d16f595 100644 --- a/csrc/kernels/generate_binaryop.py +++ b/csrc/kernels/generate_binaryop.py @@ -3,9 +3,9 @@ # generate kernel instances to speed up compilation import argparse -from pathlib import Path -from typing import List, Any from dataclasses import dataclass +from pathlib import Path +from typing import Any, List def get_if_str(idx, total, last_else=True): @@ -54,7 +54,7 @@ class BinaryOpCodegen: template void binary_op_impl(torch::Tensor &input, torch::Tensor &other, torch::Tensor &output) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); int dim = input.dim(); bool is_support = false; diff --git a/csrc/kernels/moe_align_block_size_kernels.cu b/csrc/kernels/moe_align_block_size_kernels.cu index b403144685..9d60134c1f 100644 --- a/csrc/kernels/moe_align_block_size_kernels.cu +++ b/csrc/kernels/moe_align_block_size_kernels.cu @@ -14,54 +14,50 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include #include -#include -#include #include -#include -#include "hip_compat.h" +#include "aiter_hip_common.h" #include "dispatch_utils.h" +#include "hip_compat.h" #define CEILDIV(x, y) (((x) + (y) - 1) / (y)) -namespace vllm -{ +namespace vllm { - namespace - { - __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) - { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; - } - } // namespace - - template - __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *token_nums, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel) - { +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) +{ + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} // namespace + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* token_nums, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) +{ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + const size_t start_idx = threadIdx.x * tokens_per_thread; extern __shared__ int32_t shared_mem[]; - int32_t *tokens_cnts = - shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t *cumsum = - shared_mem + (num_experts + 1) * - num_experts; // 1d tensor with shape (num_experts + 1) + int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - for (int i = 0; i < num_experts; ++i) + for(int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; } /** @@ -69,34 +65,33 @@ namespace vllm * which counts how many tokens in the token shard of thread_index are * assigned to expert expert_index. */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; } __syncthreads(); // For each expert we accumulate the token counts from the different threads. tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) + for(int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; } __syncthreads(); // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) + if(threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) - { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size); - } - *total_tokens_post_pad = cumsum[num_experts] * block_size; + cumsum[0] = 0; + for(int i = 1; i <= num_experts; ++i) + { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size); + } + *total_tokens_post_pad = cumsum[num_experts] * block_size; } __syncthreads(); @@ -106,11 +101,11 @@ namespace vllm * blocks and stores the corresponding expert_id for each block. */ auto num = tokens_cnts[index(num_experts, blockDim.x, threadIdx.x)]; - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i++) + for(int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i++) { - expert_ids[i] = threadIdx.x; - token_nums[i] = num; - num -= block_size; + expert_ids[i] = threadIdx.x; + token_nums[i] = num; + num -= block_size; } /** @@ -120,53 +115,54 @@ namespace vllm * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a * padding value(preset in python). */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id] * block_size; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id] * block_size; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; } - } +} } // namespace vllm namespace aiter { -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor token_nums, torch::Tensor num_tokens_post_pad) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(topk_ids)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] - { + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(topk_ids)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + VLLM_DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors const int32_t shared_mem = - ((num_experts + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); + ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - token_nums.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); }); + HIP_CALL( + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void*)kernel, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>(topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + token_nums.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel()); + }); } } // namespace aiter diff --git a/csrc/kernels/moe_fused_gate.cu b/csrc/kernels/moe_fused_gate.cu index aff0c36409..a2534afce4 100644 --- a/csrc/kernels/moe_fused_gate.cu +++ b/csrc/kernels/moe_fused_gate.cu @@ -15,21 +15,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#include "hip_compat.h" +#include "hip_reduce.h" +#include "vec_convert.h" +#include +#include +#include #include +#include #include #include - -#include #include -#include "hip_compat.h" -#include "hip_reduce.h" -#include "vec_convert.h" -#include -#include - /// Aligned array type template moe_fused_gate(at::Tensor& input, int ROWS_PER_WARP = std::max(1, WARP_SIZE / num_expert_group); size_t shared_mem_size = ((topk * sizeof(float) + topk * sizeof(int)) * ROWS_PER_WARP + 255) & ~255; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); dim3 block_dim(WARP_SIZE, WARPS_PER_CTA); // Check 1: Ensure that num_experts is a power of 2. diff --git a/csrc/kernels/pos_encoding_kernels.cu b/csrc/kernels/pos_encoding_kernels.cu index f8d568c97e..d7e73fccfa 100644 --- a/csrc/kernels/pos_encoding_kernels.cu +++ b/csrc/kernels/pos_encoding_kernels.cu @@ -1,6 +1,6 @@ /* * Copyright © Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2024, The vLLM team. + * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include "hip_compat.h" #include "dispatch_utils.h" @@ -179,8 +179,8 @@ void rotary_embedding( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { if (is_neox) { @@ -242,8 +242,8 @@ void batched_rotary_embedding( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { if (is_neox) { diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index bd41880579..685b1825c1 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -7,7 +7,7 @@ #include "quant_common.cuh" #include "rocprim/rocprim.hpp" #include "vec_convert.h" -#include +#include #include const int32_t BlockSize = 256; @@ -563,8 +563,8 @@ void static_per_tensor_quant(torch::Tensor& out, // [..., d] int rows = input.numel() / cols; dim3 grid(rows); dim3 block(BlockSize); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if(out.dtype() == torch_fp8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { @@ -632,8 +632,8 @@ void dynamic_per_tensor_quant(torch::Tensor& out, // [..., d] int rows = input.numel() / cols; dim3 grid(rows); dim3 block(BlockSize); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if(out.dtype() == torch_fp8) { AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "scaled_quant_kernel", [&] { @@ -685,8 +685,8 @@ void dynamic_per_token_scaled_quant(torch::Tensor& out, // [..., d] int const rows = input.numel() / cols; int32_t* num_rows_ptr = num_rows.has_value() ? num_rows->data_ptr() : nullptr; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if(cols == 32 || cols == 64 || cols == 128) { @@ -826,8 +826,8 @@ void dynamic_per_group_scaled_quant_fp4(torch::Tensor& out, // [..., d] TORCH_CHECK(cols % group_size == 0, __func__, " cols is not divisible by group_size"); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int thread_data_size = 32; int num_thread_per_group = group_size / thread_data_size; @@ -924,8 +924,8 @@ void smooth_per_token_scaled_quant( int32_t input_stride0 = input.stride(0); int32_t input_stride1 = input.dim() > 2 ? input.stride(1) : cols; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); dim3 const grid(rows); dim3 const block(BlockSize); @@ -1001,8 +1001,8 @@ void partial_transpose(torch::Tensor& out, // [rows, d] int const rows = input.numel() / cols; int32_t* num_rows_ptr = num_rows.data_ptr(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if(cols <= 1024) { diff --git a/csrc/kernels/quick_all_reduce.cu b/csrc/kernels/quick_all_reduce.cu index 0897011ecb..f3ded02c68 100644 --- a/csrc/kernels/quick_all_reduce.cu +++ b/csrc/kernels/quick_all_reduce.cu @@ -1,8 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include +#include +#include #include #ifdef USE_ROCM @@ -61,8 +60,8 @@ void qr_open_handles(fptr_t _fa, void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { auto fa = reinterpret_cast(_fa); - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = at::hip::getCurrentHIPStream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); @@ -77,9 +76,9 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, reinterpret_cast(out.data_ptr()), out.numel(), quant_level, stream); } else { - fa->allreduce( - reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce<__hip_bfloat16, false>( + reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(out.data_ptr()), out.numel(), quant_level, stream); } } else { @@ -98,14 +97,14 @@ int64_t qr_max_size() { template struct AllReduceTwoshot, cast_bf2half>; \ template struct AllReduceTwoshot, cast_bf2half>; \ -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecFP, false) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecQ4, false) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecQ6, false) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecFP8, false) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecFP, true) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecQ4, true) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecQ6, true) -INSTANTIATE_FOR_WORLDSIZE(nv_bfloat16, CodecFP8, true) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecFP8, false) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(__hip_bfloat16, CodecFP8, true) INSTANTIATE_FOR_WORLDSIZE(half, CodecFP, false) INSTANTIATE_FOR_WORLDSIZE(half, CodecQ4, false) diff --git a/csrc/kernels/rmsnorm_kernels.cu b/csrc/kernels/rmsnorm_kernels.cu index d8bd1459f8..4f5092c662 100644 --- a/csrc/kernels/rmsnorm_kernels.cu +++ b/csrc/kernels/rmsnorm_kernels.cu @@ -1,6 +1,6 @@ /* * Copyright © Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2024, The vLLM team. + * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include "dispatch_utils.h" -// #include "attention/attention_dtypes.h" -#ifndef USE_ROCM -#include -#include -#include -#include -#else #include #include -#include #include -// #include "quantization/fp8/amd/hip_float8.h" -// #include "quantization/fp8/amd/quant_utils.cuh" +#include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; -#endif namespace vllm { @@ -110,10 +100,10 @@ namespace vllm } float v8_variance_sum = v8_variance.sum(); - using BlockReduce = cub::BlockReduce; + using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; float variance = - BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); + BlockReduce(reduceStore).Reduce(v8_variance_sum, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { @@ -145,9 +135,9 @@ namespace vllm // variance += x * x; // } - // using BlockReduce = cub::BlockReduce; + // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; - // variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -384,9 +374,9 @@ namespace vllm residual_v[id] = temp; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { @@ -427,9 +417,9 @@ namespace vllm residual[blockIdx.x * hidden_size + idx] = z; } - using BlockReduce = cub::BlockReduce; + using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; - variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); if (threadIdx.x == 0) { @@ -502,9 +492,9 @@ namespace vllm // residual_v[id] = temp; // } - // using BlockReduce = cub::BlockReduce; + // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; - // variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -545,9 +535,9 @@ namespace vllm // residual[blockIdx.x * hidden_size + idx] = z; // } - // using BlockReduce = cub::BlockReduce; + // using BlockReduce = hipcub::BlockReduce; // __shared__ typename BlockReduce::TempStorage reduceStore; - // variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + // variance = BlockReduce(reduceStore).Reduce(variance, hipcub::Sum{}, blockDim.x); // if (threadIdx.x == 0) { // s_variance = rsqrtf(variance / hidden_size + epsilon); @@ -574,8 +564,8 @@ void rms_norm(torch::Tensor &out, // [..., hidden_size] dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), @@ -591,8 +581,8 @@ void rms_norm(torch::Tensor &out, // [..., hidden_size] // dim3 grid(num_tokens); // dim3 block(std::min(hidden_size, 1024)); -// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); +// const hipStream_t stream = at::hip::getCurrentHIPStream(); // VLLM_DISPATCH_FLOATING_TYPES( // input.scalar_type(), "scaled_rms_norm_kernel", [&] { // vllm::scaled_rms_norm_kernel<<>>( @@ -625,8 +615,8 @@ void fused_add_rms_norm(torch::Tensor &input, // [..., hidden_size] hiding on global mem ops. */ const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); /*If the tensor types are FP16/BF16, try to use the optimized kernel with packed + vectorized ops. Max optimization is achieved with a width-8 vector of FP16/BF16s @@ -675,8 +665,8 @@ void fused_add_rms_norm(torch::Tensor &input, // [..., hidden_size] // hiding on global mem ops. */ // const int max_block_size = (num_tokens < 256) ? 1024 : 256; // dim3 block(std::min(hidden_size, max_block_size)); -// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); +// const hipStream_t stream = at::hip::getCurrentHIPStream(); // /*If the tensor types are FP16/BF16, try to use the optimized kernel // with packed + vectorized ops. // Max optimization is achieved with a width-8 vector of FP16/BF16s diff --git a/csrc/kernels/rope/general_bwd_kernels.cu b/csrc/kernels/rope/general_bwd_kernels.cu index 482e5c6323..e846538b5d 100644 --- a/csrc/kernels/rope/general_bwd_kernels.cu +++ b/csrc/kernels/rope/general_bwd_kernels.cu @@ -32,7 +32,7 @@ void rope_bwd_impl( const int32_t stride_i_h = input_grads.stride(2); const int32_t stride_i_d = input_grads.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads)); DISPATCH_ROPE_TYPES_PARAMS( output_grads.scalar_type(), freqs.scalar_type(), @@ -86,7 +86,7 @@ void rope_2c_bwd_impl( const int32_t stride_iy_h = input_grads_y.stride(2); const int32_t stride_iy_d = input_grads_y.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads_x)); DISPATCH_ROPE_TYPES_PARAMS( output_grads_x.scalar_type(), freqs.scalar_type(), @@ -134,7 +134,7 @@ void rope_cached_bwd_impl( const int32_t stride_i_h = input_grads.stride(2); const int32_t stride_i_d = input_grads.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads)); DISPATCH_ROPE_TYPES_PARAMS( output_grads.scalar_type(), cos.scalar_type(), @@ -190,7 +190,7 @@ void rope_cached_2c_bwd_impl( const int32_t stride_iy_h = input_grads_y.stride(2); const int32_t stride_iy_d = input_grads_y.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads_x)); DISPATCH_ROPE_TYPES_PARAMS( output_grads_x.scalar_type(), cos.scalar_type(), @@ -238,7 +238,7 @@ void rope_thd_bwd_impl( const int32_t stride_i_h = input_grads.stride(1); const int32_t stride_i_d = input_grads.stride(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads)); DISPATCH_ROPE_TYPES_PARAMS( output_grads.scalar_type(), freqs.scalar_type(), @@ -288,7 +288,7 @@ void rope_2d_bwd_impl( TORCH_CHECK(size_s == img_height * img_width, "rope_2d_fwd_impl - input tensor shape doesn't match image size."); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_grads)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_grads)); DISPATCH_ROPE_TYPES_PARAMS( output_grads.scalar_type(), cos_h.scalar_type(), diff --git a/csrc/kernels/rope/general_fwd_kernels.cu b/csrc/kernels/rope/general_fwd_kernels.cu index 708036a293..8379a3a3aa 100644 --- a/csrc/kernels/rope/general_fwd_kernels.cu +++ b/csrc/kernels/rope/general_fwd_kernels.cu @@ -32,7 +32,7 @@ void rope_fwd_impl( const int32_t stride_o_h = output.stride(2); const int32_t stride_o_d = output.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), freqs.scalar_type(), @@ -86,7 +86,7 @@ void rope_2c_fwd_impl( const int32_t stride_oy_h = output_y.stride(2); const int32_t stride_oy_d = output_y.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_x)); DISPATCH_ROPE_TYPES_PARAMS( input_x.scalar_type(), freqs.scalar_type(), @@ -134,7 +134,7 @@ void rope_cached_fwd_impl( const int32_t stride_o_h = output.stride(2); const int32_t stride_o_d = output.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), cos.scalar_type(), @@ -190,7 +190,7 @@ void rope_cached_2c_fwd_impl( const int32_t stride_oy_h = output_y.stride(2); const int32_t stride_oy_d = output_y.stride(3); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_x)); DISPATCH_ROPE_TYPES_PARAMS( input_x.scalar_type(), cos.scalar_type(), @@ -237,7 +237,7 @@ void rope_thd_fwd_impl( const int32_t stride_o_h = output.stride(1); const int32_t stride_o_d = output.stride(2); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), freqs.scalar_type(), @@ -287,7 +287,7 @@ void rope_2d_fwd_impl( TORCH_CHECK(size_s == img_height * img_width, "rope_2d_fwd_impl - input tensor shape doesn't match image size."); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), cos_h.scalar_type(), diff --git a/csrc/kernels/rope/pos_fwd_kernels.cu b/csrc/kernels/rope/pos_fwd_kernels.cu index 8208b893ba..506681e7b9 100644 --- a/csrc/kernels/rope/pos_fwd_kernels.cu +++ b/csrc/kernels/rope/pos_fwd_kernels.cu @@ -54,7 +54,7 @@ void rope_cached_positions_fwd_impl( assert(1 == positions.stride(1) && 2 == positions.dim()); const int32_t max_position = cos.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), cos.scalar_type(), @@ -136,7 +136,7 @@ void rope_cached_positions_2c_fwd_impl( assert(1 == positions.stride(1) && 2 == positions.dim()); const int32_t max_position = cos.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_x)); DISPATCH_ROPE_TYPES_PARAMS( input_x.scalar_type(), cos.scalar_type(), @@ -212,7 +212,7 @@ void rope_cached_positions_offsets_fwd_impl( assert(1 == offsets.stride(1) && 2 == offsets.dim()); const int32_t max_position = cos.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); DISPATCH_ROPE_TYPES_PARAMS( input.scalar_type(), cos.scalar_type(), @@ -299,7 +299,7 @@ void rope_cached_positions_offsets_2c_fwd_impl( assert(1 == offsets.stride(1) && 2 == offsets.dim()); const int32_t max_position = cos.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_x)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input_x)); DISPATCH_ROPE_TYPES_PARAMS( input_x.scalar_type(), cos.scalar_type(), diff --git a/csrc/kernels/rope/rope_common.h b/csrc/kernels/rope/rope_common.h index ac415b4396..85941e09c4 100644 --- a/csrc/kernels/rope/rope_common.h +++ b/csrc/kernels/rope/rope_common.h @@ -4,7 +4,7 @@ #pragma once #include "dispatch_utils.h" -#include +#include // ===================================================================================================================== // Keyword interpretation @@ -2670,7 +2670,7 @@ void dispatch_1c_sbhd_uncached(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); @@ -2761,7 +2761,7 @@ void dispatch_2c_sbhd_uncached(scalar_t* __restrict__ p_output_x, const int32_t stride_oy_h, const int32_t stride_oy_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); @@ -2870,7 +2870,7 @@ void dispatch_1c_sbhd_cached(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); @@ -2964,7 +2964,7 @@ void dispatch_2c_sbhd_cached(scalar_t* __restrict__ p_output_x, const int32_t stride_oy_h, const int32_t stride_oy_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); @@ -3076,7 +3076,7 @@ void dispatch_1c_sbhd_cached_indirect(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); @@ -3177,7 +3177,7 @@ void dispatch_2c_sbhd_cached_indirect(scalar_t* __restrict__ p_output_x, const int32_t stride_oy_h, const int32_t stride_oy_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); @@ -3295,7 +3295,7 @@ void dispatch_1c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); @@ -3400,7 +3400,7 @@ void dispatch_2c_sbhd_cached_indirect2(scalar_t* __restrict__ p_output_x, const int32_t stride_oy_h, const int32_t stride_oy_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_s, size_b); const dim3 block(C10_WARP_SIZE, size_h_x < 16 ? 4 : 8); @@ -3516,7 +3516,7 @@ void dispatch_1c_thd_uncached(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(size_max_s, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); @@ -3597,7 +3597,7 @@ void dispatch_1c_2d_cached(scalar_t* __restrict__ p_output, const int32_t stride_o_h, const int32_t stride_o_d) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); const dim3 grid(img_height, img_width, size_b); const dim3 block(C10_WARP_SIZE, size_h < 16 ? 4 : 8); diff --git a/csrc/kernels/sample_kernels.cu b/csrc/kernels/sample_kernels.cu index 1ad97bf0a5..8d313c7d18 100644 --- a/csrc/kernels/sample_kernels.cu +++ b/csrc/kernels/sample_kernels.cu @@ -8,8 +8,8 @@ #include "rocprim/rocprim.hpp" #include "vec_convert.h" #include -#include -#include +#include +#include #include #include #include @@ -295,8 +295,8 @@ __global__ void mix_sample_kernel(const DTYPE_I* input, void greedy_sample(torch::Tensor& out, torch::Tensor& input) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int M = input.size(0); int N = input.size(1); @@ -324,8 +324,8 @@ void random_sample(torch::Tensor& out, std::optional generator = std::nullopt, float eps = 1e-10) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); auto gen = get_generator_or_default( generator, at::cuda::detail::getDefaultCUDAGenerator()); @@ -385,8 +385,8 @@ void mixed_sample(torch::Tensor& out, std::optional generator = std::nullopt, float eps = 1e-10) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); auto gen = get_generator_or_default( generator, at::cuda::detail::getDefaultCUDAGenerator()); @@ -476,8 +476,8 @@ __global__ void exponential_kernel(DTYPE_O* output, vec_o vec_cur; for(int i = 0; i < vec_size_o; i++) { - float u = transform_func((&rand.x)[i]) + eps; - vec_cur[i] = ck_tile::type_convert(u); + float u = transform_func((&rand.x)[i]) + eps; + vec_cur[i] = ck_tile::type_convert(u); } buffer_o.template set(k, 0, true, vec_cur.template get_as()); } @@ -488,8 +488,8 @@ void exponential(torch::Tensor& out, std::optional generator = std::nullopt, float eps = 1e-10) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); auto gen = get_generator_or_default( generator, at::cuda::detail::getDefaultCUDAGenerator()); diff --git a/csrc/kernels/topk_softmax_kernels.cu b/csrc/kernels/topk_softmax_kernels.cu index 21c7d2d714..e5137557bf 100644 --- a/csrc/kernels/topk_softmax_kernels.cu +++ b/csrc/kernels/topk_softmax_kernels.cu @@ -23,17 +23,12 @@ #include "hip_reduce.h" #include "py_itfs_common.h" #include "vec_convert.h" -#include -#include +#include +#include #include -#ifndef USE_ROCM -#include -#include -#else #include #include -#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -59,7 +54,7 @@ template __launch_bounds__(TPB) __global__ void moeSoftmax(const DTYPE* input, const bool* finished, float* output, const int num_cols) { - using BlockReduce = cub::BlockReduce; + using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ float normalizing_factor; @@ -67,7 +62,7 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; - cub::Sum sum; + hipcub::Sum sum; float threadData(-FLT_MAX); // Don't touch finished rows. @@ -82,7 +77,7 @@ __launch_bounds__(TPB) __global__ threadData = max(static_cast(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, hipcub::Max()); if(threadIdx.x == 0) { float_max = maxElem; @@ -126,12 +121,12 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax const bool need_renorm) { - using cub_kvp = cub::KeyValuePair; - using BlockReduce = cub::BlockReduce; + using cub_kvp = hipcub::KeyValuePair; + using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; cub_kvp thread_kvp; - cub::ArgMax arg_max; + hipcub::ArgMax arg_max; const int num_rows = gridDim.x; const int block_row = blockIdx.x; @@ -491,7 +486,7 @@ void topkGatingSoftmaxLauncherHelper(const DTYPE* input, const int output_stride, const int indices_stride, const bool need_renorm, - cudaStream_t stream) + hipStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 32; @@ -562,7 +557,7 @@ void topkGatingSoftmaxKernelLauncher(const DTYPE* gating_output, const int topk_weights_stride, const int topk_id_stride, const bool need_renorm, - cudaStream_t stream) + hipStream_t stream) { static constexpr int WARPS_PER_TB = 8; switch(num_experts) @@ -637,8 +632,8 @@ void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk] const bool needs_workspace = !is_pow_2 || num_experts > 256; const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; - const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options().dtype(torch::kFloat32)); VLLM_DISPATCH_FLOATING_TYPES(gating_output.scalar_type(), "topk_softmax", [&] { @@ -669,8 +664,8 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(output)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); switch(topk) { diff --git a/csrc/kernels/topk_softmax_kernels_group.cu b/csrc/kernels/topk_softmax_kernels_group.cu index df06a6c978..629c763e3f 100644 --- a/csrc/kernels/topk_softmax_kernels_group.cu +++ b/csrc/kernels/topk_softmax_kernels_group.cu @@ -1258,7 +1258,7 @@ void biased_grouped_topk(torch::Tensor& gating_output, // [num_tokens, num_exp ); const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); - const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); LAUNCH_KERNEL() } @@ -1292,7 +1292,7 @@ void grouped_topk(torch::Tensor& gating_output, // [num_tokens, num_experts] ~255; const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); - const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); LAUNCH_KERNEL() } diff --git a/csrc/kernels/unary_operator.cu b/csrc/kernels/unary_operator.cu index 17401a370c..42f079701f 100644 --- a/csrc/kernels/unary_operator.cu +++ b/csrc/kernels/unary_operator.cu @@ -1,6 +1,6 @@ /* * Copyright © Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2024, The vLLM team. + * Copyright (C) 2024-2025, The vLLM team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,20 +15,16 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include "hip_compat.h" #include "dispatch_utils.h" #include #include -#ifdef USE_ROCM #include +#include typedef __hip_bfloat16 nv_bfloat16; -#else -#include -#endif -#include namespace aiter { @@ -154,7 +150,7 @@ torch::Tensor unary_operation(torch::Tensor &input) void *buf_c = reinterpret_cast(output.data_ptr()); void *buf_a = reinterpret_cast(input.data_ptr()); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int elements = N * K; constexpr uint32_t wg = 256; diff --git a/csrc/py_itfs_ck/attention_kernels.cu b/csrc/py_itfs_ck/attention_kernels.cu index 69a82071b2..f9b018c339 100644 --- a/csrc/py_itfs_ck/attention_kernels.cu +++ b/csrc/py_itfs_ck/attention_kernels.cu @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include -#include +#include +#include #include "py_itfs_common.h" #include "ck_tile/ref/naive_attention.hpp" @@ -25,8 +25,8 @@ torch::Tensor pa_fwd_naive(torch::Tensor &Q, // [num_seqs, num_head const int quant_algo, // 0: no quant, 1: per-token FP8 quant std::optional &out_) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); TORCH_CHECK(block_tables.dtype() == torch::kInt32, "block_tables must be int32"); TORCH_CHECK(context_lens.dtype() == torch::kInt32, "context_lens must be int32"); torch::Tensor out = out_.value_or(torch::empty_like(Q)); diff --git a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu index 86bb6e31fc..52750fa8d6 100644 --- a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu +++ b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu @@ -1,12 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include "py_itfs_common.h" #include "mha_common.h" - #include "mha_fwd.h" +#include "py_itfs_common.h" +#include +#include namespace aiter { namespace torch_itfs { @@ -297,7 +296,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; bool has_lse = return_softmax_lse; bool has_dropout = p_dropout > 0.0f; @@ -350,7 +349,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] if(max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); ck_tile::stream_config stream_config{stream}; auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index 25029da6ff..704188a29b 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -230,7 +230,7 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -322,7 +322,7 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] dv = torch::empty_like(v); } - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); diff --git a/csrc/py_itfs_ck/mha_fwd_kernels.cu b/csrc/py_itfs_ck/mha_fwd_kernels.cu index a8ead2049b..1bdfde270b 100644 --- a/csrc/py_itfs_ck/mha_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_fwd_kernels.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -265,7 +265,7 @@ mha_fwd(at::Tensor &q, // [b, sq, hq, d] } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; bool has_lse = return_softmax_lse; bool has_dropout = p_dropout > 0.0f; @@ -303,7 +303,7 @@ mha_fwd(at::Tensor &q, // [b, sq, hq, d] if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); ck_tile::stream_config stream_config{stream}; auto args = diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index 84f70ac3f3..1fd6fb9063 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -212,7 +212,7 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -311,7 +311,7 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] dv = torch::empty_like(v); } - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index 4c9d71b089..712f2e7791 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -485,7 +485,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; bool has_lse = return_softmax_lse; bool has_dropout = p_dropout > 0.0f; @@ -532,7 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // [total_q, hq, d] std::optional seqlens_k = std::nullopt; if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); ck_tile::stream_config stream_config{stream}; if (paged_KV) diff --git a/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu b/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu index 915713f1e2..1de76026e7 100644 --- a/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu +++ b/csrc/py_itfs_ck/moe_ck_2stages_kernel.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include -#include +#include +#include #include "py_itfs_common.h" #include "moe_ck_gemm.hpp" @@ -15,26 +15,26 @@ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 64) \ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ @@ -42,26 +42,26 @@ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 64) \ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ }\ }\ else if (ActOP == 1) \ @@ -72,26 +72,26 @@ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 64) \ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ @@ -99,26 +99,26 @@ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 64) \ { \ if (K % (256 / sizeof(A0DataType)) == 0) \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ }\ } @@ -128,20 +128,20 @@ if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ }\ else if (ActOP == 1) \ @@ -149,20 +149,20 @@ if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage1_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ + ck_moe_stage1_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, hidden_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w1_scale_ptr, a1_scale_ptr); \ } \ } \ @@ -180,7 +180,7 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token std::optional sorted_weights = std::nullopt, std::optional act_op = 0) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); // TORCH_CHECK(hidden_states.dtype() == w1.dtype(), // "Weights and activations should both be same dtype!"); @@ -361,20 +361,20 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ } \ else \ @@ -382,20 +382,20 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ } @@ -405,20 +405,20 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ } \ else \ @@ -426,20 +426,20 @@ void ck_moe_stage1(torch::Tensor &hidden_states, // [m, k], input token if (isPerTensorQuant) \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ else \ { \ if (MPerBlock == 32) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 64) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ else if (MPerBlock == 128) \ - ck_moe_stage2_gemm(at::cuda::getCurrentCUDAStream().stream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ + ck_moe_stage2_gemm(at::hip::getCurrentHIPStream(), tokens, sorted_size, N, K, topk, inter_states_ptr, w1_ptr, w2_ptr, sorted_token_ids_ptr, sorted_expert_ids_ptr, sorted_weights_ptr, num_valid_ids_ptr, out_ptr, w2_scale_ptr, a2_scale_ptr); \ } \ } @@ -456,7 +456,7 @@ void ck_moe_stage2(torch::Tensor &inter_states, // [m, k], input token std::optional block_m = 32, std::optional sorted_weights = std::nullopt) // [max_num_tokens_padded]) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(out)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(out)); // TORCH_CHECK(inter_states.dtype() == w2.dtype(), // "Weights and activations should both be same dtype!"); // diff --git a/csrc/py_itfs_ck/moe_sorting_kernels.cu b/csrc/py_itfs_ck/moe_sorting_kernels.cu index 15731607ab..cba3416406 100644 --- a/csrc/py_itfs_ck/moe_sorting_kernels.cu +++ b/csrc/py_itfs_ck/moe_sorting_kernels.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "py_itfs_common.h" -#include -#include +#include +#include #include #include "moe_sorting_api.hpp" @@ -29,8 +29,8 @@ void moe_sorting_fwd(torch::Tensor& topk_ids, // [m, topk] auto dtype_str = torchDTypeToStr(topk_ids.dtype()); int num_tokens = topk_ids.size(0); int topk = topk_ids.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(topk_ids)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(topk_ids)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int workspace_size = moe_sorting_get_workspace_size(num_tokens, num_experts, topk, dispatch_policy); void* ws_ptr = nullptr; diff --git a/csrc/py_itfs_ck/norm_kernels.cu b/csrc/py_itfs_ck/norm_kernels.cu index c76d2af2ec..98b87162e8 100644 --- a/csrc/py_itfs_ck/norm_kernels.cu +++ b/csrc/py_itfs_ck/norm_kernels.cu @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include -#include +#include +#include #include "py_itfs_common.h" #include "layernorm2d_fwd.hpp" @@ -27,8 +27,8 @@ void layernorm2d(torch::Tensor &out, // [m, n] int y_stride = out.stride(0); int yr_stride = -1; bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision @@ -86,8 +86,8 @@ void layernorm2d_with_add(torch::Tensor &out, // [m ,n] int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision @@ -139,8 +139,8 @@ void layernorm2d_with_smoothquant(torch::Tensor &out, // [m ,n] int y_stride = out.stride(0); int yr_stride = -1; bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision @@ -194,8 +194,8 @@ void layernorm2d_with_add_smoothquant(torch::Tensor &out, // [m ,n] int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision @@ -245,8 +245,8 @@ void layernorm2d_with_dynamicquant(torch::Tensor &out, // [m ,n] int y_stride = out.stride(0); int yr_stride = -1; bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision @@ -298,8 +298,8 @@ void layernorm2d_with_add_dynamicquant(torch::Tensor &out, // [m ,n] int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveMeanVar = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); layernorm2d_fwd({ dtype_str, // input precision diff --git a/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu b/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu index 5701d70e84..2a0de07d2d 100644 --- a/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu +++ b/csrc/py_itfs_ck/rmsnorm_ck_kernels.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "py_itfs_common.h" -#include -#include +#include +#include #include #include "rmsnorm2d_fwd.hpp" @@ -26,8 +26,8 @@ void rmsnorm2d( int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision @@ -90,8 +90,8 @@ void rmsnorm2d_with_add( int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision dtype_str, // output precision @@ -145,8 +145,8 @@ void rmsnorm2d_with_smoothquant( int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision @@ -203,8 +203,8 @@ void rmsnorm2d_with_add_smoothquant( int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision @@ -257,8 +257,8 @@ void rmsnorm2d_with_dynamicquant( int y_stride = out.stride(0); int yr_stride = -1; bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision @@ -312,8 +312,8 @@ void rmsnorm2d_with_add_dynamicquant( int y_stride = out.stride(0); int yr_stride = residual_out.stride(0); bool SaveRms = false; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); rmsnorm2d_fwd({dtype_str, // input precision out_dtype_str, // output precision diff --git a/csrc/py_itfs_ck/smoothquant_kernels.cu b/csrc/py_itfs_ck/smoothquant_kernels.cu index 221ea806fe..b4f25b5855 100644 --- a/csrc/py_itfs_ck/smoothquant_kernels.cu +++ b/csrc/py_itfs_ck/smoothquant_kernels.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include -#include +#include +#include #include "py_itfs_common.h" #include "smoothquant.hpp" @@ -22,8 +22,8 @@ void smoothquant_fwd(torch::Tensor &out, // [m ,n] int m = input.numel() / n; int stride = n; int out_stride = out.stride(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); smoothquant({ dtype_str // input dtype @@ -57,8 +57,8 @@ void moe_smoothquant_fwd(torch::Tensor &out, // [topk * tokens, hidden_size int experts = x_scale.size(0); int topk = topk_ids.size(1); int stride = n; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); moe_smoothquant({ dtype_str, // input dtype diff --git a/csrc/py_itfs_cu/asm_communication.cu b/csrc/py_itfs_cu/asm_communication.cu index 7f029a9b85..89b8e355e3 100644 --- a/csrc/py_itfs_cu/asm_communication.cu +++ b/csrc/py_itfs_cu/asm_communication.cu @@ -1,59 +1,64 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include +#include +#include #include +#include #include -#include -#include -#include "communication_asm.h" #include "aiter_hip_common.h" +#include "communication_asm.h" #include "custom_all_reduce.cuh" -torch::Tensor all_reduce_asm(torch::Tensor &input, +torch::Tensor all_reduce_asm(torch::Tensor& input, int64_t _ca, - torch::Tensor ®_sig, torch::Tensor ®_buffer, bool isGraph) + torch::Tensor& reg_sig, + torch::Tensor& reg_buffer, + bool isGraph) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); auto input_size = input.numel() * input.element_size(); - void *inp_ptr = input.data_ptr(); - if (!isGraph) + void* inp_ptr = input.data_ptr(); + if(!isGraph) { TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), - "registered buffer is too small to contain the input", input_size, ">", reg_buffer.numel() * reg_buffer.element_size()); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp_ptr, - input_size, cudaMemcpyDeviceToDevice, stream)); + "registered buffer is too small to contain the input", + input_size, + ">", + reg_buffer.numel() * reg_buffer.element_size()); + HIP_CALL(hipMemcpyAsync( + reg_buffer.data_ptr(), inp_ptr, input_size, hipMemcpyDeviceToDevice, stream)); inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); + auto ca = reinterpret_cast(_ca); using RD = aiter::RankData; - RD *input_rd = ca->get_buffer_RD(stream, inp_ptr); - RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); + RD* input_rd = ca->get_buffer_RD(stream, inp_ptr); + RD* sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); struct __attribute__((packed)) KernelArgs { - void *ptr_gpu0_data; + void* ptr_gpu0_data; p2 _p0; - void *ptr_gpu0_sig; + void* ptr_gpu0_sig; p2 _p8; - void *ptr_gpu1_sig; + void* ptr_gpu1_sig; p2 _p9; - void *ptr_gpu2_sig; + void* ptr_gpu2_sig; p2 _p10; - void *ptr_gpu3_sig; + void* ptr_gpu3_sig; p2 _p11; - void *ptr_gpu4_sig; + void* ptr_gpu4_sig; p2 _p12; - void *ptr_gpu5_sig; + void* ptr_gpu5_sig; p2 _p13; - void *ptr_gpu6_sig; + void* ptr_gpu6_sig; p2 _p14; - void *ptr_gpu7_sig; + void* ptr_gpu7_sig; p2 _p15; unsigned int gpuId; p3 _p16; @@ -67,30 +72,34 @@ torch::Tensor all_reduce_asm(torch::Tensor &input, p3 _p20; }; - int bdx = 256; - int gdx = 64; - int gdy = 1; - int gdz = 1; - int stride_GPU = input_size / ca->world_size_; // stride base on the pass in GPU id; gpu0 focus on 0~15; gpu1 focus on 16~31 - int stride_TG = stride_GPU / gdx; // stride base on TG id; 64 TGs, every TG focus on 16*8192/64=2048 elements - int stride_WV = stride_TG / (bdx / 64); // stride base on Wave id, 4 waves, every wave focus on 512 elements; 1024 bytes + int bdx = 256; + int gdx = 64; + int gdy = 1; + int gdz = 1; + int stride_GPU = input_size / ca->world_size_; // stride base on the pass in GPU id; gpu0 focus + // on 0~15; gpu1 focus on 16~31 + int stride_TG = stride_GPU / + gdx; // stride base on TG id; 64 TGs, every TG focus on 16*8192/64=2048 elements + int stride_WV = + stride_TG / + (bdx / 64); // stride base on Wave id, 4 waves, every wave focus on 512 elements; 1024 bytes KernelArgs args; - size_t arg_size = sizeof(args); - args.ptr_gpu0_data = reinterpret_cast(input_rd); - args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); - args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); - args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); - args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); - args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); - args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); - args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); - args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); - args.gpuId = ca->rank_; - args.stride_gpu = stride_GPU; - args.stride_tg = stride_TG; - args.stride_wave = stride_WV; - args.loopcnt = 10; + size_t arg_size = sizeof(args); + args.ptr_gpu0_data = reinterpret_cast(input_rd); + args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); + args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); + args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); + args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); + args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); + args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); + args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); + args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); + args.gpuId = ca->rank_; + args.stride_gpu = stride_GPU; + args.stride_tg = stride_TG; + args.stride_wave = stride_WV; + args.loopcnt = 10; static AiterAsmKernel impl("allreduce_kernel_func", "all_reduce.co"); impl.launch_kernel({&args, @@ -102,96 +111,98 @@ torch::Tensor all_reduce_asm(torch::Tensor &input, 1, // bdy 1, // bdz stream}); - auto options = torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()); + auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); return torch::from_blob(inp_ptr, {input.sizes()}, options); } -std::tuple all_reduce_rmsnorm(torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &weight, // [1 ,n] - torch::Tensor &bias, // [1 ,n] +std::tuple all_reduce_rmsnorm(torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& weight, // [1 ,n] + torch::Tensor& bias, // [1 ,n] float epsilon, // following are fused_allreduce args int64_t _ca, - torch::Tensor ®_sig, torch::Tensor ®_buffer, bool isGraph) + torch::Tensor& reg_sig, + torch::Tensor& reg_buffer, + bool isGraph) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + auto stream = at::hip::getCurrentHIPStream(); auto size_input = input.numel() * input.element_size(); - auto size_pad = (size_input + 4095) & 0xfffff000; + auto size_pad = (size_input + 4095) & 0xfffff000; - void *inp_ptr = input.data_ptr(); + void* inp_ptr = input.data_ptr(); // reg_buffer contains input|out|res_out auto size_needed = size_pad * 3; TORCH_CHECK(size_needed <= reg_buffer.numel() * reg_buffer.element_size(), "registered buffer is too small to contain the input ", - size_needed, ">", reg_buffer.numel() * reg_buffer.element_size()); + size_needed, + ">", + reg_buffer.numel() * reg_buffer.element_size()); uint64_t out_offset = (uint64_t)size_pad; uint64_t res_offset = (uint64_t)size_pad * 2; - if (!isGraph) + if(!isGraph) { - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp_ptr, - size_input, cudaMemcpyDeviceToDevice, stream)); + HIP_CALL(hipMemcpyAsync( + reg_buffer.data_ptr(), inp_ptr, size_input, hipMemcpyDeviceToDevice, stream)); inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); + auto ca = reinterpret_cast(_ca); using RD = aiter::RankData; - RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); - RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); - RD *input_rd = ca->get_buffer_RD(stream, inp_ptr); + RD* sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); + RD* reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); + RD* input_rd = ca->get_buffer_RD(stream, inp_ptr); - void *out_ptr; - void *res_ptr; + void* out_ptr; + void* res_ptr; uint64_t gpu_bufs[8 * 4]; - for (size_t i = 0; i < ca->world_size_; i++) + for(size_t i = 0; i < ca->world_size_; i++) { - gpu_bufs[i] = reinterpret_cast(input_rd->ptrs[i]); - gpu_bufs[i + 8] = reinterpret_cast(reg_rd->ptrs[i]) + out_offset; + gpu_bufs[i] = reinterpret_cast(input_rd->ptrs[i]); + gpu_bufs[i + 8] = reinterpret_cast(reg_rd->ptrs[i]) + out_offset; gpu_bufs[i + 16] = reinterpret_cast(reg_rd->ptrs[i]) + res_offset; - if (i == ca->rank_) + if(i == ca->rank_) { - out_ptr = reinterpret_cast(gpu_bufs[i + 8]); - res_ptr = reinterpret_cast(gpu_bufs[i + 16]); + out_ptr = reinterpret_cast(gpu_bufs[i + 8]); + res_ptr = reinterpret_cast(gpu_bufs[i + 16]); } } - uint64_t *gpu_addr_buf_in; + uint64_t* gpu_addr_buf_in; uint addr_buf_size = 8 * 4 * sizeof(uint64_t); HIP_CALL(hipMalloc(&gpu_addr_buf_in, addr_buf_size)); HIP_CALL(hipMemcpy(gpu_addr_buf_in, gpu_bufs, addr_buf_size, hipMemcpyHostToDevice)); struct __attribute__((packed)) KernelArgs { - void *ptr_gpu0_data; + void* ptr_gpu0_data; p2 _p0; - void *ptr_gpu0_sig; + void* ptr_gpu0_sig; p2 _p8; - void *ptr_gpu1_sig; + void* ptr_gpu1_sig; p2 _p9; - void *ptr_gpu2_sig; + void* ptr_gpu2_sig; p2 _p10; - void *ptr_gpu3_sig; + void* ptr_gpu3_sig; p2 _p11; - void *ptr_gpu4_sig; + void* ptr_gpu4_sig; p2 _p12; - void *ptr_gpu5_sig; + void* ptr_gpu5_sig; p2 _p13; - void *ptr_gpu6_sig; + void* ptr_gpu6_sig; p2 _p14; - void *ptr_gpu7_sig; + void* ptr_gpu7_sig; p2 _p15; - void *ptr_resi_in; + void* ptr_resi_in; p2 _p1; - void *ptr_weight_in; + void* ptr_weight_in; p2 _p2; - void *ptr_bias_in; + void* ptr_bias_in; p2 _p3; - void *ptr_xscale; + void* ptr_xscale; p2 _p4; unsigned int gpuId; p3 _p16; @@ -212,25 +223,25 @@ std::tuple all_reduce_rmsnorm(torch::Tensor &input int TGs = M / ca->world_size_; KernelArgs args; - size_t arg_size = sizeof(args); - args.ptr_gpu0_data = reinterpret_cast(gpu_addr_buf_in); - args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); - args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); - args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); - args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); - args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); - args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); - args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); - args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); - args.ptr_resi_in = const_cast(residual_in.data_ptr()); - args.ptr_weight_in = const_cast(weight.data_ptr()); - args.ptr_bias_in = const_cast(bias.data_ptr()); - args.gpuId = ca->rank_; - args.stride_gpu = size_input / ca->world_size_; - args.N = N; - args.epsilon = epsilon; - args.tgs = TGs; - args.loopcnt = 0; + size_t arg_size = sizeof(args); + args.ptr_gpu0_data = reinterpret_cast(gpu_addr_buf_in); + args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); + args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); + args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); + args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); + args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); + args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); + args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); + args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); + args.ptr_resi_in = const_cast(residual_in.data_ptr()); + args.ptr_weight_in = const_cast(weight.data_ptr()); + args.ptr_bias_in = const_cast(bias.data_ptr()); + args.gpuId = ca->rank_; + args.stride_gpu = size_input / ca->world_size_; + args.N = N; + args.epsilon = epsilon; + args.tgs = TGs; + args.loopcnt = 0; static AiterAsmKernel impl("allreduce_rmsnorm_N8192_kernel", "allreduce_rmsnorm_N8192.co"); @@ -244,99 +255,102 @@ std::tuple all_reduce_rmsnorm(torch::Tensor &input 1, // bdz stream}); - auto options = torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()); + auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()); return {torch::from_blob(out_ptr, {input.sizes()}, options), torch::from_blob(res_ptr, {input.sizes()}, options)}; }; -std::tuple all_reduce_rmsnorm_quant(torch::Tensor &input, // [m ,n] - torch::Tensor &residual_in, // [m ,n] - torch::Tensor &xscale, // [1 ,n] - torch::Tensor &weight, // [1 ,n] - torch::Tensor &bias, // [1 ,n] - float epsilon, - // following are fused_allreduce args - int64_t _ca, - torch::Tensor ®_sig, torch::Tensor ®_buffer, bool isGraph) +std::tuple +all_reduce_rmsnorm_quant(torch::Tensor& input, // [m ,n] + torch::Tensor& residual_in, // [m ,n] + torch::Tensor& xscale, // [1 ,n] + torch::Tensor& weight, // [1 ,n] + torch::Tensor& bias, // [1 ,n] + float epsilon, + // following are fused_allreduce args + int64_t _ca, + torch::Tensor& reg_sig, + torch::Tensor& reg_buffer, + bool isGraph) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + auto stream = at::hip::getCurrentHIPStream(); auto size_input = input.numel() * input.element_size(); - auto size_pad = (size_input + 4095) & 0xfffff000; + auto size_pad = (size_input + 4095) & 0xfffff000; - void *inp_ptr = input.data_ptr(); + void* inp_ptr = input.data_ptr(); // reg_buffer contains input|out|res_out auto size_needed = size_pad * 4; TORCH_CHECK(size_needed <= reg_buffer.numel() * reg_buffer.element_size(), "registered buffer is too small to contain the input ", - size_needed, ">", reg_buffer.numel() * reg_buffer.element_size()); + size_needed, + ">", + reg_buffer.numel() * reg_buffer.element_size()); - if (!isGraph) + if(!isGraph) { - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp_ptr, - size_input, cudaMemcpyDeviceToDevice, stream)); + HIP_CALL(hipMemcpyAsync( + reg_buffer.data_ptr(), inp_ptr, size_input, hipMemcpyDeviceToDevice, stream)); inp_ptr = reg_buffer.data_ptr(); } - auto ca = reinterpret_cast(_ca); + auto ca = reinterpret_cast(_ca); using RD = aiter::RankData; - RD *sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); - RD *reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); - RD *input_rd = ca->get_buffer_RD(stream, inp_ptr); + RD* sig_rd = ca->get_buffer_RD(stream, reg_sig.data_ptr()); + RD* reg_rd = ca->get_buffer_RD(stream, reg_buffer.data_ptr()); + RD* input_rd = ca->get_buffer_RD(stream, inp_ptr); - void *out_ptr; - void *res_ptr; - void *ys_ptr; + void* out_ptr; + void* res_ptr; + void* ys_ptr; uint64_t gpu_bufs[8 * 4]; - for (size_t i = 0; i < ca->world_size_; i++) + for(size_t i = 0; i < ca->world_size_; i++) { - gpu_bufs[i] = reinterpret_cast(input_rd->ptrs[i]); - gpu_bufs[i + 8] = reinterpret_cast(reg_rd->ptrs[i]) + size_pad; + gpu_bufs[i] = reinterpret_cast(input_rd->ptrs[i]); + gpu_bufs[i + 8] = reinterpret_cast(reg_rd->ptrs[i]) + size_pad; gpu_bufs[i + 16] = reinterpret_cast(reg_rd->ptrs[i]) + size_pad * 2; gpu_bufs[i + 24] = reinterpret_cast(reg_rd->ptrs[i]) + size_pad * 3; - if (i == ca->rank_) + if(i == ca->rank_) { - out_ptr = reinterpret_cast(gpu_bufs[i + 8]); - res_ptr = reinterpret_cast(gpu_bufs[i + 16]); - ys_ptr = reinterpret_cast(gpu_bufs[i + 24]); + out_ptr = reinterpret_cast(gpu_bufs[i + 8]); + res_ptr = reinterpret_cast(gpu_bufs[i + 16]); + ys_ptr = reinterpret_cast(gpu_bufs[i + 24]); } } - uint64_t *gpu_addr_buf_in; + uint64_t* gpu_addr_buf_in; uint addr_buf_size = 8 * 4 * sizeof(uint64_t); HIP_CALL(hipMalloc(&gpu_addr_buf_in, addr_buf_size)); HIP_CALL(hipMemcpy(gpu_addr_buf_in, gpu_bufs, addr_buf_size, hipMemcpyHostToDevice)); struct __attribute__((packed)) KernelArgs { - void *ptr_gpu0_data; + void* ptr_gpu0_data; p2 _p0; - void *ptr_gpu0_sig; + void* ptr_gpu0_sig; p2 _p8; - void *ptr_gpu1_sig; + void* ptr_gpu1_sig; p2 _p9; - void *ptr_gpu2_sig; + void* ptr_gpu2_sig; p2 _p10; - void *ptr_gpu3_sig; + void* ptr_gpu3_sig; p2 _p11; - void *ptr_gpu4_sig; + void* ptr_gpu4_sig; p2 _p12; - void *ptr_gpu5_sig; + void* ptr_gpu5_sig; p2 _p13; - void *ptr_gpu6_sig; + void* ptr_gpu6_sig; p2 _p14; - void *ptr_gpu7_sig; + void* ptr_gpu7_sig; p2 _p15; - void *ptr_resi_in; + void* ptr_resi_in; p2 _p1; - void *ptr_weight_in; + void* ptr_weight_in; p2 _p2; - void *ptr_bias_in; + void* ptr_bias_in; p2 _p3; - void *ptr_xscale; + void* ptr_xscale; p2 _p4; unsigned int gpuId; p3 _p16; @@ -357,28 +371,29 @@ std::tuple all_reduce_rmsnorm_quant int TGs = M / ca->world_size_; KernelArgs args; - size_t arg_size = sizeof(args); - args.ptr_gpu0_data = reinterpret_cast(gpu_addr_buf_in); - args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); - args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); - args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); - args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); - args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); - args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); - args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); - args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); - args.ptr_resi_in = const_cast(residual_in.data_ptr()); - args.ptr_weight_in = const_cast(weight.data_ptr()); - args.ptr_bias_in = const_cast(bias.data_ptr()); - args.ptr_xscale = xscale.data_ptr(); - args.gpuId = ca->rank_; - args.stride_gpu = size_input / ca->world_size_; - args.N = N; - args.epsilon = epsilon; - args.tgs = TGs; - args.loopcnt = 0; - - static AiterAsmKernel impl("allreduce_rmsnorm_qnt_N8192_kernel", "allreduce_rmsnorm_qnt_N8192.co"); + size_t arg_size = sizeof(args); + args.ptr_gpu0_data = reinterpret_cast(gpu_addr_buf_in); + args.ptr_gpu0_sig = const_cast(sig_rd->ptrs[0]); + args.ptr_gpu1_sig = const_cast(sig_rd->ptrs[1]); + args.ptr_gpu2_sig = const_cast(sig_rd->ptrs[2]); + args.ptr_gpu3_sig = const_cast(sig_rd->ptrs[3]); + args.ptr_gpu4_sig = const_cast(sig_rd->ptrs[4]); + args.ptr_gpu5_sig = const_cast(sig_rd->ptrs[5]); + args.ptr_gpu6_sig = const_cast(sig_rd->ptrs[6]); + args.ptr_gpu7_sig = const_cast(sig_rd->ptrs[7]); + args.ptr_resi_in = const_cast(residual_in.data_ptr()); + args.ptr_weight_in = const_cast(weight.data_ptr()); + args.ptr_bias_in = const_cast(bias.data_ptr()); + args.ptr_xscale = xscale.data_ptr(); + args.gpuId = ca->rank_; + args.stride_gpu = size_input / ca->world_size_; + args.N = N; + args.epsilon = epsilon; + args.tgs = TGs; + args.loopcnt = 0; + + static AiterAsmKernel impl("allreduce_rmsnorm_qnt_N8192_kernel", + "allreduce_rmsnorm_qnt_N8192.co"); impl.launch_kernel({&args, &arg_size, @@ -390,15 +405,9 @@ std::tuple all_reduce_rmsnorm_quant 1, // bdz stream}); - auto opt_out = torch::TensorOptions() - .dtype(torch::kInt8) - .device(input.device()); - auto opt_res = torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()); - auto opt_ys = torch::TensorOptions() - .dtype(torch::kFloat32) - .device(input.device()); + auto opt_out = torch::TensorOptions().dtype(torch::kInt8).device(input.device()); + auto opt_res = torch::TensorOptions().dtype(input.dtype()).device(input.device()); + auto opt_ys = torch::TensorOptions().dtype(torch::kFloat32).device(input.device()); return { torch::from_blob(out_ptr, {input.sizes()}, opt_out), torch::from_blob(res_ptr, {input.sizes()}, opt_res), diff --git a/csrc/py_itfs_cu/asm_flatmm_a8w8_blockscale.cu b/csrc/py_itfs_cu/asm_flatmm_a8w8_blockscale.cu index 3bac0f307f..d7edb91eba 100644 --- a/csrc/py_itfs_cu/asm_flatmm_a8w8_blockscale.cu +++ b/csrc/py_itfs_cu/asm_flatmm_a8w8_blockscale.cu @@ -3,8 +3,8 @@ #include #include #include -#include -#include +#include +#include #include "aiter_hip_common.h" #include "hip_float8.h" @@ -68,8 +68,8 @@ torch::Tensor flatmm_a8w8_blockscale_asm( args.intermediate_size = n; args.hidden_size = k; - const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); AiterAsmKernel *impl_ptr = nullptr; static AiterAsmKernel impl_kenrel("flatmm_uk_gfx9_f16f8_128x256x128_1x4x1_16x16x32", "flatmm_uk_gfx9_f16f8_128x256x128_1x4x1_16x16x32.co"); diff --git a/csrc/py_itfs_cu/asm_fmoe.cu b/csrc/py_itfs_cu/asm_fmoe.cu index 714131239a..a31552ba08 100755 --- a/csrc/py_itfs_cu/asm_fmoe.cu +++ b/csrc/py_itfs_cu/asm_fmoe.cu @@ -4,8 +4,8 @@ #include "asm_fmoe_configs.hpp" #include "moe_op.h" #include "py_itfs_common.h" -#include -#include +#include +#include #include #include #include @@ -236,8 +236,8 @@ class FMoeKernel // std::cout << "gdx: " << gdx << std::endl; // std::cout << "gdy: " << gdy << std::endl; - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); if constexpr(switchGxy) { HIP_CALL(hipModuleLaunchKernel( diff --git a/csrc/py_itfs_cu/asm_gemm_a16w16.cu b/csrc/py_itfs_cu/asm_gemm_a16w16.cu index 3bac562e39..9c4580cd25 100644 --- a/csrc/py_itfs_cu/asm_gemm_a16w16.cu +++ b/csrc/py_itfs_cu/asm_gemm_a16w16.cu @@ -3,8 +3,8 @@ #include "aiter_hip_common.h" #include "asm_bf16gemm_configs.hpp" #include "py_itfs_common.h" -#include -#include +#include +#include #include #include #include @@ -252,8 +252,8 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16 TORCH_CHECK(false, __func__, " not find kernel~ " + selectedKernelName); // 3. launch kl - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(A)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int bdx = 256; int gdx = (Ndim + SUBN - 1) / SUBN; diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index 189e7fc1ee..5c5e279658 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -3,8 +3,8 @@ #include "aiter_hip_common.h" #include "asm_f4gemm_configs.hpp" #include "py_itfs_common.h" -#include -#include +#include +#include #include #include #include @@ -195,8 +195,8 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2 args.stride_ScaleB0 = B_scale.stride(0); args.log2_k_split = 0; - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(A)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); CFG* config_map = get_cfg(A, out); using DictKey = std::tuple, std::optional>; struct SimpleHash diff --git a/csrc/py_itfs_cu/asm_gemm_a8w8.cu b/csrc/py_itfs_cu/asm_gemm_a8w8.cu index 45f0c29a54..09bbdc1b43 100644 --- a/csrc/py_itfs_cu/asm_gemm_a8w8.cu +++ b/csrc/py_itfs_cu/asm_gemm_a8w8.cu @@ -3,8 +3,8 @@ #include #include #include -#include -#include +#include +#include #include "aiter_hip_common.h" // start to prepare the input and output buffer @@ -80,8 +80,8 @@ torch::Tensor gemm_a8w8_asm(torch::Tensor &A, // A:[M, K] i8 args.ldc = stride_c; args.ks = ks; - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(A)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); static AiterAsmKernel splitK_impl("gemm_kernel_func", "gemm_a8w8_m128_splitK.co"); static AiterAsmKernel noSplitK_impl("gemm_kernel_func", "gemm_a8w8_m128_noSplitK.co"); AiterAsmKernel *impl_ptr = &noSplitK_impl; diff --git a/csrc/py_itfs_cu/asm_layernorm.cu b/csrc/py_itfs_cu/asm_layernorm.cu index ae1578ad91..bb215feeb6 100644 --- a/csrc/py_itfs_cu/asm_layernorm.cu +++ b/csrc/py_itfs_cu/asm_layernorm.cu @@ -3,8 +3,8 @@ #include #include #include -#include -#include +#include +#include #include "aiter_hip_common.h" struct __attribute__((packed)) KernelArgs @@ -66,8 +66,8 @@ void layernorm2d_with_add_asm(torch::Tensor &out, // [m ,n] args.ptr_OutResidual = residual_out.data_ptr(); args.ptr_InResidual = residual_in.data_ptr(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int sub_M = 2; static AiterAsmKernel impl("layer_norm_kernel_func", "layer_norm.co"); @@ -119,8 +119,8 @@ void layernorm2d_with_add_smoothquant_asm(torch::Tensor &out, // [m ,n] args.ptr_OutYScale = yscale.data_ptr(); args.ptr_XScale = xscale.data_ptr(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); int sub_M = 2; static AiterAsmKernel impl("layer_norm_qnt", "layer_norm_qnt.co"); diff --git a/csrc/py_itfs_cu/asm_mha_bwd.cu b/csrc/py_itfs_cu/asm_mha_bwd.cu index 69a83abc58..01efd1a46f 100644 --- a/csrc/py_itfs_cu/asm_mha_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_bwd.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -203,7 +203,7 @@ std::vector fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -295,7 +295,7 @@ std::vector fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h dv = torch::empty_like(v); } - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); diff --git a/csrc/py_itfs_cu/asm_mha_fwd.cu b/csrc/py_itfs_cu/asm_mha_fwd.cu index 39878ec1b5..62354ded84 100644 --- a/csrc/py_itfs_cu/asm_mha_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_fwd.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -245,7 +245,7 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; bool has_lse = return_softmax_lse; bool has_dropout = p_dropout > 0.0f; @@ -283,7 +283,7 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); ck_tile::stream_config stream_config{stream}; auto args = diff --git a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu index d0e1eb112c..81e36d4025 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -231,7 +231,7 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v if (is_causal) { window_size_right = 0; } bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, @@ -330,7 +330,7 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v dv = torch::empty_like(v); } - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); diff --git a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu index 9cc0b694ca..07cbf20a08 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu @@ -2,7 +2,7 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include -#include +#include #include "py_itfs_common.h" #include "mha_common.h" @@ -297,7 +297,7 @@ fmha_v3_varlen_fwd(at::Tensor &q, // [total_q, hq, d] } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; bool has_lse = return_softmax_lse; bool has_dropout = p_dropout > 0.0f; @@ -344,7 +344,7 @@ fmha_v3_varlen_fwd(at::Tensor &q, // [total_q, hq, d] std::optional seqlens_k = std::nullopt; if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentHIPStream().stream(); + auto stream = at::hip::getCurrentHIPStream(); ck_tile::stream_config stream_config{stream}; TORCH_CHECK(cu_seqlens_k.has_value(), "cu_seqlens_k must be provided if paged_KV is false"); diff --git a/csrc/py_itfs_cu/asm_mi350_a8w8_blockscale.cu b/csrc/py_itfs_cu/asm_mi350_a8w8_blockscale.cu index e29fb6289c..6ce53d1017 100644 --- a/csrc/py_itfs_cu/asm_mi350_a8w8_blockscale.cu +++ b/csrc/py_itfs_cu/asm_mi350_a8w8_blockscale.cu @@ -3,8 +3,8 @@ #include #include #include -#include -#include +#include +#include #include "aiter_hip_common.h" #include "hip_float8.h" @@ -101,8 +101,8 @@ torch::Tensor mi350_a8w8_blockscale_asm( args.Cs = n * 2; args.splitk = 0; args.activation = 0; - const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(XQ)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); // printf("ptr_X: %p\n", args.ptr_X); // printf("ptr_GU: %p\n", args.ptr_GU); // printf("ptr_XQ: %p\n", args.ptr_XQ); diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index b10ce26ad6..5961ff0649 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "aiter_hip_common.h" -#include -#include +#include +#include #include #include #include @@ -98,8 +98,8 @@ void mla_decode_stage1_asm_fwd( // std::cout << "s_log2_plen: " << args.s_log2_plen << std::endl; // std::cout << "ptr_QTP: " << args.ptr_QTP << std::endl; - const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); AiterAsmKernel* impl_ptr = nullptr; TORCH_CHECK(Q.is_contiguous(), __func__, ":only support Q.is_contiguous() for now"); @@ -203,8 +203,8 @@ void mla_prefill_asm_fwd( args.s_Bs = stride_Page; args.s_log2_plen = log2_page; - const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); AiterAsmKernel* impl_ptr = nullptr; TORCH_CHECK(Q.is_contiguous(), __func__, ":only support Q.is_contiguous() for now"); diff --git a/csrc/py_itfs_cu/asm_moe_2stage.cu b/csrc/py_itfs_cu/asm_moe_2stage.cu index b9f3444649..aa2e54d945 100644 --- a/csrc/py_itfs_cu/asm_moe_2stage.cu +++ b/csrc/py_itfs_cu/asm_moe_2stage.cu @@ -3,8 +3,8 @@ #include #include #include -#include -#include +#include +#include #include "aiter_hip_common.h" #include "moe_op.h" #include "asm_moe_2stage_configs.hpp" @@ -163,8 +163,8 @@ void moe_stage1_g1u1( std::optional sorted_weights = std::nullopt // [max_num_tokens_padded], do_weight==true need ) { - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); CFG *config_map = get_cfg(input, out, w1, quant_type, sorted_weights.has_value()); static std::unordered_map> impl_ptr_map; diff --git a/csrc/py_itfs_cu/asm_pa.cu b/csrc/py_itfs_cu/asm_pa.cu index f4490640e2..d80db86062 100644 --- a/csrc/py_itfs_cu/asm_pa.cu +++ b/csrc/py_itfs_cu/asm_pa.cu @@ -3,8 +3,8 @@ #include "aiter_hip_common.h" #include "asm_pa_configs.hpp" #include "py_itfs_common.h" -#include -#include +#include +#include #include #include #include @@ -156,8 +156,8 @@ torch::Tensor pa_fwd(torch::Tensor& Q, // [num_seqs, num_heads, head_size] // << " kv_nheads:" << args.kv_nheads << " Qs:" << args.Qs << " Bs:" << args.Bs // << " KVs:" << args.KVs << std::endl; - const at::cuda::OptionalCUDAGuard device_guard(device_of(Q)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(Q)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); std::string q_type; std::string kv_type; diff --git a/csrc/py_itfs_cu/asm_topksoftmax.cu b/csrc/py_itfs_cu/asm_topksoftmax.cu index 86597f6ff3..b82f2df13e 100644 --- a/csrc/py_itfs_cu/asm_topksoftmax.cu +++ b/csrc/py_itfs_cu/asm_topksoftmax.cu @@ -2,8 +2,8 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "aiter_hip_common.h" #include "py_itfs_common.h" -#include -#include +#include +#include #include struct __attribute__((packed)) KernelArgs @@ -122,8 +122,8 @@ void topk_softmax_asm(torch::Tensor& topk_weights, // [num_tokens, topk] topk); } - const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); uint gdx = (num_tokens + SUBM - 1) / SUBM; TORCH_CHECK(gdx >> 31 == 0, "num_tokens too large: ", num_tokens); diff --git a/csrc/py_itfs_cu/custom.cu b/csrc/py_itfs_cu/custom.cu index 9378ba1f33..0cbc02655e 100644 --- a/csrc/py_itfs_cu/custom.cu +++ b/csrc/py_itfs_cu/custom.cu @@ -16,9 +16,9 @@ */ #include #include -#include -#include -#include +#include +#include +#include #include "py_itfs_common.h" namespace aiter { @@ -28,13 +28,13 @@ namespace aiter { // void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int // K, -// cudaStream_t stream, const int rows_per_block); +// hipStream_t stream, const int rows_per_block); // void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, // const int64_t rows_per_block) { // auto M = in_a.size(0); // auto K = in_a.size(1); // LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, -// at::cuda::getCurrentCUDAStream(), rows_per_block); +// at::hip::getCurrentHIPStream(), rows_per_block); // } void LLGemm1(void* in_a, @@ -42,7 +42,7 @@ void LLGemm1(void* in_a, void* out_c, const int M, const int K, - cudaStream_t stream, + hipStream_t stream, const int rows_per_block = 4, const c10::ScalarType scalar_type = c10::ScalarType::Half); // template @@ -63,13 +63,13 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t TORCH_CHECK(in_b.dtype() == torch::kFloat16 || in_b.dtype() == torch::kBFloat16); // call the kernel function... - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), + at::hip::getCurrentHIPStream(), rows_per_block, in_b.scalar_type()); } @@ -80,7 +80,7 @@ void wvSplitK_(void* in_a, const int M, const int K, const int N, - cudaStream_t stream, + hipStream_t stream, const int CuCount = 1, const c10::ScalarType scalar_type = c10::ScalarType::Half); void wvSpltK(at::Tensor& in_a, @@ -96,14 +96,14 @@ void wvSpltK(at::Tensor& in_a, TORCH_CHECK(K % 8 == 0, "k % 8 == 0"); TORCH_CHECK(in_a.dtype() == torch::kFloat16 || in_a.dtype() == torch::kBFloat16); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); wvSplitK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), + at::hip::getCurrentHIPStream(), CuCount, in_b.scalar_type()); } @@ -114,7 +114,7 @@ void wv_splitk_small_fp16_bf16(void* in_a, const int M, const int K, const int N, - cudaStream_t stream, + hipStream_t stream, const int CuCount = 1, const c10::ScalarType scalar_type = c10::ScalarType::Half); void wv_splitk_small_fp16_bf16_wrapper(at::Tensor& in_a, @@ -130,14 +130,14 @@ void wv_splitk_small_fp16_bf16_wrapper(at::Tensor& in_a, TORCH_CHECK(K % 8 == 0, "k % 8 == 0"); TORCH_CHECK(in_a.dtype() == torch::kFloat16 || in_a.dtype() == torch::kBFloat16); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); wv_splitk_small_fp16_bf16(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), + at::hip::getCurrentHIPStream(), CuCount, in_b.scalar_type()); } @@ -151,7 +151,7 @@ void wvSplitKQ_(void* in_a, const int K, const int Kp, const int N, - cudaStream_t stream, + hipStream_t stream, const int CuCount = 1, const c10::ScalarType a_scalar_type = c10::ScalarType::Float8_e4m3fnuz, const c10::ScalarType c_scalar_type = c10::ScalarType::Half); @@ -172,7 +172,7 @@ void wvSplitKQ(at::Tensor& in_a, auto scale_a_ptr = scale_a.data_ptr(); auto scale_b_ptr = scale_b.data_ptr(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); wvSplitKQ_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), @@ -182,7 +182,7 @@ void wvSplitKQ(at::Tensor& in_a, K, Kp, N, - at::cuda::getCurrentCUDAStream(), + at::hip::getCurrentHIPStream(), CuCount, in_a.scalar_type(), out_c.scalar_type()); @@ -193,20 +193,20 @@ void LLGemmZZ(void* in_a, void* out_c, const int M, const int K, - cudaStream_t stream, + hipStream_t stream, const int solidx); void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int64_t solidx = 0) { auto M = in_a.size(0); auto K = in_a.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, - at::cuda::getCurrentCUDAStream(), + at::hip::getCurrentHIPStream(), solidx); } // instantiate the CPP template for T=float: @@ -222,14 +222,14 @@ void MMGPUKernel(float* in_a, int numBColumns, int numCRows, int numCColumns, - cudaStream_t stream); + hipStream_t stream); void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { auto matA_sizes{in_a.sizes()}; auto matB_sizes{in_b.sizes()}; auto matO_sizes{out_c.sizes()}; - const at::cuda::OptionalCUDAGuard device_guard(device_of(in_a)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(in_a)); MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), @@ -239,6 +239,6 @@ void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) matB_sizes[1], matO_sizes[0], matO_sizes[1], - at::cuda::getCurrentCUDAStream()); + at::hip::getCurrentHIPStream()); } } // namespace aiter diff --git a/gradlib/csrc/grad_funcs.cu b/gradlib/csrc/grad_funcs.cu index 7d4d10ee80..39f294816a 100644 --- a/gradlib/csrc/grad_funcs.cu +++ b/gradlib/csrc/grad_funcs.cu @@ -1,37 +1,35 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. // #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and +// not for others // // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h // #undef __HIP_NO_HALF_OPERATORS__ // #undef __HIP_NO_HALF_CONVERSIONS__ // #endif -#include -#include #include #include -#include -#include -#include -#include -// #include +#include +#include +#include + +#include +#include #include #include #include -#include #include // #include #include +#include #include #include #include #include #include -#include -#include "nvToolsExt.h" // #ifdef USE_ROCM // #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) @@ -39,62 +37,58 @@ // #endif // #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL -// #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL #ifdef +// ROCM_BACKWARD_PASS_GUARD flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if (error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } +#define CHECK_HIP_ERROR(error) \ + if(error != hipSuccess) \ + { \ + fprintf(stderr, \ + "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } #endif #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if (error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ - } +#define CHECK_HIPBLAS_ERROR(error) \ + if(error != HIPBLAS_STATUS_SUCCESS) \ + { \ + fprintf(stderr, \ + "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } #endif -namespace +namespace { +/*thread_local*/ hipStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all devices? +// C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ hipEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +uint64_t workspace_size = 32 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; - - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - uint64_t workspace_size = 32 * 1024 * 1024; - // uint64_t workspace_size = 0; - void *d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig - { hipblasOperation_t op_A; hipblasOperation_t op_B; int M; @@ -102,356 +96,398 @@ namespace int K; hipblasDatatype_t dtype; - friend auto operator<(const MatMulConfig &left, const MatMulConfig &right) -> bool + friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); } - }; +}; - // std::map, std::vector> heuristic_map; - std::map heuristic_map; +// std::map, std::vector> +// heuristic_map; +std::map heuristic_map; - hipEvent_t start, stop; - int bench_iters{1}; - int warmup_iters{1}; +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; - bool cout_print = true; -} +bool cout_print = true; +} // namespace ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call */ -hipblasStatus_t hipblasLtMatmul_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipblasDatatype_t dtype, - hipStream_t &stream) +hipblasStatus_t hipblasLtMatmul_wrapper(hipblasLtHandle_t handle, + hipblasOperation_t op_A, + hipblasOperation_t op_B, + int m, + int n, + int k, + const void* alpha, + const void* a, + int lda, + const void* b, + int ldb, + const void* beta, + void* c, + int ldc, + hipblasDatatype_t dtype, + hipStream_t& stream) { - // TODO: flag is not supported for hipblasLt yet - int flag{0}; - if (dtype == HIPBLAS_R_16F) - { - // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - flag = rocblas_gemm_flags_fp16_alt_impl; - } - - nvtxRangePushA("hipBLASLt variables creation"); - hipblasLtMatrixLayout_t matA, matB, matC; - hipblasLtMatmulDesc_t matmul; - if (op_A == HIPBLAS_OP_N) - { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); - } - else - { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); - } - if (op_B == HIPBLAS_OP_N) - { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); - } - else - { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); - } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( - matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - nvtxRangePop(); - - // if heuristic does not exist in the map, do search and push into the map - auto gemm_key{MatMulConfig{op_A, op_B, m, n, k, dtype}}; - if (heuristic_map.count(gemm_key) <= 0) - { - nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - if (cout_print) + // TODO: flag is not supported for hipblasLt yet + int flag{0}; + if(dtype == HIPBLAS_R_16F) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + // use fp16 alt impl for MI200 + // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + flag = rocblas_gemm_flags_fp16_alt_impl; } - std::vector heuristicResult(request_solutions); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if ((returnedAlgoCount != request_solutions) && cout_print) + + nvtxRangePushA("hipBLASLt variables creation"); + hipblasLtMatrixLayout_t matA, matB, matC; + hipblasLtMatmulDesc_t matmul; + if(op_A == HIPBLAS_OP_N) { - std::cout << "less solution found! request: " << request_solutions - << ", found: " << returnedAlgoCount << std::endl; + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); } - - if (returnedAlgoCount == 1) + else + { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + } + if(op_B == HIPBLAS_OP_N) { - heuristic_map[gemm_key] = heuristicResult[0]; + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); } else { - // benchmark requested solutions and pick best one - int bestIndex{-1}; - double bestMs{std::numeric_limits::max()}; - for (int sol{0}; sol < returnedAlgoCount; ++sol) - { - // warm up - for (int iter{0}; iter < warmup_iters; ++iter) - { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter{0}; iter < bench_iters; ++iter) + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + } + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + nvtxRangePop(); + + // if heuristic does not exist in the map, do search and push into the map + auto gemm_key{MatMulConfig{op_A, op_B, m, n, k, dtype}}; + if(heuristic_map.count(gemm_key) <= 0) + { + nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + if(cout_print) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") + << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype + << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " + << std::endl; } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) + std::vector heuristicResult(request_solutions); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle, + matmul, + matA, + matB, + matC, + matC, + preference, + request_solutions, + heuristicResult.data(), + &returnedAlgoCount)); + if((returnedAlgoCount != request_solutions) && cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; + std::cout << "less solution found! request: " << request_solutions + << ", found: " << returnedAlgoCount << std::endl; } - if (bestMs > eventMs) + + if(returnedAlgoCount == 1) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) - { - std::cout << " *" << std::endl; - } + heuristic_map[gemm_key] = heuristicResult[0]; } else { - if (cout_print) - { - std::cout << std::endl; - } + // benchmark requested solutions and pick best one + int bestIndex{-1}; + double bestMs{std::numeric_limits::max()}; + for(int sol{0}; sol < returnedAlgoCount; ++sol) + { + // warm up + for(int iter{0}; iter < warmup_iters; ++iter) + { + CHECK_HIPBLAS_ERROR( + hipblasLtMatmul(handle, + matmul, + alpha, + a, + matA, + b, + matB, + beta, + c, + matC, + c, + matC, // In case beta != 0, these runs can overwrite the + // values in c since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, + workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for(int iter{0}; iter < bench_iters; ++iter) + { + CHECK_HIPBLAS_ERROR( + hipblasLtMatmul(handle, + matmul, + alpha, + a, + matA, + b, + matB, + beta, + c, + matC, + c, + matC, // In case beta != 0, these runs can overwrite the + // values in c since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, + workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if(cout_print) + { + std::cout << " Sol " << sol << ": average time per iter " + << std::to_string(eventMs) << " ms"; + } + if(bestMs > eventMs) + { + bestMs = eventMs; + bestIndex = sol; + if(cout_print) + { + std::cout << " *" << std::endl; + } + } + else + { + if(cout_print) + { + std::cout << std::endl; + } + } + } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; + nvtxRangePop(); } + + hipblasStatus_t status = hipblasLtMatmul(handle, + matmul, + alpha, + a, + matA, + b, + matB, + beta, + c, + matC, + c, + matC, + &heuristic_map[gemm_key].algo, + d_workspace, + workspace_size, + stream); + + nvtxRangePushA("hipBLASLt variables deletion"); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); nvtxRangePop(); - } - - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristic_map[gemm_key].algo, - d_workspace, workspace_size, - stream); - - nvtxRangePushA("hipBLASLt variables deletion"); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - nvtxRangePop(); - - return status; + + return status; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -torch::Tensor hipBLASLtMm_( - const torch::Tensor &mat1, - const torch::Tensor &mat2) +torch::Tensor hipBLASLtMm_(const torch::Tensor& mat1, const torch::Tensor& mat2) { - auto mat1_strides{mat1.strides()}; - auto mat2_strides{mat2.strides()}; - auto mat1_sizes{mat1.sizes()}; - auto mat2_sizes{mat2.sizes()}; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType{mat1.options().dtype()}; - auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; - auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; - - bool transpose_result = true; - bool transpose_mat1; - bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) - { - transpose_mat2 = false; - } - else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) - { - transpose_mat2 = true; - } - else - { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) - { - transpose_mat1 = false; - } - else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) - { - transpose_mat1 = true; - } - else - { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) - { - bool tmp = transpose_mat1; - transpose_mat1 = !transpose_mat2; - transpose_mat2 = !tmp; - mat1_strides = mat2.strides(); - mat2_strides = mat1.strides(); - mat1_sizes = mat2.sizes(); - mat2_sizes = mat1.sizes(); - } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one{1.0f}; - float zero{0.0f}; - int64_t m = mat1_sizes[transpose_result ? 1 : 0]; - int64_t k = mat1_sizes[transpose_result ? 0 : 1]; - int64_t n = mat2_sizes[transpose_result ? 0 : 1]; - int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; - int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; - int64_t result_ld = result.stride(transpose_result ? 0 : 1); - // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; - - int flag{0}; - hipblasDatatype_t hipblasType; - if (abcType == at::kHalf) - { - hipblasType = HIPBLAS_R_16F; - } - else if (abcType == at::kBFloat16) - { - hipblasType = HIPBLAS_R_16B; - } - else if (abcType == at::kFloat) - { - hipblasType = HIPBLAS_R_32F; - } - else - { - assert(false && "Wrong datatype!"); - } - - void *ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; - void *ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; - void *ptrC{static_cast(result.data_ptr())}; - - auto current_stream{torch::hip::getCurrentHIPStream().stream()}; - - CHECK_HIPBLAS_ERROR(hipblasLtMatmul_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream)); - - return result; + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << + // std::endl; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), + " != ", + mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << + // std::endl; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) + { + transpose_mat2 = false; + } + else if((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) + { + transpose_mat2 = true; + } + else + { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) + { + transpose_mat1 = false; + } + else if((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) + { + transpose_mat1 = true; + } + else + { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if(transpose_result) + { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl + // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << + // std::endl; + + int flag{0}; + hipblasDatatype_t hipblasType; + if(abcType == at::kHalf) + { + hipblasType = HIPBLAS_R_16F; + } + else if(abcType == at::kBFloat16) + { + hipblasType = HIPBLAS_R_16B; + } + else if(abcType == at::kFloat) + { + hipblasType = HIPBLAS_R_32F; + } + else + { + assert(false && "Wrong datatype!"); + } + + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_wrapper(hipblaslt_handle, + transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, + n, + k, + &one, + ptrA, + mat1_ld, + ptrB, + mat2_ld, + &zero, + ptrC, + result_ld, + hipblasType, + current_stream)); + + return result; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// void create_extension() { - CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); - - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); - CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - CHECK_HIP_ERROR(hipEventCreate(&start)); - CHECK_HIP_ERROR(hipEventCreate(&stop)); + CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); + CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulPreferenceSetAttribute(preference, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(workspace_size))); + + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// void destroy_extension() { - CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - CHECK_HIP_ERROR(hipEventDestroy(event)); + CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + CHECK_HIP_ERROR(hipEventDestroy(event)); - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); - CHECK_HIP_ERROR(hipEventDestroy(start)); - CHECK_HIP_ERROR(hipEventDestroy(stop)); + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("create_extension", &create_extension, "create_extension"); - m.def("destroy_extension", &destroy_extension, "destroy_extension"); - m.def("mm", &hipBLASLtMm_, "mm"); + m.def("create_extension", &create_extension, "create_extension"); + m.def("destroy_extension", &destroy_extension, "destroy_extension"); + m.def("mm", &hipBLASLtMm_, "mm"); } diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index 37a390de82..cbd83879cf 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. // #ifdef __gfx908__ // // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below // just for gfx908 and not for others @@ -46,11 +46,11 @@ namespace { - /*thread_local*/ cudaStream_t weight_stream; + /*thread_local*/ hipStream_t weight_stream; // BUG: DLM has event and stream on different devices error // In multi-GPU scenerio, do names defined in this namespace exist on all // devices? C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; + /*thread_local*/ hipEvent_t event; // hipBLASLt hipblasLtHandle_t hipblaslt_handle; diff --git a/gradlib/csrc/rocsolgemm.cu b/gradlib/csrc/rocsolgemm.cu index 5b1fedc4c6..c7b3c9a855 100644 --- a/gradlib/csrc/rocsolgemm.cu +++ b/gradlib/csrc/rocsolgemm.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. // #ifdef __gfx908__ // // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others // // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h @@ -56,11 +56,11 @@ namespace { rocblas_handle r_handle; - /*thread_local*/ cudaStream_t weight_stream; + /*thread_local*/ hipStream_t weight_stream; // BUG: DLM has event and stream on different devices error // In multi-GPU scenerio, do names defined in this namespace exist on all devices? // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; + /*thread_local*/ hipEvent_t event; // hipBLASLt hipblasLtHandle_t hipblaslt_handle; diff --git a/gradlib/include/hipbsolgemm.cuh b/gradlib/include/hipbsolgemm.cuh index f457228e9d..49f1c19b31 100644 --- a/gradlib/include/hipbsolgemm.cuh +++ b/gradlib/include/hipbsolgemm.cuh @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. // #ifdef __gfx908__ // // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below // just for gfx908 and not for others @@ -7,51 +7,47 @@ // default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef // __HIP_NO_HALF_CONVERSIONS__ #endif -#include -#include #include #include -#include -#include -#include -#include -// #include +#include +#include +#include +#include #include #include #include -#include #include -#include #include +#include -#include #include +#include +#include #include #include #include #include -#include -#include "nvToolsExt.h" void hipb_create_extension(); void hipb_destroy_extension(); -torch::Tensor hipb_mm(const torch::Tensor &mat1, const torch::Tensor &mat2, +torch::Tensor hipb_mm(const torch::Tensor& mat1, + const torch::Tensor& mat2, const int solution_index, - std::optional bias = std::nullopt, + std::optional bias = std::nullopt, std::optional out_dtype = std::nullopt, - std::optional scaleA = std::nullopt, - std::optional scaleB = std::nullopt, - std::optional scaleOut = std::nullopt); - -std::vector hipb_findallsols( - const torch::Tensor &mat1, const torch::Tensor &mat2, - std::optional bias = std::nullopt, - std::optional out_dtype = std::nullopt, - std::optional scaleA = std::nullopt, - std::optional scaleB = std::nullopt, - std::optional scaleC = std::nullopt); - -std::string getHipblasltKernelName(int solution_index); \ No newline at end of file + std::optional scaleA = std::nullopt, + std::optional scaleB = std::nullopt, + std::optional scaleOut = std::nullopt); + +std::vector hipb_findallsols(const torch::Tensor& mat1, + const torch::Tensor& mat2, + std::optional bias = std::nullopt, + std::optional out_dtype = std::nullopt, + std::optional scaleA = std::nullopt, + std::optional scaleB = std::nullopt, + std::optional scaleC = std::nullopt); + +std::string getHipblasltKernelName(int solution_index); diff --git a/gradlib/include/rocsolgemm.cuh b/gradlib/include/rocsolgemm.cuh index abd71ebbd6..07d2901ad7 100644 --- a/gradlib/include/rocsolgemm.cuh +++ b/gradlib/include/rocsolgemm.cuh @@ -1,34 +1,29 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #define ROCBLAS_NO_DEPRECATED_WARNINGS #define ROCBLAS_BETA_FEATURES_API -#include -#include #include #include -#include -#include -#include -#include -// #include +#include +#include +#include +#include #include #include #include -#include #include // #include #include +#include #include #include #include #include #include -#include -#include "nvToolsExt.h" #include @@ -36,11 +31,8 @@ void rocb_create_extension(); void rocb_destroy_extension(); -torch::Tensor RocSolIdxBlas( - const torch::Tensor &mat1, - const torch::Tensor &mat2, - const int32_t solution_index = 0); +torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int32_t solution_index = 0); -std::vector RocFindAllSolIdxBlas( - const torch::Tensor &mat1, - const torch::Tensor &mat2); +std::vector RocFindAllSolIdxBlas(const torch::Tensor& mat1, const torch::Tensor& mat2);