Skip to content
155 changes: 155 additions & 0 deletions benchmark/test_transformer_engine_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import pytest
import torch

import flag_gems
from benchmark.attri_util import DEFAULT_METRICS, FLOAT_DTYPES
from benchmark.performance_utils import Benchmark, generate_tensor_input

try:
from transformer_engine.pytorch import cpp_extensions as tex

TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False


class TexGluBenchmark(Benchmark):
DEFAULT_METRICS = DEFAULT_METRICS[:] + ["tflops"]
# Triton grid_y is capped at 65535, BLOCK_SIZE_H=64 -> last dim <= 8388480.
MAX_LAST_DIM = 2 * 64 * 65535

def set_more_shapes(self):
# Last dim must be even for GLU operations to split
special_shapes_2d = [(1024, 2**i) for i in range(1, 20, 4)]
sp_shapes_3d = [(64, 64, 2**i) for i in range(1, 15, 4)]
return special_shapes_2d + sp_shapes_3d

def init_user_config(self):
super().init_user_config()
supported = []
for shape in self.shapes:
last_dim = shape[-1]
if last_dim % 2 != 0:
continue
if last_dim > self.MAX_LAST_DIM:
continue
supported.append(shape)
if not supported:
pytest.skip(
"No geglu shapes satisfy the constraints of FlagGems implementation."
)
self.shapes = supported


class TexGluForwardBenchmark(TexGluBenchmark):
def get_input_iter(self, cur_dtype):
for shape in self.shapes:
x = generate_tensor_input(shape, cur_dtype, self.device)
# TE GLU APIs typically accept (input, quantizer).
yield (x, None)

def get_tflops(self, op, *args, **kwargs):
# args[0] is the input tensor x
shape = list(args[0].shape)
return torch.tensor(shape).prod().item()


class TexGluBackwardBenchmark(TexGluBenchmark):
def get_input_iter(self, cur_dtype):
for shape in self.shapes:
inp = generate_tensor_input(shape, cur_dtype, self.device)

out_shape = list(shape)
out_shape[-1] = out_shape[-1] // 2

grad_out = torch.randn(out_shape, dtype=cur_dtype, device=self.device)

yield grad_out, inp, None

def get_tflops(self, op, *args, **kwargs):
# args[1] is the original input tensor 'inp'
inp_shape = list(args[1].shape)
# Proxy FLOPs estimate: forward + backward cost roughly approximated
return torch.tensor(inp_shape).prod().item() * 2


glu_forward_ops = [
("geglu", "geglu", FLOAT_DTYPES),
# ("swiglu", "swiglu", FLOAT_DTYPES),
# ("reglu", "reglu", FLOAT_DTYPES),
]

glu_backward_ops = [
("dgeglu", "dgeglu", FLOAT_DTYPES),
# ("dswiglu", "dswiglu", FLOAT_DTYPES),
# ("dreglu", "dreglu", FLOAT_DTYPES),
]


def gems_geglu_wrapper(x, *_):
return flag_gems.geglu(x)


def gems_dgeglu_wrapper(grad_out, inp, *_args, **_kwargs):
return flag_gems.dgeglu(grad_out, inp)


@pytest.mark.parametrize(
"op_name, tex_attr_name, dtypes",
[
pytest.param(
name,
tex_attr,
dtype,
marks=getattr(pytest.mark, name, None),
)
for name, tex_attr, dtype in glu_forward_ops
],
)
def test_tex_glu_forward_perf(op_name, tex_attr_name, dtypes):
if not TE_AVAILABLE:
pytest.skip("TransformerEngine not installed")

if not hasattr(tex, tex_attr_name):
pytest.skip(f"Operator {tex_attr_name} not found in transformer_engine")

te_op = getattr(tex, tex_attr_name)

bench = TexGluForwardBenchmark(
op_name=op_name,
torch_op=te_op,
dtypes=dtypes,
gems_op=gems_geglu_wrapper,
)
bench.run()


@pytest.mark.parametrize(
"op_name, tex_attr_name, dtypes",
[
pytest.param(
name,
tex_attr,
dtype,
marks=getattr(pytest.mark, name, None),
)
for name, tex_attr, dtype in glu_backward_ops
],
)
def test_tex_glu_backward_perf(op_name, tex_attr_name, dtypes):
if not TE_AVAILABLE:
pytest.skip("TransformerEngine not installed")

if not hasattr(tex, tex_attr_name):
pytest.skip(f"Operator {tex_attr_name} not found in transformer_engine")

te_op = getattr(tex, tex_attr_name)

bench = TexGluBackwardBenchmark(
op_name=op_name,
torch_op=te_op,
dtypes=dtypes,
is_backward=False,
gems_op=gems_dgeglu_wrapper,
)
bench.run()
1 change: 1 addition & 0 deletions benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_general_unary_pointwise_perf(op_name, torch_op, dtypes):

forward_inplace_operations = [
("abs_", torch.abs_, FLOAT_DTYPES),
# ("angle", torch.angle, COMPLEX_DTYPES + [torch.float32] + INT_DTYPES + BOOL_DTYPES),
("erf_", torch.erf_, FLOAT_DTYPES),
("exp_", torch.exp_, FLOAT_DTYPES),
("exp2_", torch.exp2_, FLOAT_DTYPES),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def enable(
("gather_backward", gather_backward),
("ge.Scalar", ge_scalar),
("ge.Tensor", ge),
("geglu", geglu),
("dgeglu", dgeglu),
("gelu", gelu),
("gelu_", gelu_),
("gelu_backward", gelu_backward),
Expand Down
3 changes: 3 additions & 0 deletions src/flag_gems/fused/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from flag_gems.fused.cross_entropy_loss import cross_entropy_loss
from flag_gems.fused.flash_mla import flash_mla
from flag_gems.fused.fused_add_rms_norm import fused_add_rms_norm
from flag_gems.fused.geglu import dgeglu, geglu
from flag_gems.fused.gelu_and_mul import gelu_and_mul
from flag_gems.fused.instance_norm import instance_norm
from flag_gems.fused.moe_align_block_size import (
Expand All @@ -25,6 +26,8 @@
"fused_add_rms_norm",
"silu_and_mul",
"silu_and_mul_out",
"geglu",
"dgeglu",
"gelu_and_mul",
"cross_entropy_loss",
"outer",
Expand Down
185 changes: 185 additions & 0 deletions src/flag_gems/fused/geglu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.utils import tl_extra_shim

erf = tl_extra_shim.erf
exp = tl_extra_shim.exp
pow = tl_extra_shim.pow
tanh = tl_extra_shim.tanh

logger = logging.getLogger(__name__)


@triton.jit
def geglu_kernel(
input_ptr,
output_ptr,
M,
H,
stride_in_m,
stride_in_h,
stride_out_m,
stride_out_h,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_h = tl.program_id(1)

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)

mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)

# input 切分为 x_a, x_b
input_a_ptr = (
input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
)
input_b_ptr = (
input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
)
output_ptr = (
output_ptr + offs_m[:, None] * stride_out_m + offs_h[None, :] * stride_out_h
)

x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)

gelu_out = 0.5 * x_a * (1 + tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2))))
out = gelu_out * x_b

tl.store(output_ptr, out.to(tl.float32), mask=mask)


@triton.jit
def dgeglu_kernel(
grad_out_ptr,
input_ptr,
grad_in_ptr,
M,
H,
stride_grad_out_m,
stride_grad_out_h,
stride_in_m,
stride_in_h,
stride_grad_in_m,
stride_grad_in_h,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_h = tl.program_id(1)

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)

mask = (offs_m[:, None] < M) & (offs_h[None, :] < H)

grad_out_ptr = (
grad_out_ptr
+ offs_m[:, None] * stride_grad_out_m
+ offs_h[None, :] * stride_grad_out_h
)
input_a_ptr = (
input_ptr + offs_m[:, None] * stride_in_m + offs_h[None, :] * stride_in_h
)
input_b_ptr = (
input_ptr + offs_m[:, None] * stride_in_m + (offs_h[None, :] + H) * stride_in_h
)
grad_a_ptr = (
grad_in_ptr
+ offs_m[:, None] * stride_grad_in_m
+ offs_h[None, :] * stride_grad_in_h
)
grad_b_ptr = (
grad_in_ptr
+ offs_m[:, None] * stride_grad_in_m
+ (offs_h[None, :] + H) * stride_grad_in_h
)

grad_out = tl.load(grad_out_ptr, mask=mask, other=0.0).to(tl.float32)
x_a = tl.load(input_a_ptr, mask=mask, other=0.0).to(tl.float32)
x_b = tl.load(input_b_ptr, mask=mask, other=0.0).to(tl.float32)

# GELU 公式及其导数
tanh_out = tanh(0.79788456 * x_a * (1 + 0.044715 * pow(x_a, 2)))
gelu_out = 0.5 * x_a * (1 + tanh_out)

# dgelu/dx
sech2 = 1 - pow(tanh_out, 2)
dgelu = 0.5 * (1 + tanh_out) + 0.5 * x_a * sech2 * 0.79788456 * (
1 + 3 * 0.044715 * pow(x_a, 2)
)

# 反向传播
grad_a = grad_out * x_b * dgelu
grad_b = grad_out * gelu_out

tl.store(grad_a_ptr, grad_a.to(x_a.dtype), mask=mask)
tl.store(grad_b_ptr, grad_b.to(x_a.dtype), mask=mask)


def geglu(input_tensor: torch.Tensor) -> torch.Tensor:
shape = input_tensor.shape
H = shape[-1] // 2
M = input_tensor.numel() // (2 * H)

input_2d = input_tensor.contiguous().view(M, 2 * H)
output_2d = torch.empty(M, H, device=input_tensor.device, dtype=input_tensor.dtype)

grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(H, META["BLOCK_SIZE_H"]),
)

geglu_kernel[grid](
input_2d,
output_2d,
M,
H,
input_2d.stride(0),
input_2d.stride(1),
output_2d.stride(0),
output_2d.stride(1),
BLOCK_SIZE_M=64,
BLOCK_SIZE_H=64,
)

return output_2d.view(*shape[:-1], H)


def dgeglu(grad_output: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
shape = input_tensor.shape
H = shape[-1] // 2
M = input_tensor.numel() // (2 * H)

grad_out_2d = grad_output.contiguous().view(M, H)
input_2d = input_tensor.contiguous().view(M, 2 * H)
grad_in_2d = torch.empty_like(input_2d)

grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(H, META["BLOCK_SIZE_H"]),
)

dgeglu_kernel[grid](
grad_out_2d,
input_2d,
grad_in_2d,
M,
H,
grad_out_2d.stride(0),
grad_out_2d.stride(1),
input_2d.stride(0),
input_2d.stride(1),
grad_in_2d.stride(0),
grad_in_2d.stride(1),
BLOCK_SIZE_M=64,
BLOCK_SIZE_H=64,
)

return grad_in_2d.view_as(input_tensor)
Loading
Loading