Skip to content

Commit aa1aef6

Browse files
committed
refine
1 parent 4cdbea4 commit aa1aef6

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

benchmark/kernels/minmax-text-01-lighting_attention/benchmark_lighting_attention_decode.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def forward(self, hidden_states):
269269
def test_lightning_attention_implementations(model_params):
270270
torch.manual_seed(42)
271271

272-
batch_size = 2
272+
batch_size = 64
273273
seq_len = 1
274274
dtype = torch.bfloat16
275275
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -285,7 +285,6 @@ def test_lightning_attention_implementations(model_params):
285285
model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device)
286286
model_attn.eval()
287287

288-
# 创建一个假的past_key_value
289288
d = model_params["head_dim"]
290289
past_kv = torch.randn(
291290
batch_size,
@@ -398,7 +397,6 @@ def benchmark(batch_size, seq_len, provider):
398397
model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device)
399398
model_attn.eval()
400399

401-
# 创建一个假的past_key_value
402400
d = params["head_dim"]
403401
past_kv = torch.randn(
404402
batch_size,
@@ -460,15 +458,17 @@ def run_triton():
460458
)
461459
args = parser.parse_args()
462460

463-
# 运行正确性测试
464461
params = {
465462
"hidden_size": 6144,
466463
"num_attention_heads": 64,
467464
"head_dim": 96,
468465
"hidden_act": "silu",
469466
}
467+
468+
# Run correctness test first
469+
# Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json
470470
test_lightning_attention_implementations(params)
471471

472-
# 运行性能测试
472+
# Run performance benchmark
473473
benchmark = get_benchmark()
474474
benchmark.run(print_data=True, save_path=args.save_path)

0 commit comments

Comments
 (0)