Fix GPT-2 no-past attention fusion for transformers >= 4.27#27449
Conversation
|
@microsoft-github-policy-service agree |
|
Thanks for the review! Updated the test to use Added two expected model files:
Both |
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
444e962 to
65bee87
Compare
|
@Rishi-Dave, please merge latest main and resolve the conflicts. |
In transformers >= 4.27, the causal attention mask changed from torch.uint8 to torch.bool, removing the Cast node from the ONNX graph. FusionGptAttentionNoPast.fuse() hardcoded Cast as the first element in the mask match path, causing fusion to silently fail for all modern transformers exports. Replace match_parent_path with match_parent_paths to try both the old Cast-prefixed pattern and the new Cast-less pattern, mirroring the fix already applied to FusionGptAttention (with-past) at lines 426-438. Add a synthetic no-past graph generator and unit test covering both mask variants (with and without Cast). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace manual Attention node count assertion with verify_fusion() to match the pattern used by all other tests in the suite. Add golden expected model files for both add-input orderings. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
65bee87 to
604aa89
Compare
|
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
LGTM
gpt2_model_generator.py — Test Graph Generator (+378 lines)
Minor Issues
-
Typo:
epsion(line ~944 and ~1306 in the generator). This is a known pre-existing typo for theepsilonattribute ofLayerNormalizationthat exists throughout the file. Not introduced by this PR but worth noting. -
Large function: At 370+ lines, the generator is verbose, but this is consistent with the other generators in the file (
create_gpt2_attentionis similarly long). Acceptable.
test_attention_fusion.py — New Test (+32 lines)
Suggestion
The test saves the model to "." (current working directory) which is fragile — it depends on the CWD when tests are run. Consider using tempfile.TemporaryDirectory() for the intermediate model path, similar to how other tests in the file handle temp files. However, the existing test_gpt2_attention_fusion test uses the same pattern, so this is consistent.
Summary
FusionGptAttentionNoPastmask pattern matching to support bothtorch.uint8(old) andtorch.bool(new) causal masksMotivation
Fixes #16453
In
transformers >= 4.27(Feb 2023), the causal attention mask dtype changed fromtorch.uint8totorch.bool(commit). This removed aCastnode from the exported ONNX graph.FusionGptAttentionNoPast.fuse()hardcodedCastas the first element inmatch_parent_path, causing the mask path match to fail silently for all modern transformers exports. The result: zero Attention nodes fused for any GPT-2 model exported without past state.The sibling class
FusionGptAttention(with-past) was already fixed to handle both patterns usingmatch_parent_paths(plural). This PR applies the same approach to the no-past variant.Changes
fusion_gpt_attention_no_past.pymatch_parent_pathwithmatch_parent_pathsfor the Where-based mask path (lines 187-201), offering both the Cast-prefixed pattern (old transformers) and Cast-less pattern (transformers >= 4.27)gpt2_model_generator.pycreate_gpt2_attention_no_past()function that builds a synthetic GPT-2 no-past attention graph with the Where-based mask patternadd_castparameter to test both mask variantstest_attention_fusion.pytest_gpt2_attention_no_past_fusion()that verifies an Attention node is fused for all combinations ofadd_castandswitch_add_inputsTest Plan
test_gpt2_attention_no_past_fusionpasses (4 variants: with/without Cast × normal/switched Add inputs)lintrunnerreports no issues for new code)