Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
GIT_TAG sgl-kernel
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
Expand Down
29 changes: 12 additions & 17 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/elementwise
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);

m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);

m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);

m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);

m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
Expand Down Expand Up @@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);

m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, int cuda_stream) -> ()");
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);

m.def(
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
"cuda_stream) -> ()");
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);

m.def(
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
"cuda_stream) -> ()");
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);

m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
"cuda_stream) -> ()");
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);

m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);

/*
Expand Down
6 changes: 5 additions & 1 deletion sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ limitations under the License.

using namespace flashinfer;

void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
Expand All @@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
input.stride(0),
residual.stride(0),
eps,
enable_pdl,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
Expand Down
8 changes: 4 additions & 4 deletions sgl-kernel/csrc/speculative/speculative_sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType threshold_acc) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;

extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);

DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
Expand Down Expand Up @@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
}

DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
Expand Down Expand Up @@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);

const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
Expand Down
42 changes: 16 additions & 26 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
/*
* From csrc/elementwise
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
Expand Down Expand Up @@ -254,48 +254,38 @@ void segment_packbits(
*/
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);

void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);

void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);

void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);

void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);

namespace flash {
/*
Expand Down
116 changes: 108 additions & 8 deletions sgl-kernel/python/sgl_kernel/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,138 @@ def rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> torch.Tensor:
r"""Root mean square normalization.

``out[i] = (input[i] / RMS(input)) * weight[i]``

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_

Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out


def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
) -> None:
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
r"""Fused add root mean square normalization.

Step 1:
``residual[i] += input[i]``

Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)


def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.

``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_

Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gemma_rmsnorm.default(
out, input, weight, eps, get_cuda_stream()
)
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out


def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
) -> None:
r"""Gemma-style fused add root mean square normalization.

Step 1:
``residual[i] += input[i]``

Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, get_cuda_stream()
input, residual, weight, eps, enable_pdl
)


Expand Down
Loading
Loading