Fix transformer 2.9.0 (torch 2.9.1 used by SGLang 0.5.5) build#2445
Fix transformer 2.9.0 (torch 2.9.1 used by SGLang 0.5.5) build#2445yiakwy-xpu-ml-framework-team wants to merge 2 commits intoNVIDIA:release_v2.9from
Conversation
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR fixes build issues for Transformer Engine 2.9.0 when used with torch 2.9.1 (as required by SGLang 0.5.5). The changes address three main issues:
The C++ header fixes are solid. However, the Python changes weaken installation safety checks and the flash attention fallback logic appears incomplete - the bare Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Build as Build System
participant PyInit as Python __init__.py
participant CommonH as common.h
participant CUDADriver as cuda_driver.h
participant NVTE as transformer_engine.h
participant Backends as backends.py
participant FlashAttn as Flash Attention
Note over Build,FlashAttn: Transformer Engine 2.9.0 Build Process
Build->>CommonH: Compile C++ headers
CommonH->>CUDADriver: #include cuda_driver.h
CUDADriver->>CommonH: #include common.h (circular!)
Note over CommonH,CUDADriver: Fixed: Move NVTE include after local headers
CommonH->>NVTE: #include transformer_engine.h (reordered)
Build->>PyInit: Import transformer_engine
PyInit->>PyInit: Check package installation
Note over PyInit: Commented out: assert te_installed_via_pypi
PyInit->>PyInit: Allow non-PyPI installations
Build->>Backends: Import backends.py
Backends->>FlashAttn: try: from flash_attn_3...
alt flash_attn_3 module exists
FlashAttn-->>Backends: Import successful
else ModuleNotFoundError
Note over Backends: Print error message
Backends->>FlashAttn: Fallback: from flash_attn_interface...
FlashAttn-->>Backends: Import may fail
end
|
| # extension are all installed via PyPI and have matching versions. | ||
| if te_framework_installed: | ||
| assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." | ||
| # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." |
There was a problem hiding this comment.
style: Commenting out this validation weakens installation safety checks. Users with mismatched installations (e.g., PyPI framework package + source-built core) may encounter runtime issues. Consider a more targeted fix that specifically allows the SGLang 0.5.5 scenario while keeping validation for other cases.
| from flash_attn_interface import flash_attn_func as flash_attn_func_v3 | ||
| from flash_attn_interface import ( | ||
| flash_attn_varlen_func as flash_attn_varlen_func_v3, | ||
| ) | ||
| from flash_attn_interface import ( | ||
| flash_attn_with_kvcache as flash_attn_with_kvcache_v3, | ||
| ) | ||
| from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 | ||
| from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 |
There was a problem hiding this comment.
logic: This fallback will fail if flash_attn_interface module doesn't exist. The package flash-attn-3 was detected but imports from flash_attn_3.flash_attn_interface failed. If the goal is to support an alternative location like flash_attn.flash_attn_interface (from flash-attn v2), this should use flash_attn.flash_attn_interface instead of bare flash_attn_interface. Without the proper module path, these imports will raise ModuleNotFoundError and the variables will remain undefined, causing issues later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: