Skip to content

Conversation

@jikechao
Copy link
Contributor

@jikechao jikechao commented Nov 24, 2025

PR Category

Operator

Type of Change

New Feature

Description

Support a new operator atan2

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is fully covered by a UT.

Performance

platform linux -- Python 3.12.12, pytest-9.0.0, pluggy-1.6.0
benchmark: 5.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /data/shenqingchao/software/FlagGems
configfile: pytest.ini
plugins: anyio-4.9.0, langsmith-0.4.32, benchmark-5.2.3
collected 33 items / 32 deselected / 1 selected                                                    

test_binary_pointwise_perf.py 
Operator: atan2  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup               TFLOPS          Size Detail
--------------------------------------------------------------------------------------------------------------------
SUCCESS               9.151488            9.484288               0.965               0.226          [torch.Size([1073741824]), torch.Size([1073741824])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64]), torch.Size([64, 64])]
SUCCESS               0.125952            0.173056               0.728               0.194          [torch.Size([4096, 4096]), torch.Size([4096, 4096])]
SUCCESS               0.125952            0.124928               1.008               0.269          [torch.Size([64, 512, 512]), torch.Size([64, 512, 512])]
SUCCESS               9.337856            9.671680               0.965               0.222          [torch.Size([1024, 1024, 1024]), torch.Size([1024, 1024, 1024])]
SUCCESS               0.005120            0.005120               1.000               0.000          [torch.Size([1024, 1]), torch.Size([1024, 1])]
SUCCESS               0.005120            0.005120               1.000               0.006          [torch.Size([1024, 16]), torch.Size([1024, 16])]
SUCCESS               0.007168            0.007168               1.000               0.073          [torch.Size([1024, 256]), torch.Size([1024, 256])]
SUCCESS               0.035840            0.034816               1.029               0.241          [torch.Size([1024, 4096]), torch.Size([1024, 4096])]
SUCCESS               0.579584            0.618496               0.937               0.217          [torch.Size([1024, 65536]), torch.Size([1024, 65536])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]
SUCCESS               0.005120            0.005120               1.000               0.026          [torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]
SUCCESS               0.012288            0.013312               0.923               0.158          [torch.Size([64, 64, 256]), torch.Size([64, 64, 256])]
SUCCESS               0.124928            0.158720               0.787               0.211          [torch.Size([64, 64, 4096]), torch.Size([64, 64, 4096])]
SUCCESS               2.208768            2.295808               0.962               0.234          [torch.Size([64, 64, 65536]), torch.Size([64, 64, 65536])]


Operator: atan2  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup               TFLOPS          Size Detail
--------------------------------------------------------------------------------------------------------------------
SUCCESS              14.853120           15.513600               0.957               0.138          [torch.Size([1073741824]), torch.Size([1073741824])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64]), torch.Size([64, 64])]
SUCCESS               0.242688            0.242688               1.000               0.138          [torch.Size([4096, 4096]), torch.Size([4096, 4096])]
SUCCESS               0.242688            0.242688               1.000               0.138          [torch.Size([64, 512, 512]), torch.Size([64, 512, 512])]
SUCCESS              14.870528           15.404032               0.965               0.139          [torch.Size([1024, 1024, 1024]), torch.Size([1024, 1024, 1024])]
SUCCESS               0.005120            0.005120               1.000               0.000          [torch.Size([1024, 1]), torch.Size([1024, 1])]
SUCCESS               0.005120            0.005120               1.000               0.006          [torch.Size([1024, 16]), torch.Size([1024, 16])]
SUCCESS               0.009216            0.009216               1.000               0.057          [torch.Size([1024, 256]), torch.Size([1024, 256])]
SUCCESS               0.064512            0.064512               1.000               0.130          [torch.Size([1024, 4096]), torch.Size([1024, 4096])]
SUCCESS               1.044480            0.954368               1.094               0.141          [torch.Size([1024, 65536]), torch.Size([1024, 65536])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]
SUCCESS               0.006144            0.006144               1.000               0.021          [torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]
SUCCESS               0.019456            0.019456               1.000               0.108          [torch.Size([64, 64, 256]), torch.Size([64, 64, 256])]
SUCCESS               0.242688            0.243712               0.996               0.138          [torch.Size([64, 64, 4096]), torch.Size([64, 64, 4096])]
SUCCESS               3.771392            3.811328               0.990               0.141          [torch.Size([64, 64, 65536]), torch.Size([64, 64, 65536])]


Operator: atan2  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup               TFLOPS          Size Detail
--------------------------------------------------------------------------------------------------------------------
SUCCESS               9.000960            9.400320               0.958               0.228          [torch.Size([1073741824]), torch.Size([1073741824])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64]), torch.Size([64, 64])]
SUCCESS               0.124928            0.149504               0.836               0.224          [torch.Size([4096, 4096]), torch.Size([4096, 4096])]
SUCCESS               0.124928            0.130048               0.961               0.258          [torch.Size([64, 512, 512]), torch.Size([64, 512, 512])]
SUCCESS               9.005056            9.529344               0.945               0.225          [torch.Size([1024, 1024, 1024]), torch.Size([1024, 1024, 1024])]
SUCCESS               0.005120            0.005120               1.000               0.000          [torch.Size([1024, 1]), torch.Size([1024, 1])]
SUCCESS               0.005120            0.005120               1.000               0.006          [torch.Size([1024, 16]), torch.Size([1024, 16])]
SUCCESS               0.007168            0.007168               1.000               0.073          [torch.Size([1024, 256]), torch.Size([1024, 256])]
SUCCESS               0.035840            0.035840               1.000               0.234          [torch.Size([1024, 4096]), torch.Size([1024, 4096])]
SUCCESS               0.580608            0.533504               1.088               0.252          [torch.Size([1024, 65536]), torch.Size([1024, 65536])]
SUCCESS               0.005120            0.005120               1.000               0.002          [torch.Size([64, 64, 1]), torch.Size([64, 64, 1])]
SUCCESS               0.005120            0.005120               1.000               0.026          [torch.Size([64, 64, 16]), torch.Size([64, 64, 16])]
SUCCESS               0.013312            0.012288               1.083               0.171          [torch.Size([64, 64, 256]), torch.Size([64, 64, 256])]
SUCCESS               0.124928            0.148480               0.841               0.226          [torch.Size([64, 64, 4096]), torch.Size([64, 64, 4096])]
SUCCESS               2.214912            2.320384               0.955               0.231          [torch.Size([64, 64, 65536]), torch.Size([64, 64, 65536])]

cc @0x45f

@jikechao
Copy link
Contributor Author

jikechao commented Nov 24, 2025

The accuracy test also passed!

========================================= warnings summary =========================================
tests/test_binary_pointwise_ops.py::test_accuracy_atan2[dtype0-shape0]
  /home/shenqingchao/miniconda3/lib/python3.12/site-packages/torch/library.py:323: UserWarning: Warning only once for all operators,  other operators may also be overridden.
    Overriding a previously registered kernel for the same operator and the same dispatch key
    operator: aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
      registered at aten/src/ATen/RegisterSchema.cpp:6
    dispatch key: CUDA
    previous kernel: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:17993
         new kernel: registered at /dev/null:1265 (Triggered internally at ../aten/src/ATen/core/dispatch/OperatorEntry.cpp:155.)
    self.m.impl(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================== 18 passed, 7104 deselected, 1 warning in 1.71s =========================

Removed duplicate import of atan2 from flag_gems.ops.

Signed-off-by: Qingchao Shen <[email protected]>
Removed duplicate entry for 'atan2' from the list of functions.

Signed-off-by: Qingchao Shen <[email protected]>
Signed-off-by: Qingchao Shen <[email protected]>
Signed-off-by: Qingchao Shen <[email protected]>
@jikechao jikechao changed the title Add atan2 operator [operator] Add atan2 Nov 25, 2025
Signed-off-by: Qingchao Shen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant