Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 332 additions & 1 deletion benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
import gc
import logging
from typing import Generator

import pytest
import torch
from pytest import mark

import flag_gems

try:
from transformer_engine.pytorch import cpp_extensions as tex

TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False

from benchmark.attri_util import (
BOOL_DTYPES,
COMPLEX_DTYPES,
DEFAULT_METRICS,
FLOAT_DTYPES,
INT_DTYPES,
)
from benchmark.performance_utils import Benchmark, generate_tensor_input
from benchmark.performance_utils import ( # noqa
Benchmark,
BenchmarkMetrics,
BenchmarkResult,
Config,
generate_tensor_input,
vendor_name,
)

fp64_is_supported = flag_gems.runtime.device.support_fp64

Expand Down Expand Up @@ -270,3 +288,316 @@ def test_bitwise_right_shift_perf():
dtypes=INT_DTYPES,
)
bench.run()


class SwigluBenchmarkResult(BenchmarkResult):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a base class that can be used by TE ops benchmark test such as #1062, #1055, #1052, #1056

def __str__(self) -> str:
header_title = (
f"\nOperator: {self.op_name} Performance Test (dtype={self.dtype}, mode={self.mode},"
f"level={self.level})\n"
)
col_names = [
f"{'Status':<12}",
f"{'TE Latency (ms)':>20}",
f"{'Gems Latency (ms)':>20}",
f"{'Gems Speedup':>20}",
f"{'TE GBPS':>20}",
f"{'Gems GBPS':>20}",
" Size Detail",
]

header_col_names = "".join(col_names)
header_break = "\n" + "-" * (len(header_col_names) + 10)
header = header_title + header_col_names + header_break

metrics_lines = "".join(self._format_metrics(ele) for ele in self.result)
return header + metrics_lines

def _format_metrics(self, metrics: BenchmarkMetrics) -> str:
status = "SUCCESS" if metrics.error_msg is None else "FAILED"
latency_base_str = (
f"{metrics.latency_base:.6f}" if metrics.latency_base is not None else "N/A"
)
latency_str = f"{metrics.latency:.6f}" if metrics.latency is not None else "N/A"
speedup_str = f"{metrics.speedup:.3f}" if metrics.speedup is not None else "N/A"
gbps_base_str = (
f"{metrics.gbps_base:.3f}" if metrics.gbps_base is not None else "N/A"
)
gbps_str = f"{metrics.gbps:.3f}" if metrics.gbps is not None else "N/A"
shape_detail_str = f"{metrics.shape_detail}"

data_line = (
f"\n{status:<12}"
f"{latency_base_str:>20}"
f"{latency_str:>20}"
f"{speedup_str:>20}"
f"{gbps_base_str:>20}"
f"{gbps_str:>20}"
f" {shape_detail_str}"
)
return data_line


class SwigluForwardBenchmark(Benchmark):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metrics = DEFAULT_METRICS + ["gbps"]
self.to_bench_metrics = self.metrics

def set_shapes(self, shape_file_path: str = None):
core_shapes = [
(1024, 1024),
(4096, 2048),
(16, 1024, 4096),
(8, 512, 8192),
(4, 128, 8, 2048),
]
self.shapes = core_shapes

if Config.bench_level.value == "comprehensive":
additional_shapes = self.set_more_shapes()
if additional_shapes:
self.shapes = list(dict.fromkeys(self.shapes + additional_shapes))

def set_more_shapes(self):
special_shapes_2d = [(4096, 2**i) for i in range(8, 14, 2)]
sp_shapes_3d = [(16, 1024, 2**i) for i in range(10, 15, 2)]
return special_shapes_2d + sp_shapes_3d

def get_input_iter(self, cur_dtype: torch.dtype) -> Generator:
for input_shape in self.shapes:
if input_shape[-1] % 2 != 0:
raise ValueError(
f"Swiglu forward input shape {input_shape} has odd last dimension"
)
input_tensor = generate_tensor_input(input_shape, cur_dtype, self.device)
yield (input_tensor,)

def get_gbps(self, args: tuple, latency: float) -> float:
if not latency or latency == 0:
return 0.0
(input_tensor,) = args
element_size = input_tensor.element_size()
total_bytes = (
input_tensor.numel() + (input_tensor.numel() // 2)
) * element_size
return total_bytes / (latency * 1e6)

def run(self):
if Config.query:
super().run()
return

self.init_user_config()
if "gbps" not in self.to_bench_metrics and any(
m in self.to_bench_metrics for m in ["latency", "latency_base"]
):
self.to_bench_metrics.append("gbps")

for dtype in self.to_bench_dtypes:
metrics_list = []
for input_data in self.get_input_iter(dtype):
metric = BenchmarkMetrics()
try:
args, kwargs = self.unpack_to_args_kwargs(input_data)
metric.shape_detail = self.record_shapes(*args, **kwargs)

if "latency_base" in self.to_bench_metrics:
metric.latency_base = self.get_latency(
self.torch_op, *args, **kwargs
)

if "latency" in self.to_bench_metrics:
if not self.gems_op:
raise ValueError(
"GEMS operator not set. Use bench.set_gems()."
)
metric.latency = self.get_latency(self.gems_op, *args, **kwargs)

if (
"speedup" in self.to_bench_metrics
and metric.latency is not None
and metric.latency > 0
):
metric.speedup = metric.latency_base / metric.latency

if "gbps" in self.to_bench_metrics:
metric.gbps_base = self.get_gbps(
args, latency=metric.latency_base
)
metric.gbps = self.get_gbps(args, latency=metric.latency)

except Exception as e:
metric.error_msg = str(e)
print(f"\nBenchmark failed for shape {metric.shape_detail}: {e}")
finally:
metrics_list.append(metric)
gc.collect()

if not metrics_list:
continue

result_formatter = SwigluBenchmarkResult(
level=Config.bench_level.value,
op_name=self.op_name,
dtype=str(dtype),
mode=Config.mode.value,
result=metrics_list,
)
print(result_formatter)
logging.info(result_formatter.to_json())


class SwigluBackwardBenchmark(Benchmark):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.metrics = DEFAULT_METRICS + ["gbps"]
self.to_bench_metrics = self.metrics

def set_shapes(self, shape_file_path: str = None):
core_shapes = [
(1024, 1024),
(4096, 2048),
(16, 1024, 4096),
(8, 512, 8192),
(4, 128, 8, 2048),
]
self.shapes = core_shapes

if Config.bench_level.value == "comprehensive":
additional_shapes = self.set_more_shapes()
if additional_shapes:
self.shapes = list(dict.fromkeys(self.shapes + additional_shapes))

def set_more_shapes(self):
special_shapes_2d = [(4096, 2**i) for i in range(8, 14, 2)]
sp_shapes_3d = [(16, 1024, 2**i) for i in range(10, 15, 2)]
return special_shapes_2d + sp_shapes_3d

def get_input_iter(self, cur_dtype: torch.dtype) -> Generator:
for input_shape in self.shapes:
if input_shape[-1] % 2 != 0:
raise ValueError(
f"Swiglu backward input shape {input_shape} has odd last dimension"
)

input_tensor = generate_tensor_input(input_shape, cur_dtype, self.device)

grad_output_shape = list(input_shape)
grad_output_shape[-1] = input_shape[-1] // 2
grad_output = generate_tensor_input(
tuple(grad_output_shape), cur_dtype, self.device
)

yield (grad_output, input_tensor)

def record_shapes(self, *args, **kwargs):
input_tensor = args[1]
return str(input_tensor.shape)

def get_gbps(self, args: tuple, latency: float) -> float:
if not latency or latency == 0:
return 0.0
grad_output, input_tensor = args
element_size = grad_output.element_size()

total_bytes = (
grad_output.numel() + input_tensor.numel() + input_tensor.numel()
) * element_size
return total_bytes / (latency * 1e6)

def run(self):
if Config.query:
super().run()
return

self.init_user_config()
if "gbps" not in self.to_bench_metrics and any(
m in self.to_bench_metrics for m in ["latency", "latency_base"]
):
self.to_bench_metrics.append("gbps")

for dtype in self.to_bench_dtypes:
metrics_list = []
for input_data in self.get_input_iter(dtype):
metric = BenchmarkMetrics()
try:
args, kwargs = self.unpack_to_args_kwargs(input_data)
metric.shape_detail = self.record_shapes(*args, **kwargs)

if "latency_base" in self.to_bench_metrics:
metric.latency_base = self.get_latency(
self.torch_op, *args, **kwargs
)

if "latency" in self.to_bench_metrics:
if not self.gems_op:
raise ValueError(
"GEMS operator not set. Use bench.set_gems()."
)
metric.latency = self.get_latency(self.gems_op, *args, **kwargs)

if (
"speedup" in self.to_bench_metrics
and metric.latency is not None
and metric.latency > 0
):
metric.speedup = metric.latency_base / metric.latency

if "gbps" in self.to_bench_metrics:
metric.gbps_base = self.get_gbps(
args, latency=metric.latency_base
)
metric.gbps = self.get_gbps(args, latency=metric.latency)

except Exception as e:
metric.error_msg = str(e)
print(f"\nBenchmark failed for shape {metric.shape_detail}: {e}")
finally:
metrics_list.append(metric)
gc.collect()

if not metrics_list:
continue

result_formatter = SwigluBenchmarkResult(
level=Config.bench_level.value,
op_name=self.op_name,
dtype=str(dtype),
mode=Config.mode.value,
result=metrics_list,
)
print(result_formatter)
logging.info(result_formatter.to_json())


@mark.skipif(
not TE_AVAILABLE,
reason="Transformer Engine backend is not available for reference.",
)
@mark.swiglu
def test_swiglu_forward_perf():
bench = SwigluForwardBenchmark(
op_name="swiglu_forward",
torch_op=lambda x: tex.swiglu(x, None),
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.swiglu)
bench.run()


@mark.skipif(
not TE_AVAILABLE,
reason="Transformer Engine backend is not available for reference.",
)
@mark.swiglu
def test_swiglu_backward_perf():
bench = SwigluBackwardBenchmark(
op_name="swiglu_backward",
torch_op=lambda grad_output, input_tensor: tex.dswiglu(
grad_output, input_tensor, None
),
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.dswiglu)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def enable(
("silu", silu),
("silu_", silu_),
("silu_backward", silu_backward),
("swiglu", swiglu),
("sin", sin),
("sin_", sin_),
("slice_scatter", slice_scatter),
Expand Down
4 changes: 4 additions & 0 deletions src/flag_gems/fused/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flag_gems.fused.rwkv_mm_sparsity import rwkv_mm_sparsity
from flag_gems.fused.silu_and_mul import silu_and_mul, silu_and_mul_out
from flag_gems.fused.skip_layernorm import skip_layer_norm
from flag_gems.fused.swiglu import SwiGLU, dswiglu, swiglu
from flag_gems.fused.topk_softmax import topk_softmax
from flag_gems.fused.weight_norm import weight_norm

Expand All @@ -25,6 +26,9 @@
"fused_add_rms_norm",
"silu_and_mul",
"silu_and_mul_out",
"swiglu",
"dswiglu",
"SwiGLU",
"gelu_and_mul",
"cross_entropy_loss",
"outer",
Expand Down
Loading
Loading