Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Update] Add rms norm general kernel and update sampler condition #37

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions kernels/csrc/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,28 @@ void rms_norm(torch::Tensor &out, // [num_tokens, hidden_size]
torch::Tensor &weight, // [hidden_size]
float epsilon, bool use_quant);

void layer_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);

void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);

void layer_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &input_sum, // [tokens] or [1]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);

void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
Expand Down Expand Up @@ -49,10 +64,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("epsilon"), py::arg("use_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");

m.def("layer_norm_general", &layer_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Layer Normalization to the input tensor (modified from TRTLLM kernel).");

m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");

m.def("layer_norm_general_fuse_sum", &layer_norm_general_fuse_sum, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Layer Normalization to the input tensor & get input sum (modified from TRTLLM kernel).");

m.def("rms_norm_general_fuse_sum", &rms_norm_general_fuse_sum, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("input_sum"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
"Apply Root Mean Square (RMS) Normalization to the input tensor & get input sum (TRTLLM kernel).");
Expand Down
271 changes: 268 additions & 3 deletions kernels/csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,96 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
}
}

template <typename T, typename scale_type>
__global__ void generalRMSNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
{
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;

extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_mean;
__shared__ float s_variance;

const int tidx = threadIdx.x;
const int bidx = blockIdx.x;

float variance = 0.0f;
float local_var_sum = 0.0f;

const int n_elems = hidden_dim / num_elems_T;

for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val); // no mean
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);

if (threadIdx.x == 0)
{
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();

const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;

for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i));

if (with_per_token_scaling)
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
if (use_shmem)
{
shmem[i] = val;
}
}
else if (with_per_tensor_scaling)
{
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val;
}
}

if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem)
{
val_f = compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i);
}

reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
}
}
}

template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
Expand Down Expand Up @@ -325,6 +415,100 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const
}
}

template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
__global__ void generalRMSNorm_fuse_sum(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
int tokens, int hidden_dim, scale_type* input_sum, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
{
constexpr auto num_elems_T = num_elems<T>::value;
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;

extern __shared__ __align__(sizeof(float)) char _shmem[];
T* shmem = reinterpret_cast<T*>(_shmem);
__shared__ float s_mean;
__shared__ float s_variance;

const int tidx = threadIdx.x;
const int bidx = blockIdx.x;

float variance = 0.0f;
float local_var_sum = 0.0f;

const int n_elems = hidden_dim / num_elems_T;

for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val); // no mean
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);

if (threadIdx.x == 0)
{
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();

const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;
T_scalar sum = 0.0f;

for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
const T val = cuda_cast<T>(compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i));

if (with_per_token_scaling)
{
amax = cuda_max(cuda_max<T_scalar, T>(cuda_abs(val)), amax);
sum += cuda_sum<float>(val);
if (use_shmem)
{
shmem[i] = val;
}
}
else if (with_per_tensor_scaling)
{
reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(cuda_cast<float_packed_t>(val) * scale_orig_quant);
}
else
{
normed_output[index] = val;
}
}

if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
float sum_f = blockAllReduceSum(cuda_cast<float>(sum));
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const int index = bidx * n_elems + i;
float_packed_t val_f = cuda_cast<float_packed_t>(use_shmem ? shmem[i] : input[index]);
if (!use_shmem)
{
val_f = compute_layernorm(val_f, 0.0f, s_variance, gamma, beta, i);
}

reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
input_sum[bidx] = sum_f;
}
}
}

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename out_type, bool use_quant>
Expand Down Expand Up @@ -424,7 +608,7 @@ void rms_norm(torch::Tensor &out, // [..., hidden_size]
});
}

void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
void layer_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
Expand Down Expand Up @@ -463,7 +647,46 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
});
}

void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
block.x = 32 * ((block.x + 31) / 32);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalRMSNorm", [&] {
using T = typename FloatTypeConverter<scalar_t>::Type;
if (use_per_token_quant) {
// per-token
vllm::generalRMSNorm<T, at::Half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
vllm::generalRMSNorm<T, at::Half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr<at::Half>(), nullptr,
out.data_ptr<int8_t>(), false
);
}
});
}

void layer_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &input_sum, // [tokens] or [1]
Expand Down Expand Up @@ -507,7 +730,49 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
});
}


void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &input_sum, // [tokens] or [1]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
block.x = 32 * ((block.x + 31) / 32);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "generalRMSNorm_fuse_sum", [&] {
using T = typename FloatTypeConverter<scalar_t>::Type;
if (use_per_token_quant) {
// per-token
vllm::generalRMSNorm_fuse_sum<T, at::Half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr<at::Half>(), nullptr, scaling.data_ptr<at::Half>(),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
// out.data_ptr<int8_t>(), input.data_ptr<scalar_t>(),
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
// Rasing error here
// Not implemented per-tensor input_sum
assert(false);

vllm::generalRMSNorm_fuse_sum<T, at::Half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(), nullptr,
out.data_ptr<int8_t>(), false
);
}
});
}

void invoke_dequant_add_residual_rms_norm_quant(
torch::Tensor &out, // [..., hidden_size]
Expand Down
2 changes: 1 addition & 1 deletion qserve/modeling/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(
else:
last_token_logits = logits
if (
self.sampling_params.temperature < 1e-5 or self.sampling_params.top_p < 1e-8 # greedy
self.sampling_params.temperature < 1e-5 or self.sampling_params.top_p < 1e-8 or self.sampling_params.top_k == 1 # greedy
):
token = torch.argmax(last_token_logits, dim=-1)
else:
Expand Down
2 changes: 1 addition & 1 deletion qserve_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def process_requests(
str(b),
prompt=None,
profiling_config=profiling_config,
sampling_params=SamplingParams(top_p=0.95, top_k=40, temperature=0.7),
sampling_params=SamplingParams(temperature=0.0),
)

if engine.ifb_mode == False:
Expand Down