diff --git a/benchmark/test_transformer_engine_perf.py b/benchmark/test_transformer_engine_perf.py index b7f70719a..0b063370a 100644 --- a/benchmark/test_transformer_engine_perf.py +++ b/benchmark/test_transformer_engine_perf.py @@ -75,25 +75,17 @@ def get_tflops(self, op, *args, **kwargs): glu_forward_ops = [ ("geglu", "geglu", FLOAT_DTYPES), - # ("swiglu", "swiglu", FLOAT_DTYPES), - # ("reglu", "reglu", FLOAT_DTYPES), + ("swiglu", "swiglu", FLOAT_DTYPES), + ("reglu", "reglu", FLOAT_DTYPES), ] glu_backward_ops = [ ("dgeglu", "dgeglu", FLOAT_DTYPES), - # ("dswiglu", "dswiglu", FLOAT_DTYPES), - # ("dreglu", "dreglu", FLOAT_DTYPES), + ("dswiglu", "dswiglu", FLOAT_DTYPES), + ("dreglu", "dreglu", FLOAT_DTYPES), ] -def gems_geglu_wrapper(x, *_): - return flag_gems.geglu(x) - - -def gems_dgeglu_wrapper(grad_out, inp, *_args, **_kwargs): - return flag_gems.dgeglu(grad_out, inp) - - @pytest.mark.parametrize( "op_name, tex_attr_name, dtypes", [ @@ -115,11 +107,15 @@ def test_tex_glu_forward_perf(op_name, tex_attr_name, dtypes): te_op = getattr(tex, tex_attr_name) + if not hasattr(flag_gems, op_name): + pytest.skip(f"Operator {op_name} not found in flag_gems") + gems_op = getattr(flag_gems, op_name) + bench = TexGluForwardBenchmark( op_name=op_name, torch_op=te_op, dtypes=dtypes, - gems_op=gems_geglu_wrapper, + gems_op=gems_op, ) bench.run() @@ -145,11 +141,15 @@ def test_tex_glu_backward_perf(op_name, tex_attr_name, dtypes): te_op = getattr(tex, tex_attr_name) + if not hasattr(flag_gems, op_name): + pytest.skip(f"Operator {op_name} not found in flag_gems") + gems_op = getattr(flag_gems, op_name) + bench = TexGluBackwardBenchmark( op_name=op_name, torch_op=te_op, dtypes=dtypes, is_backward=False, - gems_op=gems_dgeglu_wrapper, + gems_op=gems_op, ) bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c24ede53c..92bae8702 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -330,6 +330,8 @@ def enable( ("where.self_out", where_self_out), ("zeros", zeros), ("zeros_like", zeros_like), + ("dreglu", dreglu), + ("reglu", reglu), ), user_unused_ops_list=list(set(unused or [])), cpp_patched_ops_list=list(set(aten_patch_list)), diff --git a/src/flag_gems/fused/__init__.py b/src/flag_gems/fused/__init__.py index ea0a0f90d..b3bd20d44 100644 --- a/src/flag_gems/fused/__init__.py +++ b/src/flag_gems/fused/__init__.py @@ -11,6 +11,7 @@ ) from flag_gems.fused.moe_sum import moe_sum from flag_gems.fused.outer import outer +from flag_gems.fused.reglu import dreglu, reglu from flag_gems.fused.reshape_and_cache import reshape_and_cache from flag_gems.fused.reshape_and_cache_flash import reshape_and_cache_flash from flag_gems.fused.rotary_embedding import apply_rotary_pos_emb @@ -44,4 +45,6 @@ "topk_softmax", "rwkv_ka_fusion", "rwkv_mm_sparsity", + "dreglu", + "reglu", ] diff --git a/src/flag_gems/fused/geglu.py b/src/flag_gems/fused/geglu.py index 2bd841c07..4fb6506b6 100644 --- a/src/flag_gems/fused/geglu.py +++ b/src/flag_gems/fused/geglu.py @@ -1,4 +1,5 @@ import logging +from typing import Any, Optional import torch import triton @@ -120,7 +121,7 @@ def dgeglu_kernel( tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask) -def geglu(input_tensor: torch.Tensor) -> torch.Tensor: +def geglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: shape = input_tensor.shape H = shape[-1] // 2 M = input_tensor.numel() // (2 * H) @@ -149,7 +150,11 @@ def geglu(input_tensor: torch.Tensor) -> torch.Tensor: return output_2d.view(*shape[:-1], H) -def dgeglu(grad_output: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: +def dgeglu( + grad_output: torch.Tensor, + input_tensor: torch.Tensor, + quantizer: Optional[Any] = None, +) -> torch.Tensor: shape = input_tensor.shape H = shape[-1] // 2 M = input_tensor.numel() // (2 * H) diff --git a/src/flag_gems/fused/reglu.py b/src/flag_gems/fused/reglu.py new file mode 100644 index 000000000..23b4a3fcf --- /dev/null +++ b/src/flag_gems/fused/reglu.py @@ -0,0 +1,175 @@ +import logging +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from flag_gems import runtime +from flag_gems.utils import libentry, libtuner + +logger = logging.getLogger(__name__) + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("gated_activation"), + key=["M", "N"], +) +@triton.jit +def dreglu_kernel( + grad_output_ptr, + input_ptr, + grad_input_ptr, + M, + N, + stride_grad_out_m, + stride_grad_out_n, + stride_in_m, + stride_in_n, + stride_grad_in_m, + stride_grad_in_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + grad_output_ptr += ( + offs_m[:, None] * stride_grad_out_m + offs_n[None, :] * stride_grad_out_n + ) + input_ptr_a = ( + input_ptr + offs_m[:, None] * stride_in_m + offs_n[None, :] * stride_in_n + ) + input_ptr_b = ( + input_ptr + offs_m[:, None] * stride_in_m + (offs_n[None, :] + N) * stride_in_n + ) + grad_input_ptr_a = ( + grad_input_ptr + + offs_m[:, None] * stride_grad_in_m + + offs_n[None, :] * stride_grad_in_n + ) + grad_input_ptr_b = ( + grad_input_ptr + + offs_m[:, None] * stride_grad_in_m + + (offs_n[None, :] + N) * stride_grad_in_n + ) + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + grad_out = tl.load(grad_output_ptr, mask=mask, other=0.0) + block_a = tl.load(input_ptr_a, mask=mask, other=0.0) + block_b = tl.load(input_ptr_b, mask=mask, other=0.0) + relu_a = tl.maximum(block_a, 0.0) + d_relu_a = tl.where(block_a > 0, 1.0, 0.0) + grad_a = grad_out * d_relu_a * block_b + grad_b = grad_out * relu_a + tl.store(grad_input_ptr_a, grad_a, mask=mask) + tl.store(grad_input_ptr_b, grad_b, mask=mask) + + +@libentry() +@libtuner( + configs=runtime.get_tuned_config("gated_activation"), + key=["M", "N_OUT"], +) +@triton.jit +def reglu_kernel( + x_ptr, + y_ptr, + M, + N_OUT, + stride_x_m, + stride_x_n, + stride_y_m, + stride_y_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + x_ptr_a = x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n + x_ptr_b = ( + x_ptr + offs_m[:, None] * stride_x_m + (offs_n[None, :] + N_OUT) * stride_x_n + ) + y_ptr = y_ptr + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N_OUT) + block_a = tl.load(x_ptr_a, mask=mask, other=0.0) + block_b = tl.load(x_ptr_b, mask=mask, other=0.0) + gate = tl.where(block_a > 0, block_a, 0.0) + output = gate * block_b + tl.store(y_ptr, output, mask=mask) + + +def reglu(input_tensor: torch.Tensor, quantizer: Optional[Any] = None) -> torch.Tensor: + shape = input_tensor.shape + if input_tensor.dim() < 1: + raise ValueError("Input tensor must have at least 1 dimension.") + last_dim = shape[-1] + if last_dim % 2 != 0: + raise ValueError( + f"The last dimension of the input tensor must be even, but got {last_dim}." + ) + N_OUT = last_dim // 2 + M = input_tensor.numel() // last_dim + if input_tensor.numel() == 0: + output_shape = (*shape[:-1], N_OUT) + return torch.empty( + output_shape, device=input_tensor.device, dtype=input_tensor.dtype + ) + input_2d = input_tensor.contiguous().view(M, last_dim) + output_2d = torch.empty( + (M, N_OUT), device=input_tensor.device, dtype=input_tensor.dtype + ) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N_OUT, META["BLOCK_N"]), + ) + reglu_kernel[grid]( + input_2d, + output_2d, + M, + N_OUT, + input_2d.stride(0), + input_2d.stride(1), + output_2d.stride(0), + output_2d.stride(1), + ) + output_shape = (*shape[:-1], N_OUT) + return output_2d.view(output_shape) + + +def dreglu( + grad_output: torch.Tensor, + input_tensor: torch.Tensor, + quantizer: Optional[Any] = None, +) -> torch.Tensor: + shape = input_tensor.shape + if shape[:-1] != grad_output.shape[:-1] or shape[-1] != 2 * grad_output.shape[-1]: + raise ValueError( + f"Shape mismatch: input {shape} vs grad_output {grad_output.shape}" + ) + M = grad_output.numel() // grad_output.shape[-1] + N = grad_output.shape[-1] + grad_output_2d = grad_output.contiguous().view(M, N) + input_2d = input_tensor.contiguous().view(M, 2 * N) + grad_input = torch.empty_like(input_2d) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]), + triton.cdiv(N, META["BLOCK_N"]), + ) + dreglu_kernel[grid]( + grad_output_2d, + input_2d, + grad_input, + M, + N, + grad_output_2d.stride(0), + grad_output_2d.stride(1), + input_2d.stride(0), + input_2d.stride(1), + grad_input.stride(0), + grad_input.stride(1), + ) + return grad_input.view(shape) diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 60af0a581..01b29f4ab 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -1132,3 +1132,18 @@ index: block_size1: - 1024 - 2048 + +gated_activation: + - gen: true + param_map: + META: + BLOCK_M: block_m + BLOCK_N: block_n + num_warps: 4 + block_m: + - 1 + - 2 + - 4 + - 8 + block_n: + - 1024 diff --git a/tests/test_unary_pointwise_ops.py b/tests/test_unary_pointwise_ops.py index 94985a8ee..a6ee6ade6 100644 --- a/tests/test_unary_pointwise_ops.py +++ b/tests/test_unary_pointwise_ops.py @@ -343,32 +343,6 @@ def test_accuracy_geglu(shape, dtype): gems_assert_close(res_out, ref_out, dtype) -@pytest.mark.dreglu -@pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) -@pytest.mark.skipif(not TE_AVAILABLE, reason="transformer engine is not available") -def test_accuracy_dreglu(shape, dtype): - if len(shape) == 0 or TO_CPU: - pytest.skip("dreglu does not support 0-dim scalar tensors.") - - if shape[-1] % 2 != 0: - shape = list(shape) - shape[-1] += 1 - shape = tuple(shape) - - input_tensor = torch.randn(shape, dtype=dtype, device=flag_gems.device) - - grad_output_shape = list(shape) - grad_output_shape[-1] //= 2 - grad_output = torch.randn( - tuple(grad_output_shape), dtype=dtype, device=flag_gems.device - ) - ref_out = tex.dgeglu(grad_output, input_tensor, None) - with flag_gems.use_gems(): - res_out = flag_gems.dgeglu(grad_output, input_tensor) - gems_assert_close(res_out, ref_out, dtype) - - @pytest.mark.gelu @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @@ -1331,3 +1305,85 @@ def test_accuracy_atan_(shape, dtype): ref_out = ref_out.to(res_out.dtype) gems_assert_close(res_out, ref_out, dtype) + + +DREGU_SHAPES = [ + (), + (1,), + (512, 512), + (1, 2048), + (2048, 1), + (1024, 1024), + (20, 320, 15), + (4096, 1024), + (2048, 2048), + (1024, 4096), + (512, 512, 512), + (512, 256, 512), +] + + +@pytest.mark.dreglu +@pytest.mark.parametrize("shape", DREGU_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.skipif(not TE_AVAILABLE, reason="transformer engine is not available") +def test_accuracy_dreglu(shape, dtype): + if len(shape) == 0: + pytest.skip("dreglu does not support 0-dim scalar tensors.") + + if shape[-1] % 2 != 0: + shape = list(shape) + shape[-1] += 1 + shape = tuple(shape) + + input_tensor = torch.randn(shape, dtype=dtype, device=flag_gems.device) + + grad_output_shape = list(shape) + grad_output_shape[-1] //= 2 + grad_output = torch.randn( + tuple(grad_output_shape), dtype=dtype, device=flag_gems.device + ) + + ref_out = tex.dreglu(grad_output, input_tensor, None) + with flag_gems.use_gems(): + res_out = flag_gems.dreglu(grad_output, input_tensor, None) + gems_assert_close(res_out, ref_out, dtype) + + +REGLU_SHAPES = [ + (), + (2,), + (512, 512), + (1, 2048), + (2048, 2), + (1024, 1024), + (20, 320, 16), + (4096, 1024), + (2048, 2048), + (1024, 4096), + (512, 512, 512), + (512, 256, 512), +] + + +@pytest.mark.reglu +@pytest.mark.parametrize("shape", REGLU_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.skipif(not TE_AVAILABLE, reason="transformer engine is not available") +def test_accuracy_reglu(shape, dtype): + if len(shape) == 0: + pytest.skip("reglu does not support 0-dim scalar tensors.") + + if shape[-1] % 2 != 0: + pytest.skip( + f"reglu requires the last dimension to be even, but got shape {shape}." + ) + + input_tensor = torch.randn(shape, dtype=dtype, device=flag_gems.device) + + ref_out = tex.reglu(input_tensor, None) + + with flag_gems.use_gems(): + res_out = flag_gems.reglu(input_tensor) + + gems_assert_close(res_out, ref_out, dtype)