Skip to content

Conversation

@factnn
Copy link
Contributor

@factnn factnn commented Nov 20, 2025

Description

This PR enhances test coverage for the aten::index operator and fixes the AttributeError for mixed basic/advanced indexing.

Development Tool:

  • This operator was developed with Triton-Copilot, an AI-powered tool for Triton kernel development.

Changes

  • Fix AttributeError in index operator for mixed basic/advanced indexing
  • Add comprehensive test cases for index operator
  • Support combining advanced and basic indexing using Triton

Verification

  • ✅ All pytest tests pass
  • ✅ Supports mixed basic/advanced indexing (e.g., a[None, idx])
  • ✅ Comprehensive test coverage

Fixes #635

- Fix AttributeError in index operator for mixed basic/advanced indexing
- Add comprehensive test cases for index operator
- Support combining advanced and basic indexing using Triton

Fixes flagos-ai#635
@factnn
Copy link
Contributor Author

factnn commented Nov 25, 2025

     0.021216            0.018368               1.155          [torch.Size([268435456]), [torch.Size([65536])]]

SUCCESS 0.009376 0.007328 1.279 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.011040 0.008256 1.337 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.010992 0.008480 1.296 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.009264 0.007616 1.216 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.009152 0.008496 1.077 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.012400 0.008096 1.532 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.017216 0.008448 2.038 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.497120 0.189728 2.620 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.013248 0.008768 1.511 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

Operator: index Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail

SUCCESS 0.018176 0.016896 1.076 [torch.Size([268435456]), [torch.Size([65536])]]
SUCCESS 0.009056 0.007360 1.230 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.011488 0.008272 1.389 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.010752 0.009056 1.187 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.008416 0.007616 1.105 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.009760 0.008464 1.153 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.012480 0.008096 1.542 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.017248 0.008816 1.956 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.536176 0.379424 1.413 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.012752 0.008960 1.423 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

Operator: index Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail

SUCCESS 0.018944 0.017408 1.088 [torch.Size([268435456]), [torch.Size([65536])]]
SUCCESS 0.009344 0.007360 1.270 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.011040 0.007616 1.450 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.010144 0.007840 1.294 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.008416 0.008416 1.000 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.010016 0.008192 1.223 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.012128 0.008128 1.492 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.017024 0.008128 2.094 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.498368 0.195168 2.554 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.012848 0.007840 1.639 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

kiddyjinjin
kiddyjinjin previously approved these changes Nov 25, 2025
Copy link
Collaborator

@kiddyjinjin kiddyjinjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@kiddyjinjin kiddyjinjin self-requested a review November 27, 2025 13:05
- Update get_max_rank_shape() and broadcast_indices() in index_put.py to support None values (consistent with index.py)
- Fix precision issue: create tensor_indices AFTER broadcast_indices to ensure using broadcasted tensors
- Add gen_indices_for_index_put() function in test_reduction_ops.py to properly handle multi-dimensional index shapes
- Update all index_put tests to use gen_indices_for_index_put()

This fixes the pipeline failures and ensures consistency between index and index_put operators.
@factnn
Copy link
Contributor Author

factnn commented Dec 2, 2025

performance of index and index_put

(zpy_triton) (.venv) root@job-4211e20e-bdf4-4193-bf7e-7448650342e5-master-0:/share/project/zpy/FlagGems_eval# pytest benchmark/test_select_and_slice_perf.py -m index -s
/usr/local/lib/python3.10/dist-packages/hypothesis/entry_points.py:23: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
========================== test session starts ===========================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
rootdir: /share/project/zpy/FlagGems_eval
configfile: pytest.ini
plugins: anyio-4.11.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 19 items / 18 deselected / 1 selected
Running 1 items in this shard

benchmark/test_select_and_slice_perf.py
Operator: index Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.019120 0.016832 1.136 [torch.Size([268435456]), [torch.Size([65536])]]
SUCCESS 0.010240 0.007392 1.385 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.012832 0.007424 1.728 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.012544 0.008928 1.405 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.009440 0.007648 1.234 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.011296 0.007648 1.477 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.012448 0.008128 1.531 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.019328 0.008128 2.378 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.500224 0.196992 2.539 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.013216 0.008096 1.632 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

Operator: index Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.018176 0.017088 1.064 [torch.Size([268435456]), [torch.Size([65536])]]
SUCCESS 0.009216 0.007392 1.247 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.012640 0.007616 1.660 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.012064 0.009120 1.323 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.010288 0.008304 1.239 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.011200 0.007648 1.464 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.012672 0.008128 1.559 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.019168 0.009120 2.102 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.536864 0.376064 1.428 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.012192 0.008832 1.380 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

Operator: index Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.018896 0.017584 1.075 [torch.Size([268435456]), [torch.Size([65536])]]
SUCCESS 0.009920 0.008640 1.148 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])]]
SUCCESS 0.012704 0.008144 1.560 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])]]
SUCCESS 0.012368 0.008928 1.385 [torch.Size([32, 32]), [torch.Size([2, 8])]]
SUCCESS 0.009888 0.008128 1.217 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])]]
SUCCESS 0.010656 0.007872 1.354 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.011968 0.008160 1.467 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])]]
SUCCESS 0.020384 0.008512 2.395 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])]]
SUCCESS 0.500016 0.193888 2.579 [torch.Size([512, 512, 512]), [torch.Size([2, 128])]]
SUCCESS 0.013248 0.007904 1.676 [torch.Size([64, 64, 64]), [torch.Size([2, 8]), torch.Size([2, 8])]]

(zpy_triton) (.venv) root@job-4211e20e-bdf4-4193-bf7e-7448650342e5-master-0:/share/project/zpy/FlagGems_eval# pytest benchmark/test_select_and_slice_perf.py -m index_put -s
/usr/local/lib/python3.10/dist-packages/hypothesis/entry_points.py:23: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
========================== test session starts ===========================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
rootdir: /share/project/zpy/FlagGems_eval
configfile: pytest.ini
plugins: anyio-4.11.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 19 items / 17 deselected / 2 selected
Running 2 items in this shard

benchmark/test_select_and_slice_perf.py
Operator: index_put Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.830000 0.834320 0.995 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS 0.011936 0.010176 1.173 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS 0.014848 0.010400 1.428 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS 0.012608 0.010624 1.187 [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS 0.016032 0.014752 1.087 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS 0.022240 0.015104 1.472 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS 0.018560 0.015872 1.169 [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS 0.419712 0.416704 1.007 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 0.427440 0.417600 1.024 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 0.821888 0.504608 1.629 [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]

Operator: index_put Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 1.637488 1.641984 0.997 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS 0.011456 0.010176 1.126 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS 0.014944 0.010368 1.441 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS 0.012448 0.010704 1.163 [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS 0.020032 0.018816 1.065 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS 0.024928 0.019168 1.301 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS 0.022560 0.020608 1.095 [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS 0.825568 0.823728 1.002 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 1.346064 0.825792 1.630 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 1.721248 2.251792 0.764 [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]

Operator: index_put Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 0.830336 0.834160 0.995 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), False]
SUCCESS 0.011360 0.010176 1.116 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), False]
SUCCESS 0.014912 0.010400 1.434 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([2, 8])], torch.Size([8]), False]
SUCCESS 0.012576 0.011264 1.116 [torch.Size([32, 32]), [torch.Size([2, 8])], torch.Size([32]), False]
SUCCESS 0.016128 0.014656 1.100 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), False]
SUCCESS 0.020928 0.015104 1.386 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([4, 64])], torch.Size([64]), False]
SUCCESS 0.018528 0.015968 1.160 [torch.Size([1024, 1024]), [torch.Size([4, 64])], torch.Size([1024]), False]
SUCCESS 0.419264 0.416288 1.007 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 0.427072 0.417184 1.024 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), False]
SUCCESS 0.818960 0.508672 1.610 [torch.Size([512, 512, 512]), [torch.Size([2, 128])], torch.Size([512]), False]

.
Operator: index_put Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 1.006752 0.838432 1.201 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True]
SUCCESS 0.083392 0.010592 7.873 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), True]
SUCCESS 0.092256 0.015488 5.957 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), True]
SUCCESS 0.524256 0.417312 1.256 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), True]
SUCCESS 0.524864 0.417152 1.258 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])], torch.Size([2, 128]), True]

Operator: index_put Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail


SUCCESS 1.811648 1.643200 1.103 [torch.Size([268435456]), [torch.Size([65536])], torch.Size([65536]), True]
SUCCESS 0.084288 0.010368 8.130 [torch.Size([32, 32]), [torch.Size([8]), torch.Size([8])], torch.Size([8]), True]
SUCCESS 0.095712 0.018912 5.061 [torch.Size([1024, 1024]), [torch.Size([64]), torch.Size([64])], torch.Size([64]), True]
SUCCESS 0.932416 0.823936 1.132 [torch.Size([512, 512, 512]), [torch.Size([128]), torch.Size([128]), torch.Size([128])], torch.Size([128]), True]
SUCCESS 0.931776 0.823840 1.131 [torch.Size([512, 512, 512]), [torch.Size([2, 128]), torch.Size([2, 128]), torch.Size([2, 128])], torch.Size([2, 128]), True]

.

============================ warnings summary ============================
../../../../usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py:108: 11 warnings
/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py:108: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See triton-lang/triton#4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

benchmark/test_select_and_slice_perf.py::test_index_put_acc_false_perf
/usr/local/lib/python3.10/dist-packages/torch/library.py:365: UserWarning: Warning only once for all operators, other operators may also be overridden.
Overriding a previously registered kernel for the same operator and the same dispatch key
operator: aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
dispatch key: CUDA
previous kernel: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:18106
new kernel: registered at /dev/null:396 (Triggered internally at /pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:154.)
self.m.impl(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======= 2 passed, 17 deselected, 12 warnings in 142.52s (0:02:22) ========

- Remove excessive test cases added to INDEX_ACC_SHAPE
- Keep only the original 8 test cases to match the baseline
- This should prevent CI timeout issues
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

【operator】aten::index support

2 participants