-
Notifications
You must be signed in to change notification settings - Fork 163
【Triton Copilot】Enhance test coverage for aten::index operator #1083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
- Fix AttributeError in index operator for mixed basic/advanced indexing - Add comprehensive test cases for index operator - Support combining advanced and basic indexing using Triton Fixes flagos-ai#635
SUCCESS 0.009376 0.007328 1.279 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]] Operator: index Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
|
kiddyjinjin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
- Update get_max_rank_shape() and broadcast_indices() in index_put.py to support None values (consistent with index.py) - Fix precision issue: create tensor_indices AFTER broadcast_indices to ensure using broadcasted tensors - Add gen_indices_for_index_put() function in test_reduction_ops.py to properly handle multi-dimensional index shapes - Update all index_put tests to use gen_indices_for_index_put() This fixes the pipeline failures and ensures consistency between index and index_put operators.
|
performance of index and index_put (zpy_triton) (.venv) root@job-4211e20e-bdf4-4193-bf7e-7448650342e5-master-0:/share/project/zpy/FlagGems_eval# pytest benchmark/test_select_and_slice_perf.py -m index -s benchmark/test_select_and_slice_perf.py SUCCESS 0.019120 0.016832 1.136 [torch.Size([268435456]), [torch.Size([65536])]] Operator: index Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive) SUCCESS 0.018176 0.017088 1.064 [torch.Size([268435456]), [torch.Size([65536])]] Operator: index Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive) SUCCESS 0.018896 0.017584 1.075 [torch.Size([268435456]), [torch.Size([65536])]] (zpy_triton) (.venv) root@job-4211e20e-bdf4-4193-bf7e-7448650342e5-master-0:/share/project/zpy/FlagGems_eval# pytest benchmark/test_select_and_slice_perf.py -m index_put -s benchmark/test_select_and_slice_perf.py SUCCESS 0.830000 0.834320 0.995 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False] Operator: index_put Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive) SUCCESS 1.637488 1.641984 0.997 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False] Operator: index_put Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive) SUCCESS 0.830336 0.834160 0.995 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False] . SUCCESS 1.006752 0.838432 1.201 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True] Operator: index_put Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive) SUCCESS 1.811648 1.643200 1.103 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True] . ============================ warnings summary ============================ benchmark/test_select_and_slice_perf.py::test_index_put_acc_false_perf -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html |
- Remove excessive test cases added to INDEX_ACC_SHAPE - Keep only the original 8 test cases to match the baseline - This should prevent CI timeout issues
Description
This PR enhances test coverage for the
aten::indexoperator and fixes the AttributeError for mixed basic/advanced indexing.Development Tool:
Changes
Verification
a[None, idx])Fixes #635