Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3fa3733
add page size 16 to test and op
ltqin Sep 8, 2025
113217e
add num_total_pages to kernel parameter
ltqin Sep 9, 2025
6e2d9e4
add is_sglang parameter
ltqin Sep 9, 2025
7a463b7
chang is_sglang to is_sglang_layout
ltqin Sep 9, 2025
ee72e04
kv last page size=16 pass
ltqin Sep 12, 2025
ae459b0
pass kv_last_page_lens to kernel
ltqin Sep 13, 2025
b25cee7
add parameters check before calling kernel
ltqin Sep 15, 2025
93754f4
change kv layout to [page_num, page_size, nhead, hdim]
ltqin Sep 17, 2025
8c52122
adopt the changes of struct fmha_fwd_batch_prefill_traits
Jeff-Huang Dec 13, 2025
9d7cd3f
change kv cache memory layout to [num_blocks, num_kv_heads, head_size…
Jeff-Huang Dec 19, 2025
e0cb1ea
[FMHA] Integrate vLLM block table support and enforce vectorized KV l…
Jeff-Huang Dec 24, 2025
ac28e9d
update CK
Jeff-Huang Dec 30, 2025
9d69a01
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Dec 30, 2025
688b141
update ck
Jeff-Huang Dec 30, 2025
0c9c886
adopt api changes from fmha_batch_prefill_traits
Jeff-Huang Dec 30, 2025
c75fee4
add support for linear kv cache layout
Jeff-Huang Dec 31, 2025
d144a76
update api
Jeff-Huang Dec 31, 2025
d727a92
Refactor the test code by gathering the different test functions into…
Jeff-Huang Dec 31, 2025
7642e79
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Dec 31, 2025
2917917
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 5, 2026
b1f452c
update ck
Jeff-Huang Jan 5, 2026
ed5f66a
update ck
Jeff-Huang Jan 5, 2026
f5cc627
Add profile measurements for batch prefill function
Jeff-Huang Jan 6, 2026
c7dd47f
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 7, 2026
9e10ffc
update ck
Jeff-Huang Jan 7, 2026
6a06de9
fix style
Jeff-Huang Jan 7, 2026
ae12e04
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 7, 2026
db5f333
fix style
Jeff-Huang Jan 7, 2026
44a5cc7
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 8, 2026
4de0de3
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 9, 2026
1ed076f
[FMHA] Support 3D linear layout (page_size=1) and non-contiguous KV t…
Jeff-Huang Jan 10, 2026
ec79599
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 12, 2026
ba88187
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 13, 2026
e7af363
update ck
Jeff-Huang Jan 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 251 files
2 changes: 1 addition & 1 deletion aiter/aot/test/matmul_fp16.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
Expand Down
5 changes: 4 additions & 1 deletion aiter/dist/device_communicators/communicator_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def all_reduce(
qr_comm is not None
and not qr_comm.disabled
and qr_comm.should_quick_allreduce(input_)
and (input_.nelement() * input_.element_size()) >= 4*1024*1024 # input shape should be such that quick reduce will show benefits.
and (input_.nelement() * input_.element_size())
>= 4
* 1024
* 1024 # input shape should be such that quick reduce will show benefits.
# input shape estimated at 2 * max concurrency for now. if performance issues, subject to change
):
out = qr_comm.quick_all_reduce(input_)
Expand Down
2 changes: 1 addition & 1 deletion aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def wrapper(*args, custom_build_args={}, **kwargs):
if module is None:
try:
module = get_module(md_name)
except Exception as e:
except Exception:
md = custom_build_args.get("md_name", md_name)
module = get_module(md)
except ModuleNotFoundError:
Expand Down
69 changes: 60 additions & 9 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from typing import Any, Optional, Tuple

Expand Down Expand Up @@ -973,6 +973,9 @@ def cmdGenFunc_mha_batch_prefill(
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
gen: Optional[Generator] = None,
kv_last_page_lens: Optional[Tensor] = None,
block_table: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
):
# causal=true is the same as causal=false in this case
causal = is_causal
Expand Down Expand Up @@ -2598,15 +2601,26 @@ def mha_batch_prefill_fake_tensors(
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
gen: Optional[Generator] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
seqlen_k: Optional[torch.Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
is_vectorized = k.dim() == 5 and v.dim() == 5
is_linear = (k.dim() == 4 and v.dim() == 4) or (k.dim() == 3 and v.dim() == 3)
if not (is_vectorized or is_linear):
raise ValueError(
"Batch prefill requires 5D vectorized, 4D linear, or 3D linear (page_size=1) K/V"
" tensors"
)
num_heads = q.size(1) # num_heads = q.sizes()[1]
head_size_v = v.size(2) # head_size_v = v.size(2)
head_size_v = v.size(-2) if is_vectorized else v.size(-1)
total_q = q.size(0) # total_q = q.size(0)

if out is None:
Expand Down Expand Up @@ -2671,6 +2685,9 @@ def mha_batch_prefill(
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[Tensor] = None,
block_table: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...

Expand All @@ -2696,6 +2713,9 @@ def _mha_batch_prefill(
return_softmax: bool = False,
zero_tensors: bool = False,
out: torch.Tensor = None,
kv_last_page_lens: torch.Tensor = None,
block_table: torch.Tensor = None,
seqlen_k: torch.Tensor = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -2726,6 +2746,9 @@ def _mha_batch_prefill(
q_descale,
k_descale,
v_descale,
kv_last_page_lens,
block_table,
seqlen_k,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
Expand All @@ -2750,19 +2773,44 @@ def mha_batch_prefill_func(
return_lse=False,
return_attn_probs=False,
out=None,
kv_last_page_lens=None,
block_table=None,
seqlen_k=None,
q_descale=None,
k_descale=None,
v_descale=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(2)
head_size_v_og = v.size(2)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
head_size_q_og = q.size(-1)
# 16 bytes = 128-bit (dwordx4) vector width assumed by CK kernels.
k_vector_size = 16 // k.element_size()
Comment thread
valarLip marked this conversation as resolved.
is_vectorized = k.dim() == 5 and v.dim() == 5
is_linear = (k.dim() == 4 and v.dim() == 4) or (k.dim() == 3 and v.dim() == 3)
if not (is_vectorized or is_linear):
raise ValueError(
"Batch prefill requires 5D vectorized, 4D linear, or 3D linear (page_size=1) K/V"
" tensors"
)
head_size_v_og = v.size(-2) if is_vectorized else v.size(-1)
if head_size_q_og % k_vector_size != 0 or head_size_v_og % k_vector_size != 0:
raise ValueError("Batch prefill requires head size divisible by vector size")
if is_vectorized:
if k.size(-3) * k_vector_size != head_size_q_og:
raise ValueError("K vectorized layout does not match Q head size")
if k.size(-2) % k_vector_size != 0:
raise ValueError(
"Vectorized KV requires page size divisible by vector size"
)
if v.size(-1) != k_vector_size:
raise ValueError("Vectorized KV requires last dim equal to vector size")
else:
if k.size(-1) != head_size_q_og:
raise ValueError("K linear layout does not match Q head size")
if k.size(1) != v.size(1) or k.size(2) != v.size(2):
raise ValueError("K/V linear layout must match page size and head count")
if k.stride(-1) != 1 or v.stride(-1) != 1:
raise ValueError("Batch prefill requires K/V with contiguous last dimension")
out_padded, softmax_lse, S_dmask, rng_state = _mha_batch_prefill(
q,
k,
Expand All @@ -2782,6 +2830,9 @@ def mha_batch_prefill_func(
return_lse=return_lse,
return_softmax=return_attn_probs and dropout_p > 0,
out=out,
kv_last_page_lens=kv_last_page_lens,
block_table=block_table,
seqlen_k=seqlen_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/moe_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import torch
from torch import Tensor
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import importlib.util
import sys
Expand Down Expand Up @@ -42,7 +42,7 @@
)

"""
These following help implement backward-compatibility
These following help implement backward-compatibility
for modules that were reorganized so that external repos (like sglang for example),
which depend on the old module names, can still import it the old "way" of importing.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# The kernel in this file is adapted from the VLLM project:
# https://github.com/ROCm/vllm/blob/aiter_integration_final/vllm/attention/ops/chunked_prefill_paged_decode.py
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2023-2025 SGLang Team
# Copyright (C) 2023-2026 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2025, The vLLM team.
# Copyright (C) 2024-2026, The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/_triton_kernels/attention/lean_atten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

"""
Lean Attention
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/_triton_kernels/attention/mha.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import functools
import json
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import functools
import json
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import functools
import json
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/_triton_kernels/attention/mla_decode_rope.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# Copyright (C) 2023-2025 SGLang Team
# Copyright (C) 2023-2026 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/_triton_kernels/attention/pa_decode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/_triton_kernels/attention/pa_prefill.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# Copyright (C) 2023-2025 SGLang Team
# Copyright (C) 2023-2026 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton._triton_kernels.quant.fused_fp8_quant import _fp8_quant_op
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
2 changes: 1 addition & 1 deletion aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8wfp4.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import triton.language as tl
from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd
Expand Down
Loading