diff --git a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py index 186a164ab17..3d8c1ea5787 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_deepgemm.py @@ -89,22 +89,24 @@ def _get_deepgemm_workspace(self, module: 'MoE', m_max: int, # Workspace for FP8 activations capture_graph = torch.cuda.is_current_stream_capturing() workspace["workspace_0"] = DeepGemmMoEOp.buffers.get_buffer( - (expert_size_per_partition * m_max * fp8_dim), + [expert_size_per_partition, m_max, fp8_dim], dtype=torch.float8_e4m3fn, buffer_name='workspace_0', reserve_buffer=capture_graph) # Workspace for intermediate results workspace["workspace_1"] = DeepGemmMoEOp.buffers.get_buffer( - (expert_size_per_partition * m_max * - max(intermediate_size * 2, hidden_size)), + [ + expert_size_per_partition, m_max, + max(intermediate_size * 2, hidden_size) + ], dtype=torch.bfloat16, buffer_name='workspace_1', reserve_buffer=capture_graph) # Workspace for scaling factors workspace["workspace_sf"] = DeepGemmMoEOp.buffers.get_buffer( - expert_size_per_partition * (scale_k_padded // 4) * m_padded, + [expert_size_per_partition, (scale_k_padded // 4), m_padded], dtype=torch.int32, buffer_name='workspace_sf', reserve_buffer=capture_graph) diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index aeacc5f897f..5430afdb23d 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -423,6 +423,6 @@ accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-2] S accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4] SKIP (https://nvbugs/5636912) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_trtllm_attention_dp] SKIP (https://nvbugs/5637220) llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857) -unittest/_torch/modules SKIP (https://nvbugs/5636986,https://nvbugs/5637012,https://nvbugs/5637037) +unittest/_torch/modules SKIP (https://nvbugs/5637012,https://nvbugs/5637037) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_eagle3[cutlass] SKIP (https://nvbugs/5636916) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5616182)