Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 301 additions & 1 deletion benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc
import logging
from typing import Generator

import pytest
Expand All @@ -10,8 +12,22 @@
DEFAULT_METRICS,
FLOAT_DTYPES,
INT_DTYPES,
BenchmarkMetrics,
BenchmarkResult,
)
from benchmark.performance_utils import Benchmark, generate_tensor_input, vendor_name
from benchmark.performance_utils import (
Benchmark,
Config,
generate_tensor_input,
vendor_name,
)

try:
from transformer_engine.pytorch import cpp_extensions as tex

TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False

fp64_is_supported = flag_gems.runtime.device.support_fp64

Expand Down Expand Up @@ -269,3 +285,287 @@ def test_bitwise_right_shift_perf():
dtypes=INT_DTYPES,
)
bench.run()


class TEBenchmarkResult(BenchmarkResult):
def __str__(self) -> str:
header_title = (
f"\nOperator: {self.op_name} Performance Test (dtype={self.dtype}, mode={self.mode},"
f"level={self.level})\n"
)
col_names = [
f"{'Status':<12}",
f"{'TE Latency (ms)':>20}",
f"{'Gems Latency (ms)':>20}",
f"{'Gems Speedup':>20}",
f"{'TE GBPS':>20}",
f"{'Gems GBPS':>20}",
" Size Detail",
]

header_col_names = "".join(col_names)
header_break = "\n" + "-" * (len(header_col_names) + 10)
header = header_title + header_col_names + header_break

metrics_lines = "".join(self._format_metrics(ele) for ele in self.result)
return header + metrics_lines

def _format_metrics(self, metrics: BenchmarkMetrics) -> str:
status = "SUCCESS" if metrics.error_msg is None else "FAILED"
latency_base_str = (
f"{metrics.latency_base:.6f}" if metrics.latency_base is not None else "N/A"
)
latency_str = f"{metrics.latency:.6f}" if metrics.latency is not None else "N/A"
speedup_str = f"{metrics.speedup:.3f}" if metrics.speedup is not None else "N/A"
gbps_base_str = (
f"{metrics.gbps_base:.3f}" if metrics.gbps_base is not None else "N/A"
)
gbps_str = f"{metrics.gbps:.3f}" if metrics.gbps is not None else "N/A"
shape_detail_str = f"{metrics.shape_detail}"

data_line = (
f"\n{status:<12}"
f"{latency_base_str:>20}"
f"{latency_str:>20}"
f"{speedup_str:>20}"
f"{gbps_base_str:>20}"
f"{gbps_str:>20}"
f" {shape_detail_str}"
)
return data_line


class DregluBenchmark(Benchmark):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metrics = DEFAULT_METRICS + ["gbps"]
self.to_bench_metrics = self.metrics

def set_shapes(self, shape_file_path: str = None):
core_shapes = [
(1024 * 1024,),
(4096, 2048),
(16, 1024, 4096),
(8, 512, 8192),
(4, 128, 8, 2048),
]
self.shapes = core_shapes

if Config.bench_level.value == "comprehensive":
additional_shapes = self.set_more_shapes()
if additional_shapes:
self.shapes = list(dict.fromkeys(self.shapes + additional_shapes))

def set_more_shapes(self):
special_shapes_2d = [(4096, 2**i) for i in range(8, 14, 2)]
sp_shapes_3d = [(16, 1024, 2**i) for i in range(10, 15, 2)]
return special_shapes_2d + sp_shapes_3d

def get_input_iter(self, cur_dtype: torch.dtype) -> Generator:
for input_shape in self.shapes:
if input_shape[-1] % 2 != 0:
continue
input_tensor = generate_tensor_input(input_shape, cur_dtype, self.device)
grad_output_shape = list(input_shape)
grad_output_shape[-1] //= 2
grad_output = generate_tensor_input(
tuple(grad_output_shape), cur_dtype, self.device
)
yield grad_output, input_tensor

def get_gbps(self, args: tuple, latency: float) -> float:
if not latency or latency == 0:
return 0.0
grad_output, input_tensor = args
element_size = grad_output.element_size()
total_bytes = (grad_output.numel() + 2 * input_tensor.numel()) * element_size
return total_bytes / (latency * 1e6)

def run(self):
if Config.query:
super().run()
return

self.init_user_config()
if "gbps" not in self.to_bench_metrics and any(
m in self.to_bench_metrics for m in ["latency", "latency_base"]
):
self.to_bench_metrics.append("gbps")

for dtype in self.to_bench_dtypes:
metrics_list = []
for input_data in self.get_input_iter(dtype):
metric = BenchmarkMetrics()
try:
args, kwargs = self.unpack_to_args_kwargs(input_data)
metric.shape_detail = self.record_shapes(*args, **kwargs)
if "latency_base" in self.to_bench_metrics:
metric.latency_base = self.get_latency(
self.torch_op, *args, **kwargs
)
if "latency" in self.to_bench_metrics:
if not self.gems_op:
raise ValueError(
"GEMS operator not set. Use bench.set_gems()."
)
metric.latency = self.get_latency(self.gems_op, *args, **kwargs)
if (
"speedup" in self.to_bench_metrics
and metric.latency is not None
and metric.latency > 0
):
metric.speedup = metric.latency_base / metric.latency
if "gbps" in self.to_bench_metrics:
metric.gbps_base = self.get_gbps(
args, latency=metric.latency_base
)
metric.gbps = self.get_gbps(args, latency=metric.latency)
except Exception as e:
metric.error_msg = str(e)
print(f"\nBenchmark failed for shape {metric.shape_detail}: {e}")
finally:
metrics_list.append(metric)
gc.collect()

if not metrics_list:
continue

result_formatter = TEBenchmarkResult(
level=Config.bench_level.value,
op_name=self.op_name,
dtype=str(dtype),
mode=Config.mode.value,
result=metrics_list,
)
print(result_formatter)
logging.info(result_formatter.to_json())


@pytest.mark.skipif(
not TE_AVAILABLE,
reason="Transformer Engine backend is not available for reference.",
)
@pytest.mark.dreglu
def test_dreglu_perf():
bench = DregluBenchmark(
op_name="dreglu_backward",
torch_op=lambda grad, inp: tex.dreglu(grad, inp, None),
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.dreglu)
bench.run()


class RegluBenchmark(Benchmark):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metrics = DEFAULT_METRICS + ["gbps"]
self.to_bench_metrics = self.metrics

def set_shapes(self, shape_file_path: str = None):
core_shapes = [
(4096, 2048),
(16, 1024, 4096),
(8, 512, 8192),
(4, 128, 8, 2048),
]
self.shapes = core_shapes

if Config.bench_level.value == "comprehensive":
additional_shapes = self.set_more_shapes()
if additional_shapes:
self.shapes = list(dict.fromkeys(self.shapes + additional_shapes))

def set_more_shapes(self):
special_shapes_2d = [(4096, 2**i) for i in range(8, 14, 2)]
sp_shapes_3d = [(16, 1024, 2**i) for i in range(10, 15, 2)]
return special_shapes_2d + sp_shapes_3d

def get_input_iter(self, cur_dtype: torch.dtype) -> Generator:
for input_shape in self.shapes:
if input_shape[-1] % 2 != 0:
continue
input_tensor = generate_tensor_input(input_shape, cur_dtype, self.device)
yield input_tensor,

def get_gbps(self, args: tuple, latency: float) -> float:
if not latency or latency == 0:
return 0.0
(input_tensor,) = args
element_size = input_tensor.element_size()
total_bytes = (input_tensor.numel() + 0.5 * input_tensor.numel()) * element_size
return total_bytes / (latency * 1e6)

def run(self):
if Config.query:
super().run()
return

self.init_user_config()
if "gbps" not in self.to_bench_metrics and any(
m in self.to_bench_metrics for m in ["latency", "latency_base"]
):
self.to_bench_metrics.append("gbps")

for dtype in self.to_bench_dtypes:
metrics_list = []
for input_data in self.get_input_iter(dtype):
metric = BenchmarkMetrics()
try:
args, kwargs = self.unpack_to_args_kwargs(input_data)
metric.shape_detail = self.record_shapes(*args, **kwargs)
if "latency_base" in self.to_bench_metrics:
metric.latency_base = self.get_latency(
self.torch_op, *args, **kwargs
)
if "latency" in self.to_bench_metrics:
if not self.gems_op:
raise ValueError(
"GEMS operator not set. Use bench.set_gems()."
)
metric.latency = self.get_latency(self.gems_op, *args, **kwargs)
if (
"speedup" in self.to_bench_metrics
and metric.latency is not None
and metric.latency > 0
):
metric.speedup = metric.latency_base / metric.latency
if "gbps" in self.to_bench_metrics:
metric.gbps_base = self.get_gbps(
args, latency=metric.latency_base
)
metric.gbps = self.get_gbps(args, latency=metric.latency)
except Exception as e:
metric.error_msg = str(e)
print(f"\nBenchmark failed for shape {metric.shape_detail}: {e}")
finally:
metrics_list.append(metric)
gc.collect()

if not metrics_list:
continue

result_formatter = TEBenchmarkResult(
level=Config.bench_level.value,
op_name=self.op_name,
dtype=str(dtype),
mode=Config.mode.value,
result=metrics_list,
)
print(result_formatter)
logging.info(result_formatter.to_json())


@pytest.mark.skipif(
not TE_AVAILABLE,
reason="Transformer Engine backend is not available for reference.",
)
@pytest.mark.reglu
def test_reglu_perf():
bench = RegluBenchmark(
op_name="reglu_forward",
torch_op=lambda inp: tex.reglu(inp, None),
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.reglu)
bench.run()
2 changes: 2 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def enable(
("where.self_out", where_self_out),
("zeros", zeros),
("zeros_like", zeros_like),
("dreglu", dreglu),
("reglu", reglu),
),
user_unused_ops_list=list(set(unused or [])),
cpp_patched_ops_list=list(set(aten_patch_list)),
Expand Down
3 changes: 3 additions & 0 deletions src/flag_gems/fused/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
moe_align_block_size_triton,
)
from flag_gems.fused.outer import outer
from flag_gems.fused.reglu import dreglu, reglu
from flag_gems.fused.reshape_and_cache import reshape_and_cache
from flag_gems.fused.reshape_and_cache_flash import reshape_and_cache_flash
from flag_gems.fused.rotary_embedding import apply_rotary_pos_emb
Expand Down Expand Up @@ -39,4 +40,6 @@
"topk_softmax",
"rwkv_ka_fusion",
"rwkv_mm_sparsity",
"dreglu",
"reglu",
]
Loading
Loading