Skip to content

Commit

Permalink
BlockReduceMulti utility
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg committed Aug 6, 2024
1 parent 14ba59f commit 1920584
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 37 deletions.
49 changes: 12 additions & 37 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif

#include "reduction_utils.cuh"

namespace vllm {

// TODO(woosuk): Further optimize this kernel.
Expand All @@ -36,9 +38,8 @@ __global__ void rms_norm_kernel(
}

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
variance =
BlockReduce(reduceStorage).Reduce(variance, cub::Sum{}, blockDim.x);
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down Expand Up @@ -235,23 +236,10 @@ fused_add_rms_norm_kernel(
variance += temp.sum_squares();
residual_v[id] = temp;
}
using BlockReduce1024 = cub::BlockReduce<float, 1024>;
using BlockReduce256 = cub::BlockReduce<float, 256>;

__shared__ union {
typename BlockReduce1024::TempStorage s1024;
typename BlockReduce256::TempStorage s256;
} reduceStorage;

/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = BlockReduce1024(reduceStorage.s1024)
.Reduce(variance, cub::Sum{}, blockDim.x);
} else {
variance = BlockReduce256(reduceStorage.s256)
.Reduce(variance, cub::Sum{}, blockDim.x);
}

using BlockReduce = BlockReduceMulti<float, 256, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down Expand Up @@ -287,23 +275,10 @@ fused_add_rms_norm_kernel(
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
using BlockReduce1024 = cub::BlockReduce<float, 1024>;
using BlockReduce256 = cub::BlockReduce<float, 256>;

__shared__ union {
typename BlockReduce1024::TempStorage s1024;
typename BlockReduce256::TempStorage s256;
} reduceStorage;

/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = BlockReduce1024(reduceStorage.s1024)
.Reduce(variance, cub::Sum{}, blockDim.x);
} else {
variance = BlockReduce256(reduceStorage.s256)
.Reduce(variance, cub::Sum{}, blockDim.x);
}

using BlockReduce = BlockReduceMulti<float, 256, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
Expand Down
126 changes: 126 additions & 0 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#pragma once

#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

namespace vllm {

namespace detail {

template <typename... Ts>
union MultiUnion;

template <typename T>
union MultiUnion<T> {
using type = T;
type data;

constexpr bool is_last() { return true; }

template <size_t offset>
constexpr T& get() {
static_assert(offset == 0);
return data;
}
};

template <typename T, typename... Ts>
union MultiUnion<T, Ts...> {
MultiUnion<T> head;
MultiUnion<Ts...> tail;

template <size_t offset>
constexpr auto& get() {
if constexpr (offset == 0) {
return head.template get<0>();
} else {
return tail.template get<offset - 1>();
}
}
};

template <typename T, T... sizes>
struct is_ascending {
static constexpr bool value = true;
};

template <typename T, T size1, T size2, T... sizes>
struct is_ascending<T, size1, size2, sizes...> {
static constexpr bool value =
size1 < size2 && is_ascending<T, size2, sizes...>::value;
};

template <typename T, T... sizes>
static constexpr bool is_ascending_v = is_ascending<T, sizes...>::value;

// Example usage/tests:
static_assert(is_ascending_v<size_t, 32, 64, 128, 256, 512, 1024>);
static_assert(!is_ascending_v<size_t, 64, 64>);
static_assert(!is_ascending_v<size_t, 64, 32, 80>);

} // namespace detail

// BlockReduceMulti is a helper class that allows runtime dispatching to
// multiple block sizes for block reductions. When the number of threads
// participating in the reduction is not known at compile time, can select the
// smallest available block size that exceeds the number of threads.
//
// It uses a union to represent its shared storage, as only one block size is
// used at a time. This way no memory is wasted for the unused block sizes.
template <typename T, size_t... BlockSizes>
class BlockReduceMulti {
static_assert(sizeof...(BlockSizes) > 0, "At least one block size required");
static_assert(detail::is_ascending_v<size_t, BlockSizes...>,
"Block sizes must be in ascending order");

template <size_t I, size_t I0, size_t... Is>
__device__ __host__ static constexpr size_t get() {
static_assert(I < sizeof...(Is) + 1, "Index out of bounds");
if constexpr (I == 0) {
return I0;
} else {
return get<I - 1, Is...>();
}
}

public:
template <size_t BlockSize>
using BlockReduce = cub::BlockReduce<T, BlockSize>;

using TempStorage =
detail::MultiUnion<typename BlockReduce<BlockSizes>::TempStorage...>;

template <size_t I, typename ReductionOp>
__device__ T reduce_impl(T input, ReductionOp op, size_t num_valid) {
constexpr size_t block_size = get<I, BlockSizes...>();
// If larger blocks are available and num_valid is larger than current,
// try the next block size
if constexpr (I < sizeof...(BlockSizes) - 1) {
if (num_valid > block_size) {
return reduce_impl<I + 1>(input, op, num_valid);
}
}

// Either this is the last block size or num_valid is smaller than
// block_size, so use it
return BlockReduce<block_size>(storage.template get<I>())
.Reduce(input, op, num_valid);
}

template <typename ReductionOp>
__device__ T Reduce(T input, ReductionOp op, size_t num_valid) {
return reduce_impl<0>(input, op, num_valid);
}

__device__ BlockReduceMulti(TempStorage& storage) : storage(storage) {}

private:
TempStorage& storage;
};

} // namespace vllm

0 comments on commit 1920584

Please sign in to comment.