Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-tune vector sizes for NVLS allreduce6 #338

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions include/mscclpp/nvls_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,71 +25,82 @@ struct DeviceMulticastPointerDeviceHandle {
size_t bufferSize;

#if defined(MSCCLPP_DEVICE_CUDA)
template <typename TValue = float4, typename T = float>
template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemLoadReduce(TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f32 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f32 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.f16x2 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.f16x2 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "l"(ptr)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16x2 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm("multimem.ld_reduce.relaxed.sys.global.add.f16 {%0}, [%1];" : "=r"(val.x) : "l"(ptr) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStore(const TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm volatile("multimem.st.relaxed.sys.global.f32 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.st.relaxed.sys.global.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.st.relaxed.sys.global.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
};

template <typename TValue, typename T>
MSCCLPP_DEVICE_INLINE static void multimemStoreReduce(const TValue& val, T* ptr) {
if constexpr (std::is_same<TValue, float4>::value && std::is_same<T, float>::value) {
if constexpr (std::is_same_v<TValue, float4> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f32 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x), "r"(val.y),
"r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint4>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.v2.f32 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, float>) {
asm volatile("multimem.red.relaxed.sys.global.add.f32 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same_v<TValue, uint4> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.v4.f16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.x),
"r"(val.y), "r"(val.z), "r"(val.w)
: "memory");
} else if constexpr (std::is_same<TValue, uint2>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint2> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.v2.f16x2 [%0], {%1,%2};" ::"l"(ptr), "r"(val.x), "r"(val.y)
: "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half2>::value) {
} else if constexpr (std::is_same_v<TValue, uint1> && std::is_same_v<T, __half2>) {
asm volatile("multimem.red.relaxed.sys.global.add.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else if constexpr (std::is_same<TValue, uint1>::value && std::is_same<T, __half>::value) {
asm volatile("multimem.red.relaxed.sys.global.add.f16 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory");
} else {
static_assert(dependentFalse<T>, "Not supported type");
}
Expand Down
1 change: 1 addition & 0 deletions python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._mscclpp import (
Communicator,
Connection,
connect_nvls_collective,
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
EndpointConfig,
Fifo,
Host2DeviceSemaphore,
Expand Down
1 change: 1 addition & 0 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._mscclpp import (
Communicator,
Connection,
connect_nvls_collective,
EndpointConfig,
Host2DeviceSemaphore,
Host2HostSemaphore,
Expand Down
82 changes: 62 additions & 20 deletions python/mscclpp_benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ extern "C" __global__ void __launch_bounds__(1024, 1)

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900

// Barrier among all devices followed by a memory fence
// Barrier among all devices
// Should be called by all threads on all devices
// Assumes \p num_threads_per_block >= \p num_ranks
__forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores, int thread_id,
Expand All @@ -806,36 +806,78 @@ __forceinline__ __device__ void barrier(mscclpp::SmDevice2DeviceSemaphoreDeviceH
deviceSyncer.sync(num_blocks);
}

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, TYPE* buff, int my_rank, int nranks,
size_t nelem) {
float* dev_ptr = (float*)nvlsPtrs.devicePtr;
float* mc_ptr = (float*)nvlsPtrs.mcPtr;
// Assumes \p kVecSize is 1, 2, 4, or 8 (default 8)
template <typename DataType = float, int kVecSize = 8>
MSCCLPP_DEVICE_INLINE void allreduce6_helper(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank,
int num_ranks, size_t num_elements) {
DataType* mc_ptr = (DataType*)nvlsPtrs.mcPtr;
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_threads_per_block = blockDim.x;
int num_blocks = gridDim.x;

// start with a barrier to ensure all devices have written their values
// to their own memory (that is part of the multicast memory)
// before reading them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);

int my_st = ((int64_t)nelem * (int64_t)my_rank) / (int64_t)nranks;
int my_en = ((int64_t)nelem * (int64_t)(my_rank + 1)) / (int64_t)nranks;

int my_offset = (tid + bid * blockDim.x) * 4;
int my_step = blockDim.x * gridDim.x * 4;

for (int idx = my_st + my_offset; idx < my_en; idx += my_step) {
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, mc_ptr + idx);
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, mc_ptr + idx);
barrier(semaphores, tid, bid, num_blocks, num_ranks);

// every device loads, reduces, and stores a partition of the multicast memory
int rank_start = ((int64_t)num_elements * (int64_t)my_rank) / (int64_t)num_ranks;
int rank_end = ((int64_t)num_elements * (int64_t)(my_rank + 1)) / (int64_t)num_ranks;

int thread_offset = (bid * num_threads_per_block + tid) * kVecSize;
int thread_step = (num_threads_per_block * num_blocks) * kVecSize; // number of threads * vector size

for (int idx = rank_start + thread_offset; idx < rank_end; idx += thread_step) {
if constexpr (std::is_same_v<DataType, float> && (kVecSize == 4)) {
uint4 val; // fits 4 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 2)) {
uint2 val; // fits 2 float elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, float> && (kVecSize == 1)) {
uint1 val; // fits 1 float element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (float*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (float*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 8)) {
uint4 val; // fits 8 cutlass::half_t elements; i.e., 4 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 4)) {
uint2 val; // fits 4 cutlass::half_t elements; i.e., 2 half2 elements
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else if constexpr (std::is_same_v<DataType, __half> && (kVecSize == 2)) {
uint1 val; // fits 2 cutlass::half_t elements; i.e., 1 half2 element
mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (half2*)(mc_ptr + idx));
mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (half2*)(mc_ptr + idx));
} else {
// not supported: cannot use static_assert because of the way TYPE is handled in this file
assert(false); // Unsupported data type and vector size combination
}
}

// end with a barrier to ensure all devices can now read their values
// from their own memory (that is part of the multicast memory)
// after writing them in this kernel
barrier(semaphores, tid, bid, num_blocks, nranks);
barrier(semaphores, tid, bid, num_blocks, num_ranks);
}

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce6(mscclpp::SmDevice2DeviceSemaphoreDeviceHandle* semaphores,
mscclpp::DeviceMulticastPointerDeviceHandle nvlsPtrs, int my_rank, int num_ranks, size_t num_elements,
size_t vector_size) {
if (vector_size == 8) {
allreduce6_helper<TYPE, 8>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else if (vector_size == 4) {
allreduce6_helper<TYPE, 4>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else if (vector_size == 2) {
allreduce6_helper<TYPE, 2>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
} else {
allreduce6_helper<TYPE, 1>(semaphores, nvlsPtrs, my_rank, num_ranks, num_elements);
}
}
#endif
2 changes: 1 addition & 1 deletion python/mscclpp_benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def run_benchmark(
MscclppAllReduce1(mscclpp_group, memory),
MscclppAllReduce3(mscclpp_group, memory, proxy_service),
]
if is_nvls_supported():
if is_nvls_supported() and (data_type == cp.float32 or data_type == cp.float16):
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
mscclpp_algos.append(MscclppAllReduce6(mscclpp_group, nelem, data_type))
else:
if memory.nbytes < 2**22:
Expand Down
27 changes: 22 additions & 5 deletions python/mscclpp_benchmark/mscclpp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,16 @@ def __init__(
self.device_handles_cp = cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8)
self.nvls_handle = self.nvls_mem_handle.device_handle().raw

self.set_params(nblocks, block_size)
if self.memory.dtype != cp.float16 and self.memory.dtype != cp.float32:
raise RuntimeError("Unsupported data type")

if self.memory.dtype == cp.float16:
Binyang2014 marked this conversation as resolved.
Show resolved Hide resolved
vector_size = 8
elif self.memory.dtype == cp.float32:
vector_size = 4
else:
vector_size = 1
self.set_params(nblocks, block_size, vector_size)

def get_memory(self):
return self.memory
Expand All @@ -477,23 +486,31 @@ def __call__(self, stream_ptr):
self.kernel.launch_kernel(self.params, self.nblocks, self.block_size, 0, stream_ptr)
return self.memory

def set_params(self, nblocks, block_size):
def set_params(self, nblocks, block_size, vector_size):
self.nblocks = nblocks
self.block_size = block_size
self.vector_size = vector_size
self.params = b""
self.params += pack(
self.device_handles_cp,
self.nvls_handle,
self.memory,
self.group.my_rank,
self.group.nranks,
ctypes.c_size_t(self.memory.size),
self.vector_size,
)

def auto_tune(self):
nblocks_to_try = [8, 12, 16, 24, 32, 48, 64, 72, 96, 108]
block_size_to_try = [256, 512, 1024]
if self.memory.dtype == cp.float16:
vector_size_to_try = [8, 4, 2]
elif self.memory.dtype == cp.float32:
vector_size_to_try = [4, 2, 1]
else:
vector_size_to_try = [1]
for nblocks in nblocks_to_try:
for block_size in block_size_to_try:
self.set_params(nblocks, block_size)
yield nblocks, block_size
for vector_size in vector_size_to_try:
self.set_params(nblocks, block_size, vector_size)
yield nblocks, block_size, vector_size
Loading