diff --git a/src/flag_gems/fused/__init__.py b/src/flag_gems/fused/__init__.py index b3bd20d44b..76a189cb15 100644 --- a/src/flag_gems/fused/__init__.py +++ b/src/flag_gems/fused/__init__.py @@ -19,6 +19,7 @@ from flag_gems.fused.rwkv_mm_sparsity import rwkv_mm_sparsity from flag_gems.fused.silu_and_mul import silu_and_mul, silu_and_mul_out from flag_gems.fused.skip_layernorm import skip_layer_norm +from flag_gems.fused.swiglu import dswiglu, swiglu from flag_gems.fused.topk_softmax import topk_softmax from flag_gems.fused.weight_norm import weight_norm @@ -28,6 +29,8 @@ "fused_add_rms_norm", "silu_and_mul", "silu_and_mul_out", + "swiglu", + "dswiglu", "geglu", "dgeglu", "gelu_and_mul", diff --git a/src/flag_gems/fused/swiglu.py b/src/flag_gems/fused/swiglu.py new file mode 100644 index 0000000000..b5ad34bd33 --- /dev/null +++ b/src/flag_gems/fused/swiglu.py @@ -0,0 +1,188 @@ +import logging +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import tl_extra_shim + +sigmoid = tl.sigmoid +exp = tl_extra_shim.exp +pow = tl_extra_shim.pow + +logger = logging.getLogger(__name__) + + +@triton.jit +def swiglu_kernel( + input_ptr, + output_ptr, + M, + H, + stride_in_m, + stride_in_h, + stride_out_m, + stride_out_h, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_h = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) + + input_a_ptr = ( + input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h + ) + input_b_ptr = ( + input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h + ) + output_ptr = ( + output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h + ) + + x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) + x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) + + silu_x_a = x_a * sigmoid(x_a) + out = silu_x_a * x_b + + tl.store(output_ptr, out.to(x_a.dtype), mask=mask) + + +@triton.jit +def dswiglu_kernel( + grad_out_ptr, + input_ptr, + grad_in_ptr, + M, + H, + stride_grad_out_m, + stride_grad_out_h, + stride_in_m, + stride_in_h, + stride_grad_in_m, + stride_grad_in_h, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_h = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + mask = (offs_m[:, None] < M) & (offs_h[None, :] < H) + + grad_out_ptr = ( + grad_out_ptr + + offs_m[:, None] * stride_grad_out_m + + offs_h[None, :] * stride_grad_out_h + ) + input_a_ptr = ( + input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h + ) + input_b_ptr = ( + input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h + ) + grad_a_ptr = ( + grad_in_ptr + + offs_m[:, None] * stride_grad_in_m + + offs_h[None, :] * stride_grad_in_h + ) + grad_b_ptr = ( + grad_in_ptr + + offs_m[:, None] * stride_grad_in_m + + (offs_h[None, :] + H) * stride_grad_in_h + ) + + grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32) + x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32) + x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32) + + sig = sigmoid(x_a) + silu = x_a * sig + d_silu = sig + x_a * sig * (1 - sig) + + grad_a = grad_out * x_b * d_silu + grad_b = grad_out * silu + + tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask) + tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) + + +def swiglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: + if input_tensor.shape[-1] % 2 != 0: + raise ValueError(f"SwiGLU 输入最后一维必须为偶数,实际为 {input_tensor.shape[-1]}") + if not input_tensor.is_cuda: + raise ValueError("SwiGLU 仅支持 CUDA 张量") + + shape = input_tensor.shape + H = shape[-1] // 2 + M = input_tensor.numel() // (2 * H) + input_2d = input_tensor.contiguous().view(M, 2 * H) + output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(H, META["BLOCK_SIZE_H"]), + ) + + swiglu_kernel[grid]( + input_2d, + output_2d, + M, + H, + input_2d.stride(0), + input_2d.stride(1), + output_2d.stride(0), + output_2d.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_H=64, + ) + + return output_2d.view(*shape[:-1], H) + + +def dswiglu( + grad_output: torch.Tensor, + input_tensor: torch.Tensor, + quantizer: Optional[Any] = None, +) -> torch.Tensor: + shape = input_tensor.shape + assert shape[-1] % 2 == 0, f"dswiglu: input_tensor 最后一维需为偶数,实际为 {shape[-1]}" + H = shape[-1] // 2 + M = input_tensor.numel() // (2 * H) + grad_out_2d = grad_output.contiguous().view(M, H) + input_2d = input_tensor.contiguous().view(M, 2 * H) + grad_in_2d = torch.empty_like(input_2d) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(H, META["BLOCK_SIZE_H"]), + ) + + dswiglu_kernel[grid]( + grad_out_2d, + input_2d, + grad_in_2d, + M, + H, + grad_out_2d.stride(0), + grad_out_2d.stride(1), + input_2d.stride(0), + input_2d.stride(1), + grad_in_2d.stride(0), + grad_in_2d.stride(1), + BLOCK_SIZE_M=64, + BLOCK_SIZE_H=64, + ) + + return grad_in_2d.view_as(input_tensor) + + +__all__ = ["swiglu", "dswiglu"] diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index a02405f6b3..257f4cb438 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -130,6 +130,25 @@ def SkipVersion(module_name, skip_pattern): (3, 7, 1023, 1025), ] +SWIGLU_SPECIAL_SHAPES = ( + [(2, 19, 8)] + if QUICK_MODE + else [ + (2,), + (64,), + (32, 64), + (256, 512), + (1, 128), + (8, 16, 32), + (16, 32, 64), + (20, 320, 16), + (4, 8, 16, 32), + (8, 16, 32, 64), + (10,), + (20, 30), + ] +) + KRON_SHAPES = [ [(), (2, 3)], [(2, 3), ()], diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 3dfb773706..ee7520bdbc 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -19,6 +19,7 @@ FLOAT_DTYPES, INT_DTYPES, POINTWISE_SHAPES, + SWIGLU_SPECIAL_SHAPES, SkipVersion, gems_assert_close, gems_assert_equal, @@ -472,6 +473,67 @@ def test_accuracy_glu_backward(shape, dtype): gems_assert_close(res_in_grad, ref_in_grad, dtype) +def generate_input( + shape: tuple[int, ...], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + return torch.randn(shape, dtype=dtype, device=device).contiguous() + + +def filter_valid_shapes(shapes: list[tuple[int, ...]]) -> list[tuple[int, ...]]: + valid_shapes = [] + for shape in shapes: + if not shape: + continue + if shape[-1] % 2 == 0: + valid_shapes.append(shape) + return valid_shapes + + +VALID_POINTWISE_SHAPES = filter_valid_shapes(SWIGLU_SPECIAL_SHAPES) + + +@pytest.mark.swiglu +@pytest.mark.parametrize("shape", VALID_POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.skipif(not TE_AVAILABLE, reason="transformer engine is not available") +def test_accuracy_swiglu_forward(shape: tuple[int, ...], dtype: torch.dtype): + torch.manual_seed(42) + device = flag_gems.device + + input_tensor = generate_input(shape, dtype, device) + + te_forward = tex.swiglu(input_tensor, quantizer=None).to(device) + te_forward = to_reference(te_forward) + + with flag_gems.use_gems(): + fg_forward = flag_gems.swiglu(input_tensor, quantizer=None) + + gems_assert_close(fg_forward, te_forward, dtype) + + +@pytest.mark.swiglu +@pytest.mark.parametrize("shape", VALID_POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.skipif(not TE_AVAILABLE, reason="transformer engine is not available") +def test_accuracy_swiglu_backward(shape: tuple[int, ...], dtype: torch.dtype): + torch.manual_seed(42) + device = flag_gems.device + + input_tensor = generate_input(shape, dtype, device) + + grad_shape = list(shape) + grad_shape[-1] = grad_shape[-1] // 2 + grad_output = generate_input(tuple(grad_shape), dtype, device) + + te_grad_input = tex.dswiglu(grad_output, input_tensor, quantizer=None).to(device) + te_grad_input = to_reference(te_grad_input) + + with flag_gems.use_gems(): + fg_grad_input = flag_gems.dswiglu(grad_output, input_tensor, quantizer=None) + + gems_assert_close(fg_grad_input, te_grad_input, dtype) + + @pytest.mark.isinf @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES)