Skip to content

Commit

Permalink
QOL improvements to float8 gemm benchmark (#596)
Browse files Browse the repository at this point in the history
Summary:

1. add more options for shape generation, such as
  - square: M == K == N sweeping through powers of 2
  - sweep: M, K, N each sweeping through powers of 2
  - custom: user specifies a single value of M, K, N
2. fix a bug when calling `torch._scaled_mm`, we should create the
   scales outside the benchmark for a less biased result
3. add sweep over `fast_accum` setting
4. add ability to save result to file, for easy analysis later

Test Plan:

```
time python benchmarks/float8/bench_matmul.py --out_filename ~/local/tmp/20240803_f8_gemm_sweep_2.csv --shape_gen_name sweep
// result: https://gist.github.com/vkuzo/1d82e84ddd8aac8166695d819ebc8883
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored Aug 5, 2024
1 parent 1328787 commit de4a1fb
Showing 1 changed file with 87 additions and 31 deletions.
118 changes: 87 additions & 31 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,91 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
return time_sec, tops_sec, pct_top_peak


def get_name_to_shapes_iter(
shape_gen_name: str,
M: Optional[int],
K: Optional[int],
N: Optional[int],
):
if shape_gen_name == 'llama':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
bsz, seq_len = 4, 4096
M = bsz * seq_len
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (M, 8192, 1280),
"attn.w0": (M, 1024, 8192),
"ffn.w13": (M, 8192, 7168),
"ffn.w2": (M, 3584, 8192),
}
return name_to_shapes_70b.items()

elif shape_gen_name == 'square':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_power_of_2 = 5 # 32
max_power_of_2 = 16 # 65,536
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2 ** power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == 'sweep':
assert M == K == N == None, \
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
name_to_shapes = {}
min_p2 = 5 # 32
max_p2 = 16 # 65,536
counter = 0
for M_p2 in range(min_p2, max_p2 + 1):
M = 2 ** M_p2
for K_p2 in range(min_p2, max_p2 + 1):
K = 2 ** K_p2
for N_p2 in range(min_p2, max_p2 + 1):
N = 2 ** N_p2
name_to_shapes[counter] = M, K, N
counter += 1
return name_to_shapes.items()

elif shape_gen_name == 'custom':
assert M is not None and K is not None and N is not None, \
'M, K, N must be specified for custom shape_gen'
name_to_shapes = {
1: (M, K, N),
}
return name_to_shapes.items()

raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')


@torch.inference_mode()
def run(n_limit: Optional[int] = None):
def run(
n_limit: Optional[int] = None,
shape_gen_name: str = 'llama',
out_filename: Optional[str] = None,
M: Optional[int] = None,
K: Optional[int] = None,
N: Optional[int] = None,
):
device = "cuda"

# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (8192, 1280),
"attn.w0": (1024, 8192),
"ffn.w13": (8192, 7168),
"ffn.w2": (3584, 8192),
}

headers = ("name", "shape", "dtype", "ref_time_s", "fp8_time_s", "fp8_speedup")
headers = ("fast_accum", "name", "M", "K", "N", "ref_time_s", "fp8_time_s", "fp8_speedup")
results = []

name_to_shapes = name_to_shapes_70b
dtypes = torch.bfloat16, torch.float16
dtype = torch.bfloat16
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
fast_accum_vals = [True, False]

for idx, (dtype, (name, (K, N))) in enumerate(
itertools.product(dtypes, name_to_shapes.items())
):
for idx, (fast_accum, (name, (M, K, N))) in enumerate(itertools.product(fast_accum_vals, name_to_shapes)):
if n_limit is not None and idx >= n_limit:
break

# source: Xiao Sun, these are realistic for LLaMa 70B training
bsz, seq_len = 4, 4096

M = bsz * seq_len
print("M, K, N:", M, K, N)
tops = 2 * M * N * K
print(f"tops: {tops:.2E}")
print("M, K, N:", M, K, N, f"tops: {tops:.2E}")

# raw torch.mm
A = torch.randn(M, K, device=device, dtype=dtype)
Expand All @@ -99,12 +151,12 @@ def run(n_limit: Optional[int] = None):
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)

def do_matmul(A, B):
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
Expand All @@ -114,22 +166,26 @@ def do_matmul(A, B):
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
)

del A, B
del A, B, scale_a, scale_b

results.append(
[
fast_accum,
name,
(M, K, N),
dtype,
M,
K,
N,
ref_time_sec,
fp8_time_sec,
ref_time_sec / fp8_time_sec,
]
)

data_pd = pd.DataFrame(results, columns=headers)
print(data_pd)
data_df = pd.DataFrame(results, columns=headers)
print(data_df)

if out_filename is not None:
data_df.to_csv(out_filename)

def main() -> None:
fire.Fire(run)
Expand Down

0 comments on commit de4a1fb

Please sign in to comment.