Skip to content

Commit ee20e67

Browse files
authored
[https://nvbugs/5636986][fix] Fix DeepGemmMoe get_buffer calls (#8939)
Signed-off-by: Xiwen Yu <[email protected]> Signed-off-by: xiweny <[email protected]>
1 parent b53961e commit ee20e67

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,24 @@ def _get_deepgemm_workspace(self, module: 'MoE', m_max: int,
8989
# Workspace for FP8 activations
9090
capture_graph = torch.cuda.is_current_stream_capturing()
9191
workspace["workspace_0"] = DeepGemmMoEOp.buffers.get_buffer(
92-
(expert_size_per_partition * m_max * fp8_dim),
92+
[expert_size_per_partition, m_max, fp8_dim],
9393
dtype=torch.float8_e4m3fn,
9494
buffer_name='workspace_0',
9595
reserve_buffer=capture_graph)
9696

9797
# Workspace for intermediate results
9898
workspace["workspace_1"] = DeepGemmMoEOp.buffers.get_buffer(
99-
(expert_size_per_partition * m_max *
100-
max(intermediate_size * 2, hidden_size)),
99+
[
100+
expert_size_per_partition, m_max,
101+
max(intermediate_size * 2, hidden_size)
102+
],
101103
dtype=torch.bfloat16,
102104
buffer_name='workspace_1',
103105
reserve_buffer=capture_graph)
104106

105107
# Workspace for scaling factors
106108
workspace["workspace_sf"] = DeepGemmMoEOp.buffers.get_buffer(
107-
expert_size_per_partition * (scale_k_padded // 4) * m_padded,
109+
[expert_size_per_partition, (scale_k_padded // 4), m_padded],
108110
dtype=torch.int32,
109111
buffer_name='workspace_sf',
110112
reserve_buffer=capture_graph)

tests/integration/test_lists/waives.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,6 @@ accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] S
404404
accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4] SKIP (https://nvbugs/5636912)
405405
accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] SKIP (https://nvbugs/5637220)
406406
llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857)
407-
unittest/_torch/modules SKIP (https://nvbugs/5636986,https://nvbugs/5637012,https://nvbugs/5637037)
407+
unittest/_torch/modules SKIP (https://nvbugs/5637012,https://nvbugs/5637037)
408408
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass] SKIP (https://nvbugs/5636916)
409409
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5616182)

0 commit comments

Comments
 (0)