diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 9a438591a7aa..3fd6140426cd 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -332,6 +332,9 @@ set(SOURCES "csrc/moe/prepare_moe_input.cu" "csrc/quantization/gguf/gguf_kernel.cu" + + "csrc/sgl_diffusion/scale_residual_norm_scale_shift/scale_residual_norm_scale_shift_host.cu" + "csrc/speculative/eagle_utils.cu" "csrc/speculative/ngram_utils.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 0145676e6770..2ff15cd1df45 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -84,6 +84,11 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + m.def( + "scale_residual_norm_scale_shift(Tensor residual, Tensor x, Tensor? gate, Tensor? norm_weight, " + "Tensor? norm_bias, Tensor scale, Tensor shift, float eps, bool use_rms_norm) -> (Tensor, Tensor)"); + m.impl("scale_residual_norm_scale_shift", torch::kCUDA, &scale_residual_norm_scale_shift); + m.def( "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "Tensor pos_ids, bool interleave, bool enable_pdl, " diff --git a/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/kernel_welford.cuh b/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/kernel_welford.cuh new file mode 100644 index 000000000000..6dbe8280f819 --- /dev/null +++ b/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/kernel_welford.cuh @@ -0,0 +1,422 @@ +#ifndef SCALE_RESIDUAL_NORM_SCALE_SHIFT_KERNEL_H +#define SCALE_RESIDUAL_NORM_SCALE_SHIFT_KERNEL_H + +#include + +struct WelfordValue { + float mean = 0.0f, m2 = 0.0f; + int count = 0; +}; + +template +union PtrValUnion { + const DType* ptr; + DType value; +}; + +template +struct BroadcastDesc { + PtrValUnion union_value; + int32_t stride_b; + int32_t frame_len; +}; + +constexpr int THREADS_PER_WARP = 32; +constexpr int THREADS_PER_CTA = 128; +constexpr int WARP_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; +constexpr int64_t CTA_REDUCE_THRESHOLD = 1024; + +enum NormType : int { + LayerNorm = 0, + RMSNorm = 1, +}; + +// Load 4 elements from global memory and convert to RegT (float). +template +__inline__ __device__ void load4_cast(const PtrTy* ptr, RegT v[4]) { + using Raw = std::conditional_t, float4, ushort4>; + Raw raw = *reinterpret_cast(ptr); + const PtrTy* t4 = reinterpret_cast(&raw); + v[0] = static_cast(t4[0]); + v[1] = static_cast(t4[1]); + v[2] = static_cast(t4[2]); + v[3] = static_cast(t4[3]); +} + +// Store 4 float values back to memory as PtrTy, using vectorized write. +template +__inline__ __device__ void store4_cast(PtrTy* ptr, const RegTy v[4]) { + using Raw = std::conditional_t, float4, ushort4>; + PtrTy cast_v[4]; + cast_v[0] = static_cast(v[0]); + cast_v[1] = static_cast(v[1]); + cast_v[2] = static_cast(v[2]); + cast_v[3] = static_cast(v[3]); + Raw raw = *reinterpret_cast(cast_v); + *reinterpret_cast(ptr) = raw; +} + +template +__inline__ __device__ WelfordValue compute_scale_residual(float x, float g, float r, WelfordValue welf, float& out) { + out = fmaf(x, g, r); + if constexpr (norm_type == LayerNorm) { + welf.count += 1; + float delta = out - welf.mean; + welf.mean = welf.mean + delta / welf.count; + float delta2 = (out - welf.mean); + welf.m2 = fmaf(delta, delta2, welf.m2); + } else { + welf.count += 1; + float delta = out * out - welf.mean; + welf.mean = welf.mean + delta / welf.count; + } + return welf; +} + +// Vectorized path of (x*gate + residual), computing 4 elements per thread. +template +__inline__ __device__ WelfordValue scale_residual_aligned( + WelfordValue welf, + const DType* x, + const DType* gate, + const DType* residual, + DType* residual_output, + bool is_warp_reduce, + bool has_gate_tensor, + int D, + uint32_t thr_id, + uint32_t lane_id) { + uint32_t idx = (is_warp_reduce ? lane_id : thr_id) * 4; + uint32_t stride = (is_warp_reduce ? THREADS_PER_WARP : THREADS_PER_CTA) * 4; + + while (idx + 3 < D) { + float x_i[4], gate_i[4], residual_i[4]; + load4_cast(x + idx, x_i); + if (has_gate_tensor) { + load4_cast(gate + idx, gate_i); + } else { + gate_i[0] = gate_i[1] = gate_i[2] = gate_i[3] = 1.0f; + } + load4_cast(residual + idx, residual_i); + float resi_out[4]; +#pragma unroll + for (int j = 0; j < 4; j++) { + welf = compute_scale_residual(x_i[j], gate_i[j], residual_i[j], welf, resi_out[j]); + } + store4_cast(residual_output + idx, resi_out); + idx += stride; + } + return welf; +} + +// Scalar fallback path for residual = x * gate + residual. +template +__inline__ __device__ WelfordValue scale_residual_general( + WelfordValue welf, + const DType* x, + const DType* gate, + const DType* residual, + DType* residual_output, + bool is_warp_reduce, + bool has_gate_tensor, + int D, + uint32_t thr_id, + uint32_t lane_id) { + uint32_t idx = is_warp_reduce ? lane_id : thr_id; + uint32_t stride = is_warp_reduce ? THREADS_PER_WARP : THREADS_PER_CTA; + while (idx < D) { + float resi_out; + float gate_v = has_gate_tensor ? static_cast(gate[idx]) : 1.0f; + welf = compute_scale_residual( + static_cast(x[idx]), gate_v, static_cast(residual[idx]), welf, resi_out); + residual_output[idx] = static_cast(resi_out); + idx += stride; + } + return welf; +} + +// Warp-level mean reduction using shuffle instructions. +template +__inline__ __device__ WelfordValue warp_reduce(WelfordValue welf) { +#pragma unroll + for (int offset = thread_group_width >> 1; offset > 0; offset >>= 1) { + if constexpr (norm_type == LayerNorm) { + float other_mean = __shfl_down_sync(0xffffffff, welf.mean, offset); + float other_m2 = __shfl_down_sync(0xffffffff, welf.m2, offset); + int other_count = __shfl_down_sync(0xffffffff, welf.count, offset); + if (other_count == 0) continue; + float total = welf.count + other_count; + float delta = other_mean - welf.mean; + float rate = other_count / total; + welf.mean = fmaf(delta, rate, welf.mean); + welf.m2 = fmaf(delta * delta, welf.count * rate, welf.m2 + other_m2); + welf.count = total; + } else { + float other_mean = __shfl_down_sync(0xffffffff, welf.mean, offset); + int other_count = __shfl_down_sync(0xffffffff, welf.count, offset); + if (other_count == 0) continue; + float total = welf.count + other_count; + float delta = other_mean - welf.mean; + welf.mean = fmaf(delta, other_count / total, welf.mean); + welf.count = total; + } + } + return welf; +} + +// CTA-level reduction for LayerNorm/RMSNorm. +template +__inline__ __device__ void cta_reduce( + int lane, + int warp, + WelfordValue welf, + int D, + float eps, + float* __restrict__ shm_mean, + float* __restrict__ shm_m2, + int* __restrict__ shm_count) { + if (lane == 0) { + shm_mean[warp] = welf.mean; + shm_m2[warp] = welf.m2; + shm_count[warp] = welf.count; + } + __syncthreads(); + + if (warp == 0) { + welf.mean = (lane < WARP_PER_CTA) ? shm_mean[lane] : 0; + welf.m2 = (lane < WARP_PER_CTA) ? shm_m2[lane] : 0; + welf.count = (lane < WARP_PER_CTA) ? shm_count[lane] : 0; + welf = warp_reduce(welf); + } + + if (warp == 0 && lane == 0) { + if constexpr (norm_type == LayerNorm) { + shm_mean[0] = welf.mean; + shm_m2[0] = rsqrtf(welf.m2 / D + eps); + } else { + shm_mean[0] = rsqrtf(welf.mean + eps); + } + } +} + +// Vectorized path for norm (LayerNorm/RMSNorm) + scale/shift modulation. +template +__inline__ __device__ void norm_scale_shift_aligned( + const DType* residual_output, + const ParamDType* norm_weight, + const ParamDType* norm_bias, + PtrValUnion shift_union, + PtrValUnion scale_union, + DType* modulated, + float mean, + float inv, + bool is_warp_reduce, + bool is_scale_shift_tensor, + bool has_weight_tensor, + bool has_bias_tensor, + int D, + uint32_t thr_id, + uint32_t lane_id) { + uint32_t idx = (is_warp_reduce ? lane_id : thr_id) * 4; + uint32_t stride = (is_warp_reduce ? THREADS_PER_WARP : THREADS_PER_CTA) * 4; + while (idx + 3 < D) { + float resi_out_i[4], weight_i[4], bias_i[4], scale_i[4], shift_i[4]; + float mod_i[4]; + load4_cast(residual_output + idx, resi_out_i); + if (has_weight_tensor) { + load4_cast(norm_weight + idx, weight_i); + } else { + weight_i[0] = weight_i[1] = weight_i[2] = weight_i[3] = 1.0f; + } + if (has_bias_tensor) { + load4_cast(norm_bias + idx, bias_i); + } else { + bias_i[0] = bias_i[1] = bias_i[2] = bias_i[3] = 0.0f; + } + if (is_scale_shift_tensor) { + load4_cast(scale_union.ptr + idx, scale_i); + load4_cast(shift_union.ptr + idx, shift_i); + } else { + scale_i[0] = scale_i[1] = scale_i[2] = scale_i[3] = scale_union.value; + shift_i[0] = shift_i[1] = shift_i[2] = shift_i[3] = shift_union.value; + } +#pragma unroll + for (int j = 0; j < 4; j++) { + float norm_x; + if constexpr (norm_type == LayerNorm) { + norm_x = (resi_out_i[j] - mean) * inv; + norm_x = fmaf(weight_i[j], norm_x, bias_i[j]); + } else if constexpr (norm_type == RMSNorm) { + norm_x = weight_i[j] * resi_out_i[j] * inv; + } + // 3. Modulate + mod_i[j] = fmaf(norm_x, (1.0f + scale_i[j]), shift_i[j]); + } + store4_cast(modulated + idx, mod_i); + idx += stride; + } +} + +// Scalar fallback path for norm + scale/shift, used when D is unaligned. +template +__inline__ __device__ void norm_scale_shift_general( + const DType* residual_output, + const ParamDType* norm_weight, + const ParamDType* norm_bias, + PtrValUnion shift_union, + PtrValUnion scale_union, + DType* modulated, + float mean, + float inv, + bool is_warp_reduce, + bool is_scale_shift_tensor, + bool has_weight_tensor, + bool has_bias_tensor, + int D, + uint32_t thr_id, + uint32_t lane_id) { + uint32_t idx = is_warp_reduce ? lane_id : thr_id; + uint32_t stride = is_warp_reduce ? THREADS_PER_WARP : THREADS_PER_CTA; + while (idx < D) { + float resi_out = static_cast(residual_output[idx]); + float norm_weight_v = has_weight_tensor ? static_cast(norm_weight[idx]) : 1.0f; + float norm_x; + if constexpr (norm_type == LayerNorm) { + float norm_bias_v = has_bias_tensor ? static_cast(norm_bias[idx]) : 0.0f; + norm_x = (resi_out - mean) * inv; + norm_x = fmaf(norm_weight_v, norm_x, norm_bias_v); + } else if constexpr (norm_type == RMSNorm) { + norm_x = norm_weight_v * resi_out * inv; + } + // 3. Modulate + float scale_value = is_scale_shift_tensor ? scale_union.ptr[idx] : scale_union.value; + float shift_value = is_scale_shift_tensor ? shift_union.ptr[idx] : shift_union.value; + float mod = fmaf(norm_x, (1.0f + scale_value), shift_value); + modulated[idx] = static_cast(mod); + idx += stride; + } +} + +/** + * @brief ScaleResidualNormScaleShift. + */ +template +__global__ __launch_bounds__(THREADS_PER_CTA) void scale_residual_norm_scale_shift_kernel( + const DType* residual, + const DType* x, + const DType* gate, + const ParamDType* norm_weight, + const ParamDType* norm_bias, + BroadcastDesc shift_desc, + BroadcastDesc scale_desc, + double eps, + DType* modulated, + DType* residual_output, + int B, + int S, + int D, + int gate_frame_len, + bool is_warp_reduce, + bool has_weight_tensor, + bool has_bias_tensor) { + uint32_t cta_id = blockIdx.x; + uint32_t thr_id = threadIdx.x; + uint32_t lane_id = thr_id & 31; + uint32_t warp_id = thr_id >> 5; + + // Pointer Offset + int64_t tile_id = is_warp_reduce ? cta_id * WARP_PER_CTA + warp_id : cta_id; + if (tile_id >= B * S) return; + + residual += tile_id * D; + x += tile_id * D; + bool has_gate_tensor = gate_frame_len != -1; + if (has_gate_tensor && gate_frame_len != -2) { + gate += tile_id / gate_frame_len * D; + } + bool is_scale_shift_tensor = scale_desc.stride_b != -1; + if (is_scale_shift_tensor) { + const int64_t batch_idx = tile_id / S; + const int64_t seq_idx = tile_id % S; + shift_desc.union_value.ptr += (batch_idx * shift_desc.stride_b + seq_idx) / shift_desc.frame_len * D; + scale_desc.union_value.ptr += (batch_idx * scale_desc.stride_b + seq_idx) / scale_desc.frame_len * D; + } + modulated += tile_id * D; + residual_output += tile_id * D; + + // Scale & Residual + WelfordValue welf; + if constexpr (is_d_aligned) { + welf = scale_residual_aligned( + welf, x, gate, residual, residual_output, is_warp_reduce, has_gate_tensor, D, thr_id, lane_id); + } else { + welf = scale_residual_general( + welf, x, gate, residual, residual_output, is_warp_reduce, has_gate_tensor, D, thr_id, lane_id); + } + + // Reduce + __shared__ float shm_mean[WARP_PER_CTA]; // mean of {LayerNorm: x, RMSNorm: x^2} + __shared__ float shm_m2[WARP_PER_CTA]; + __shared__ int shm_count[WARP_PER_CTA]; + welf = warp_reduce(welf); + float mean = 0.0f, inv; + if (is_warp_reduce) { + if constexpr (norm_type == LayerNorm) { + welf.mean = __shfl_sync(0xffffffff, welf.mean, 0); + welf.m2 = __shfl_sync(0xffffffff, welf.m2, 0); + mean = welf.mean; + inv = rsqrtf(welf.m2 / D + eps); + } else { + welf.mean = __shfl_sync(0xffffffff, welf.mean, 0); + inv = rsqrtf(welf.mean + eps); + } + } else { + cta_reduce(lane_id, warp_id, welf, D, eps, shm_mean, shm_m2, shm_count); + __syncthreads(); + if constexpr (norm_type == LayerNorm) { + mean = shm_mean[0]; + inv = shm_m2[0]; + } else { + inv = shm_mean[0]; + } + } + + // Norm & Modulate + if constexpr (is_d_aligned) { + norm_scale_shift_aligned( + residual_output, + norm_weight, + norm_bias, + shift_desc.union_value, + scale_desc.union_value, + modulated, + mean, + inv, + is_warp_reduce, + is_scale_shift_tensor, + has_weight_tensor, + has_bias_tensor, + D, + thr_id, + lane_id); + } else { + norm_scale_shift_general( + residual_output, + norm_weight, + norm_bias, + shift_desc.union_value, + scale_desc.union_value, + modulated, + mean, + inv, + is_warp_reduce, + is_scale_shift_tensor, + has_weight_tensor, + has_bias_tensor, + D, + thr_id, + lane_id); + } +} + +#endif diff --git a/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/scale_residual_norm_scale_shift_host.cu b/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/scale_residual_norm_scale_shift_host.cu new file mode 100644 index 000000000000..88e1053d13b3 --- /dev/null +++ b/sgl-kernel/csrc/sgl_diffusion/scale_residual_norm_scale_shift/scale_residual_norm_scale_shift_host.cu @@ -0,0 +1,411 @@ + + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "kernel_welford.cuh" + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_DIM(x, n) TORCH_CHECK((x).dim() == (n), #x " must have " #n " dimensions") + +namespace { + +template +struct BroadcastParam { + at::Tensor tensor; + BroadcastDesc desc; +}; + +bool tensor_aligned_for_vectorized_load(const at::Tensor& t) { + auto dt = t.scalar_type(); + uintptr_t addr = reinterpret_cast(t.data_ptr()); + if (dt == at::kFloat) { + return addr % 16 == 0; + } + if (dt == at::kHalf || dt == at::kBFloat16) { + return addr % 8 == 0; + } + return false; +} + +bool optional_tensor_aligned_for_vectorized_load(const c10::optional& t_opt) { + if (!t_opt.has_value()) return true; + return tensor_aligned_for_vectorized_load(t_opt.value()); +} + +template +BroadcastParam prepare_scale_shift_tensor( + const at::Tensor& tensor, int64_t B, int64_t S, int64_t D, const at::TensorOptions& options) { + BroadcastParam param; + TORCH_CHECK(tensor.defined(), "Tensor must be defined."); + auto t = tensor.to(options.dtype()); + param.desc.frame_len = S; + + const int64_t ndim = t.dim(); + if (ndim == 0) { + // (scalar) -> layout(shape=(1), stride=(0)) + DType value = t.item(); + param.desc.stride_b = -1; + param.tensor = t; + param.desc.union_value.value = value; + } else if (ndim == 1) { + // (1) -> layout(shape=(1), stride=(0)) + TORCH_CHECK(t.size(0) == 1, "Expected shape [1] for broadcast tensor."); + DType value = t[0].item(); + param.desc.stride_b = -1; + param.tensor = t; + param.desc.union_value.value = value; + } else if (ndim == 2) { + // (B,D) -> layout(shape=(B,1,D), stride=(stride_B,0,1)) + // (1,D) -> layout(shape=(1,1,D), stride=(0,0,1)) + TORCH_CHECK(t.size(1) == D, "Trailing dim must match hidden size."); + TORCH_CHECK(t.size(0) == B || t.size(0) == 1, "Leading dim must be batch size or 1."); + param.desc.stride_b = t.size(0) == B ? S : 0; + param.desc.frame_len = S; + t = t.reshape({t.size(0), 1, D}); + t = t.contiguous(); + param.tensor = t; + param.desc.union_value.ptr = t.data_ptr(); + } else if (ndim == 3) { + // (B,S,D) + // (B,1,D) + // (1,S,D) + // (1,1,D) + TORCH_CHECK(t.size(2) == D, "Trailing dim must match hidden size."); + TORCH_CHECK(t.size(0) == B || t.size(0) == 1, "Leading dim must be batch size or 1."); + TORCH_CHECK(t.size(1) == S || t.size(1) == 1, "Middle dim must be sequence length or 1."); + param.desc.stride_b = t.size(0) == B ? S : 0; + param.desc.frame_len = S / t.size(1); + t = t.contiguous(); + param.tensor = t; + param.desc.union_value.ptr = t.data_ptr(); + } else if (ndim == 4) { + // (B,F,1,D) -> (B,F,D) + TORCH_CHECK(t.size(2) == 1 && t.size(3) == D, "Expected [B,F,1,D] for frame broadcast."); + TORCH_CHECK(t.size(0) == B || t.size(0) == 1, "Leading dim must be batch size or 1."); + auto num_frames = t.size(1); + TORCH_CHECK(S % num_frames == 0, "Sequence length must be divisible by num_frames."); + t = t.reshape({t.size(0), num_frames, D}); + param.desc.stride_b = S; + param.desc.frame_len = S / num_frames; + t = t.contiguous(); + param.tensor = t; + param.desc.union_value.ptr = t.data_ptr(); + } else { + TORCH_CHECK(false, "Unsupported rank for broadcast tensor."); + } + return param; +} + +struct GateParam { + at::Tensor storage; + int frame_len; +}; + +template +GateParam prepare_gate( + const c10::optional& gate_opt, int64_t B, int64_t S, int64_t D, const at::TensorOptions& options) { + GateParam gate_param; + int64_t num_frames = 1; + at::Tensor gate_prepared; + if (gate_opt.has_value() && gate_opt.value().defined()) { + const auto& gate = gate_opt.value(); + CHECK_CUDA(gate); + if (gate.dim() == 2) { + TORCH_CHECK(gate.size(1) == D, "2D-gate hidden size mismatch"); + TORCH_CHECK(gate.size(0) == 1, "2D-gate tensor must be [1,D]"); + gate_prepared = gate.contiguous().to(options.dtype()).view({1, D}); + gate_param.frame_len = -2; + } else if (gate.dim() == 3) { + TORCH_CHECK(gate.size(0) == B, "3D-gate batch size mismatch"); + TORCH_CHECK(gate.size(2) == D, "3D-gate hidden size mismatch"); + TORCH_CHECK(gate.size(1) == 1, "3D-gate tensor must be [B,1,D]"); + gate_prepared = gate.contiguous().to(options.dtype()).view({B, 1, D}); + gate_param.frame_len = S / num_frames; + } else if (gate.dim() == 4) { + TORCH_CHECK(gate.size(0) == B, "4D-gate batch size mismatch"); + TORCH_CHECK(gate.size(3) == D, "4D-gate hidden size mismatch"); + TORCH_CHECK(gate.size(2) == 1, "4D-gate tensor must be [B,F,1,D]"); + num_frames = gate.size(1); + TORCH_CHECK(S % num_frames == 0, "sequence length must be divisible by num_frames"); + gate_prepared = gate.contiguous().to(options.dtype()).view({B, num_frames, 1, D}); + gate_param.frame_len = S / num_frames; + } else { + TORCH_CHECK(false, "gate tensor must be rank 3 or 4"); + } + gate_param.storage = gate_prepared; + } else { + gate_param.frame_len = -1; + } + return gate_param; +} + +struct NormParams { + at::Tensor weight; + bool has_weight_tensor; + at::Tensor bias; + bool has_bias_tensor; +}; + +template +NormParams prepare_norm_params( + const c10::optional& weight_opt, + const c10::optional& bias_opt, + int64_t D, + const at::TensorOptions& options) { + NormParams params; + if (weight_opt.has_value() && weight_opt.value().defined()) { + const auto& norm_weight = weight_opt.value(); + CHECK_CUDA(norm_weight); + TORCH_CHECK(norm_weight.numel() == D, "norm_weight must have length D"); + params.weight = norm_weight.contiguous().to(options.dtype()); + params.has_weight_tensor = true; + } else { + params.has_weight_tensor = false; + } + if (bias_opt.has_value() && bias_opt.value().defined()) { + const auto& norm_bias = bias_opt.value(); + CHECK_CUDA(norm_bias); + TORCH_CHECK(norm_bias.numel() == D, "norm_bias must have length D"); + params.bias = norm_bias.contiguous().to(options.dtype()); + params.has_bias_tensor = true; + } else { + params.has_bias_tensor = false; + } + return params; +} + +template +void launch_fused( + dim3 grid, + dim3 block, + cudaStream_t stream, + DType* residual, + DType* x, + DType* gate, + const ParamDType* w, + const ParamDType* b, + BroadcastDesc shift_desc, + BroadcastDesc scale_desc, + double eps, + DType* modulated, + DType* residual_output, + int B, + int S, + int D, + int frame_len, + bool is_warp_reduce, + bool has_weight_tensor, + bool has_bias_tensor) { + scale_residual_norm_scale_shift_kernel<<>>( + residual, + x, + gate, + w, + b, + shift_desc, + scale_desc, + eps, + modulated, + residual_output, + B, + S, + D, + frame_len, + is_warp_reduce, + has_weight_tensor, + has_bias_tensor); +} + +template +using LauncherFn = void (*)( + dim3, + dim3, + cudaStream_t, + DType*, + DType*, + DType*, + const ParamDType*, + const ParamDType*, + BroadcastDesc, + BroadcastDesc, + double, + DType*, + DType*, + int, + int, + int, + int, + bool, + bool, + bool); + +template +static constexpr LauncherFn DISPATCH_TABLE[2][2] = { + {&launch_fused, + &launch_fused}, + {&launch_fused, + &launch_fused}}; +} // namespace + +/*==========================================================================* + * Public entry point invoked from Python. It validates inputs, prepares * + * all broadcast buffers (gate/norm/scale/shift), and dispatches the CUDA * + * kernel that fuses gate + normalization + scale/shift. * + *==========================================================================*/ +std::tuple scale_residual_norm_scale_shift( + const at::Tensor& residual, + const at::Tensor& x, + const c10::optional& gate_opt, + const c10::optional& norm_weight_opt, + const c10::optional& norm_bias_opt, + const at::Tensor& shift, + const at::Tensor& scale, + double eps, + bool use_rms_norm) { + // --- basic input validation --- + CHECK_CUDA(residual); + CHECK_CUDA(x); + CHECK_CUDA(shift); + CHECK_CUDA(scale); + TORCH_CHECK(residual.dim() == 3, "residual must be [B, S, D]"); + TORCH_CHECK(x.sizes() == residual.sizes(), "x must match residual shape"); + + const auto B = residual.size(0); + const auto S = residual.size(1); + const auto D = residual.size(2); + auto orig_dtype = residual.dtype(); + + c10::SmallVector activation_types = { + residual.scalar_type(), x.scalar_type(), scale.scalar_type(), shift.scalar_type()}; + if (gate_opt.has_value() && gate_opt.value().defined()) { + activation_types.push_back(gate_opt.value().scalar_type()); + } + auto activation_scalar = activation_types.front(); + bool has_mixed_activation = false; + for (const auto& st : activation_types) { + if (st != activation_scalar) { + has_mixed_activation = true; + break; + } + } + if (has_mixed_activation) { + activation_scalar = at::ScalarType::Float; + } + auto act_opts = residual.options().dtype(activation_scalar); + auto cast_activation = [&](const at::Tensor& t) { + if (t.scalar_type() == activation_scalar) { + return t.contiguous(); + } + return t.to(act_opts).contiguous(); + }; + auto residual_f = cast_activation(residual); + auto x_f = cast_activation(x); + auto modulated = at::empty_like(residual_f); + auto residual_output = at::empty_like(residual_f); + + auto param_scalar = at::ScalarType::Float; + bool param_scalar_set = false; + auto set_param_scalar = [&](const at::Tensor& t, const char* name) { + CHECK_CUDA(t); + if (!param_scalar_set) { + param_scalar = t.scalar_type(); + param_scalar_set = true; + } else { + TORCH_CHECK(t.scalar_type() == param_scalar, name, " dtype must match other norm parameters."); + } + }; + if (norm_weight_opt.has_value() && norm_weight_opt.value().defined()) { + set_param_scalar(norm_weight_opt.value(), "norm_weight"); + } + if (norm_bias_opt.has_value() && norm_bias_opt.value().defined()) { + set_param_scalar(norm_bias_opt.value(), "norm_bias"); + } + auto act_opts_const = act_opts; + + bool is_warp_reduce = D <= CTA_REDUCE_THRESHOLD; + bool is_d_aligned = D % 4 == 0 && tensor_aligned_for_vectorized_load(residual) && + tensor_aligned_for_vectorized_load(x) && optional_tensor_aligned_for_vectorized_load(gate_opt) && + optional_tensor_aligned_for_vectorized_load(norm_weight_opt) && + optional_tensor_aligned_for_vectorized_load(norm_bias_opt) && + tensor_aligned_for_vectorized_load(shift) && tensor_aligned_for_vectorized_load(scale); + dim3 block(THREADS_PER_CTA); + uint32_t cta_per_grid = is_warp_reduce ? (B * S + WARP_PER_CTA - 1) / WARP_PER_CTA : B * S; + dim3 grid(dim3(cta_per_grid, 1, 1)); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto param_opts = torch::TensorOptions().device(residual.device()).dtype(param_scalar); + + auto dispatch_activation = [&](auto dtype_tag) { + using DType = decltype(dtype_tag); + auto gate_param = prepare_gate(gate_opt, B, S, D, act_opts_const); + bool has_gate_tensor = gate_param.frame_len != -1; + auto shift_param = prepare_scale_shift_tensor(shift, B, S, D, act_opts_const); + auto scale_param = prepare_scale_shift_tensor(scale, B, S, D, act_opts_const); + + auto dispatch_param = [&](auto param_tag) { + using ParamDType = decltype(param_tag); + auto norm_params = prepare_norm_params(norm_weight_opt, norm_bias_opt, D, param_opts); + auto launcher = DISPATCH_TABLE[use_rms_norm][is_d_aligned]; + launcher( + grid, + block, + stream, + residual_f.data_ptr(), + x_f.data_ptr(), + has_gate_tensor ? gate_param.storage.template data_ptr() : nullptr, + norm_params.has_weight_tensor ? norm_params.weight.template data_ptr() : nullptr, + norm_params.has_bias_tensor ? norm_params.bias.template data_ptr() : nullptr, + shift_param.desc, + scale_param.desc, + eps, + modulated.data_ptr(), + residual_output.data_ptr(), + B, + S, + D, + static_cast(gate_param.frame_len), + is_warp_reduce, + norm_params.has_weight_tensor, + norm_params.has_bias_tensor); + }; + + switch (param_scalar) { + case at::ScalarType::Float: + dispatch_param(float{}); + break; + case at::ScalarType::Half: + dispatch_param(at::Half{}); + break; + case at::ScalarType::BFloat16: + dispatch_param(at::BFloat16{}); + break; + default: + TORCH_CHECK(false, "Unsupported parameter dtype for fused kernel."); + } + }; + + switch (activation_scalar) { + case at::ScalarType::Float: + dispatch_activation(float{}); + break; + case at::ScalarType::Half: + dispatch_activation(at::Half{}); + break; + case at::ScalarType::BFloat16: + dispatch_activation(at::BFloat16{}); + break; + default: + TORCH_CHECK(false, "Unsupported activation dtype for fused kernel."); + } + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed"); + + return {modulated.to(orig_dtype), residual_output.to(orig_dtype)}; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index e171ab3168f1..af39937ab9b4 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -135,6 +135,17 @@ void silu_and_mul(at::Tensor& out, at::Tensor& input); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input); void gelu_and_mul(at::Tensor& out, at::Tensor& input); +std::tuple scale_residual_norm_scale_shift( + const at::Tensor& residual, + const at::Tensor& x, + const c10::optional& gate_opt, + const c10::optional& norm_weight_opt, + const c10::optional& norm_bias_opt, + const at::Tensor& scale, + const at::Tensor& shift, + double eps, + bool use_rms_norm); + void apply_rope_pos_ids_cos_sin_cache( at::Tensor q, at::Tensor k, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 8e8994e04c95..55b0d7a32c42 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -31,6 +31,7 @@ gemma_rmsnorm, rmsnorm, rotary_embedding, + scale_residual_norm_scale_shift, silu_and_mul, ) from sgl_kernel.expert_specialization import ( diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index c7a6a0ed1ba4..53f72e518dc2 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple import torch from sgl_kernel.utils import is_arch_support_pdl @@ -122,6 +122,42 @@ def gemma_rmsnorm( return out +def scale_residual_norm_scale_shift( + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | None, + weight: torch.Tensor, + bias: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + eps: float, + norm_type: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Fused kernel: scale_residual_norm_scale_shift. + + 1. residual_out = residual + gate * x + 2. normalized = norm(residual_out) + 3. modulated = (1 + scale) * normalized + shift. + + Returns + ------- + output: Tuple[torch.Tensor, torch.Tensor] + Modulated tensor, shape (batch_size, seq_len, hidden_dim); + Residual Output tensor, shape (batch_size, seq_len, hidden_dim). + """ + return torch.ops.sgl_kernel.scale_residual_norm_scale_shift( + residual, + x, + gate, + weight, + bias, + scale, + shift, + eps, + norm_type == "rms", + ) + + def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, diff --git a/sgl-kernel/tests/sgl_diffusion/test_fused_scale_residual_norm_scale_shift.py b/sgl-kernel/tests/sgl_diffusion/test_fused_scale_residual_norm_scale_shift.py new file mode 100644 index 000000000000..1bf8a2d5fac6 --- /dev/null +++ b/sgl-kernel/tests/sgl_diffusion/test_fused_scale_residual_norm_scale_shift.py @@ -0,0 +1,281 @@ +import pytest +import sgl_kernel +import torch + +device = "cuda" if torch.cuda.is_available() else None + + +def make_tensor_unaligned(x: torch.Tensor | None, unaligned: bool): + if not unaligned: + return x + if x is None: + return None + numel = x.numel() + buf = torch.empty(numel + 1, dtype=x.dtype, device=x.device) + buf[1 : numel + 1].copy_(x.view(-1)) + return buf[1:].view_as(x) + + +# Data Generation +def datagen( + dtype, + param_type, + batch, + seq, + frame, + hidden_dim, + use_affine, + use_bias, + eps, + norm_type, + gate_shape, + scale_shift_shape, + unaligned: bool = False, +): + residual = torch.randn(batch, seq, hidden_dim, dtype=dtype, device=device) + x = torch.randn(batch, seq, hidden_dim, dtype=dtype, device=device) + if gate_shape == "1": + gate = None + elif gate_shape == "1D": + gate = torch.randn(1, hidden_dim, dtype=dtype, device=device) + elif gate_shape == "BF1D": + if seq % frame != 0: + pytest.skip(f"seq ({seq}) must be divisible by frame ({frame}).") + gate = torch.randn(batch, frame, 1, hidden_dim, dtype=dtype, device=device) + elif gate_shape == "B1D": + gate = torch.randn(batch, 1, hidden_dim, dtype=dtype, device=device) + else: + raise ValueError("Unknown gate shape.") + norm_weight, norm_bias = None, None + if use_affine: + norm_weight = torch.randn(hidden_dim, dtype=param_type, device=device) + if norm_type == "layer" and use_bias: + norm_bias = torch.randn(hidden_dim, dtype=param_type, device=device) + if "1" == scale_shift_shape: + scale = torch.tensor(1.0, dtype=dtype, device=device) + shift = torch.tensor(1.0, dtype=dtype, device=device) + elif scale_shift_shape == "BD": + scale = torch.randn(batch, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(batch, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "1D": + scale = torch.randn(1, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(1, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "BSD": + scale = torch.randn(batch, seq, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(batch, seq, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "B1D": + scale = torch.randn(batch, 1, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(batch, 1, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "1SD": + scale = torch.randn(1, seq, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(1, seq, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "11D": + scale = torch.randn(1, 1, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(1, 1, hidden_dim, dtype=dtype, device=device) + elif scale_shift_shape == "BF1D": + if seq % frame != 0: + pytest.skip(f"seq ({seq}) must be divisible by frame ({frame}).") + scale = torch.randn(batch, frame, 1, hidden_dim, dtype=dtype, device=device) + shift = torch.randn(batch, frame, 1, hidden_dim, dtype=dtype, device=device) + return ( + make_tensor_unaligned(residual, unaligned=unaligned), + make_tensor_unaligned(x, unaligned=unaligned), + make_tensor_unaligned(gate, unaligned=unaligned), + make_tensor_unaligned(norm_weight, unaligned=unaligned), + make_tensor_unaligned(norm_bias, unaligned=unaligned), + make_tensor_unaligned(shift, unaligned=unaligned), + make_tensor_unaligned(scale, unaligned=unaligned), + eps, + norm_type == "rms", + ) + + +# Reference +def scale_residual_norm_scale_shift_ref( + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | None, + norm_weight: torch.Tensor | None, + norm_bias: torch.Tensor | None, + shift: torch.Tensor, + scale: torch.Tensor, + eps: float, + norm_type: bool, +): + # 1. residual add + if isinstance(gate, torch.Tensor): + if gate.dim() == 4: + # gate.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = gate.shape[1] + frame_seqlen = x.shape[1] // num_frames + residual_out = residual + ( + x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate + ).flatten(1, 2) + else: + # gate.shape: [batch_size, 1, inner_dim] + residual_out = residual + x * gate + else: + gate = 1 + residual_out = residual + x * gate + # 2. normalize + if norm_type == False: # LayerNorm + mean = residual_out.mean(dim=-1, keepdim=True) + var = residual_out.var(dim=-1, unbiased=False, keepdim=True) + normalized = (residual_out - mean) / torch.sqrt(var + eps) + elif norm_type == True: # RMSNorm + rms = residual_out.pow(2).mean(dim=-1, keepdim=True) + normalized = residual_out / torch.sqrt(rms + eps) + # 3. apply affine transform if given + if norm_weight is not None and norm_bias is not None: + normalized = normalized * norm_weight + norm_bias + elif norm_weight is not None: + normalized = normalized * norm_weight + # 4. apply scale/shift if given + batch, seq_len, hidden_dim = x.shape + if scale.ndim <= 3: + if scale.ndim == 0 or (scale.ndim == 1 and scale.numel() == 1): + # (), (1) → (B, S, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + elif scale.ndim == 2 and scale.shape in [ + (1, hidden_dim), + (batch, hidden_dim), + ]: + # (B, D) or (1, D) → (B, S, 1, D) + scale = scale[:, None, :].expand(batch, seq_len, hidden_dim) + shift = shift[:, None, :].expand(batch, seq_len, hidden_dim) + elif scale.ndim == 3 and scale.shape in [ + (batch, seq_len, hidden_dim), + (batch, 1, hidden_dim), + (1, seq_len, hidden_dim), + (1, 1, hidden_dim), + ]: + # (B, S, D), (B, 1, D), (1, S, D), (1, 1, D) → (B, S, 1, D) + scale = scale.expand(batch, seq_len, hidden_dim) + shift = shift.expand(batch, seq_len, hidden_dim) + normalized = normalized * (1.0 + scale) + shift + elif scale.ndim == 4 and scale.shape == (batch, scale.shape[1], 1, hidden_dim): + num_frames = scale.shape[1] + frame_seqlen = normalized.shape[1] // num_frames + normalized = ( + normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + * (1.0 + scale) + + shift + ).flatten(1, 2) + return normalized, residual_out + + +compiled_scale_residual_norm = torch.compile(scale_residual_norm_scale_shift_ref) + + +def _run_test( + *, + dtype=torch.float32, + batch=1, + seq=2048, + frame=4, + hidden_dim=1536, + use_affine=True, + use_bias=True, + eps=1e-6, + norm_type="layer", + gate_shape="1D", + scale_shift_shape="B1D", + unaligned: bool = False, +): + if device is None: + pytest.skip("No cuda device available for this test") + + param_type = dtype + input_data = datagen( + dtype, + param_type, + batch, + seq, + frame, + hidden_dim, + use_affine, + use_bias, + eps, + norm_type, + gate_shape, + scale_shift_shape, + unaligned, + ) + mod_ref, resi_out_ref = compiled_scale_residual_norm(*input_data) + mod, resi_out = sgl_kernel.scale_residual_norm_scale_shift(*input_data) + + if dtype == torch.float32: + torch.testing.assert_close(mod, mod_ref, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(resi_out, resi_out_ref, rtol=1e-5, atol=1e-5) + elif dtype == torch.float16: + torch.testing.assert_close(mod, mod_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(resi_out, resi_out_ref, rtol=1e-2, atol=1e-2) + elif dtype == torch.bfloat16: + torch.testing.assert_close(mod, mod_ref, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(resi_out, resi_out_ref, rtol=1e-2, atol=1e-2) + else: + raise ValueError(f"Not implement data type: {dtype}") + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("norm_type", ["layer", "rms"]) +def test_scale_residual_norm_scale_shift_dtype(dtype, norm_type): + _run_test(dtype=dtype, norm_type=norm_type) + + +@pytest.mark.parametrize("batch", [1, 2, 4, 8]) +def test_scale_residual_norm_scale_shift_batch(batch): + _run_test(batch=batch) + + +@pytest.mark.parametrize("seq", [83, 1024, 2047, 32760]) +def test_scale_residual_norm_scale_shift_seq(seq): + _run_test(seq=seq) + + +@pytest.mark.parametrize("frame", [4, 8]) +def test_scale_residual_norm_scale_shift_frame(frame): + _run_test(frame=frame) + + +@pytest.mark.parametrize("hidden_dim", [83, 1024, 1536, 3072, 4096]) +def test_scale_residual_norm_scale_shift_hidden_dim(hidden_dim): + _run_test(hidden_dim=hidden_dim) + + +@pytest.mark.parametrize("use_affine", [False, True]) +def test_scale_residual_norm_scale_shift_affine(use_affine): + _run_test(use_affine=use_affine) + + +@pytest.mark.parametrize("use_bias", [True]) +def test_scale_residual_norm_scale_shift_bias(use_bias): + _run_test(use_bias=use_bias) + + +@pytest.mark.parametrize("norm_type", ["layer", "rms"]) +def test_scale_residual_norm_scale_shift_norm_type(norm_type): + _run_test(norm_type=norm_type) + + +@pytest.mark.parametrize("gate_shape", ["1", "1D", "B1D", "BF1D"]) +def test_scale_residual_norm_scale_shift_gate_shape(gate_shape): + _run_test(gate_shape=gate_shape) + + +@pytest.mark.parametrize( + "scale_shift_shape", ["1", "1D", "BD", "BSD", "B1D", "1SD", "11D", "BF1D"] +) +def test_scale_residual_norm_scale_shift_scale_shift_shape(scale_shift_shape): + _run_test(scale_shift_shape=scale_shift_shape) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("unaligned", [True]) +def test_scale_residual_norm_unaligned(dtype, unaligned): + _run_test(dtype=dtype, unaligned=unaligned) + + +if __name__ == "__main__": + pytest.main([__file__])