Skip to content

Commit 4aed50c

Browse files
authored
perf: enable pdl for cutlass fp4 gemm (#2095)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description The `enablePDL` flag is set to false, this PR turned them on. Set to true for both because sm_100 and sm_120 should have support of pdl. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Updated runtime configuration for FP4 GEMM operations to enhance execution performance on SM100 and SM120 GPU architectures. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent d42b71f commit 4aed50c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

β€Žinclude/flashinfer/gemm/fp4_gemm_template_sm100.hβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
273273
std::string(cutlassGetStatusString(initStatus)); \
274274
throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \
275275
} \
276-
auto runStatus = gemm.run(args, workspace, stream, nullptr, /* enablePDL */ false); \
276+
auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \
277277
if (runStatus != cutlass::Status::kSuccess) { \
278278
std::string errMsg = "Failed to run cutlass FP4 gemm on sm100. Error: " + \
279279
std::string(cutlassGetStatusString(runStatus)); \

β€Žinclude/flashinfer/gemm/fp4_gemm_template_sm120.hβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ size_t genericFp4GemmKernelLauncher(void* D, void const* A, void const* B, void
257257
std::string(cutlass::cutlassGetStatusString(initStatus)); \
258258
throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \
259259
} \
260-
auto runStatus = gemm.run(args, workspace, stream, nullptr, /* enablePDL */ false); \
260+
auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \
261261
if (runStatus != cutlass::Status::kSuccess) { \
262262
std::string errMsg = "Failed to run cutlass FP4 gemm on sm120. Error: " + \
263263
std::string(cutlass::cutlassGetStatusString(runStatus)); \

0 commit comments

Comments
Β (0)