diff --git a/benchmarks/benchmark_attn.py b/benchmarks/benchmark_attn.py index b3902110eea..7830477a68a 100644 --- a/benchmarks/benchmark_attn.py +++ b/benchmarks/benchmark_attn.py @@ -70,7 +70,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 33e5d282716..e94d325d42d 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -68,7 +68,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, w else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) - col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) + col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)