Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,32 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
// __shfl_sync: support both 3-arg (mask, var, srcLane) and 4-arg (mask, var, srcLane, width) calls
// HIP ignores the mask but requires it to be 64-bit, so we cast explicitly.
#define __SHFL_SYNC_3(mask, var, srcLane) __shfl(var, srcLane, warpSize)
#define __SHFL_SYNC_4(mask, var, srcLane, width) __shfl(var, srcLane, width)
#define __SHFL_GET_MACRO(_1, _2, _3, _4, NAME, ...) NAME
#define __shfl_sync(...) __SHFL_GET_MACRO(__VA_ARGS__, __SHFL_SYNC_4, __SHFL_SYNC_3)(__VA_ARGS__)
// __shfl_up_sync: support 3-arg and 4-arg calls (HIP ignores mask)
#define __SHFL_UP_SYNC_3(mask, var, delta) __shfl_up(var, delta, warpSize)
#define __SHFL_UP_SYNC_4(mask, var, delta, width) __shfl_up(var, delta, width)
#define __SHFL_UP_GET(_1, _2, _3, _4, NAME, ...) NAME
#define __shfl_up_sync(...) __SHFL_UP_GET(__VA_ARGS__, __SHFL_UP_SYNC_4, __SHFL_UP_SYNC_3)(__VA_ARGS__)

// __shfl_xor_sync: support 3-arg and 4-arg calls (HIP ignores mask)
#define __SHFL_XOR_SYNC_3(mask, var, laneMask) __shfl_xor(var, laneMask, warpSize)
#define __SHFL_XOR_SYNC_4(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define __SHFL_XOR_GET(_1, _2, _3, _4, NAME, ...) NAME
#define __shfl_xor_sync(...) __SHFL_XOR_GET(__VA_ARGS__, __SHFL_XOR_SYNC_4, __SHFL_XOR_SYNC_3)(__VA_ARGS__)

// __shfl_down_sync: support 3-arg and 4-arg calls (HIP ignores mask)
#define __SHFL_DOWN_SYNC_3(mask, var, delta) __shfl_down(var, delta, warpSize)
#define __SHFL_DOWN_SYNC_4(mask, var, delta, width) __shfl_down(var, delta, width)
#define __SHFL_DOWN_GET(_1, _2, _3, _4, NAME, ...) NAME
#define __shfl_down_sync(...) __SHFL_DOWN_GET(__VA_ARGS__, __SHFL_DOWN_SYNC_4, __SHFL_DOWN_SYNC_3)(__VA_ARGS__)
#define __all_sync(mask, var) __all(var)
#define __any_sync(mask, var) __any(var)
#define __ballot_sync(mask, var) ((uint32_t)__ballot(var))
#define cublasStrsmBatched hipblasStrsmBatched
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
Expand Down Expand Up @@ -113,6 +134,10 @@
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaMemcpyToSymbol hipMemcpyToSymbol
#define cudaMemcpyFromSymbol hipMemcpyFromSymbol
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams
Expand Down
10 changes: 9 additions & 1 deletion ggml/src/ggml-hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")

file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
# Exclude D>=576 tile kernels: exceed HIP local memory limit (67584 > 65536)
list(FILTER SRCS EXCLUDE REGEX "dkq(576|640)")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
Expand All @@ -75,7 +77,13 @@ else()
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu)
endif()

ggml_add_backend_library(ggml-hip
Expand Down