diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index e339101e9c..c9f112b026 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit e339101e9c9961fe1bc8305d5c316b39d1980d3e +Subproject commit c9f112b0267625016a58ce3465ee34232c85812b diff --git a/aiter/aot/test/matmul_fp16.py b/aiter/aot/test/matmul_fp16.py index f12c1623c4..1d419f67c8 100644 --- a/aiter/aot/test/matmul_fp16.py +++ b/aiter/aot/test/matmul_fp16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index b37af44810..4094ca1be6 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -155,7 +155,10 @@ def all_reduce( qr_comm is not None and not qr_comm.disabled and qr_comm.should_quick_allreduce(input_) - and (input_.nelement() * input_.element_size()) >= 4*1024*1024 # input shape should be such that quick reduce will show benefits. + and (input_.nelement() * input_.element_size()) + >= 4 + * 1024 + * 1024 # input shape should be such that quick reduce will show benefits. # input shape estimated at 2 * max concurrency for now. if performance issues, subject to change ): out = qr_comm.quick_all_reduce(input_) diff --git a/aiter/jit/core.py b/aiter/jit/core.py index 73849a05a5..7e446636fc 100644 --- a/aiter/jit/core.py +++ b/aiter/jit/core.py @@ -818,7 +818,7 @@ def wrapper(*args, custom_build_args={}, **kwargs): if module is None: try: module = get_module(md_name) - except Exception as e: + except Exception: md = custom_build_args.get("md_name", md_name) module = get_module(md) except ModuleNotFoundError: diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 2674a772b1..0005321c17 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Any, Optional, Tuple @@ -973,6 +973,9 @@ def cmdGenFunc_mha_batch_prefill( k_descale: Optional[Tensor] = None, v_descale: Optional[Tensor] = None, gen: Optional[Generator] = None, + kv_last_page_lens: Optional[Tensor] = None, + block_table: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, ): # causal=true is the same as causal=false in this case causal = is_causal @@ -2598,15 +2601,26 @@ def mha_batch_prefill_fake_tensors( return_softmax_lse: bool, return_dropout_randval: bool, out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, gen: Optional[Generator] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + seqlen_k: Optional[torch.Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + is_vectorized = k.dim() == 5 and v.dim() == 5 + is_linear = (k.dim() == 4 and v.dim() == 4) or (k.dim() == 3 and v.dim() == 3) + if not (is_vectorized or is_linear): + raise ValueError( + "Batch prefill requires 5D vectorized, 4D linear, or 3D linear (page_size=1) K/V" + " tensors" + ) num_heads = q.size(1) # num_heads = q.sizes()[1] - head_size_v = v.size(2) # head_size_v = v.size(2) + head_size_v = v.size(-2) if is_vectorized else v.size(-1) total_q = q.size(0) # total_q = q.size(0) if out is None: @@ -2671,6 +2685,9 @@ def mha_batch_prefill( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[Tensor] = None, + block_table: Optional[Tensor] = None, + seqlen_k: Optional[Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -2696,6 +2713,9 @@ def _mha_batch_prefill( return_softmax: bool = False, zero_tensors: bool = False, out: torch.Tensor = None, + kv_last_page_lens: torch.Tensor = None, + block_table: torch.Tensor = None, + seqlen_k: torch.Tensor = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -2726,6 +2746,9 @@ def _mha_batch_prefill( q_descale, k_descale, v_descale, + kv_last_page_lens, + block_table, + seqlen_k, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -2750,19 +2773,44 @@ def mha_batch_prefill_func( return_lse=False, return_attn_probs=False, out=None, + kv_last_page_lens=None, + block_table=None, + seqlen_k=None, q_descale=None, k_descale=None, v_descale=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - head_size_q_og = q.size(2) - head_size_v_og = v.size(2) - if head_size_q_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8]) - if head_size_v_og % 8 != 0: - v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8]) + head_size_q_og = q.size(-1) + # 16 bytes = 128-bit (dwordx4) vector width assumed by CK kernels. + k_vector_size = 16 // k.element_size() + is_vectorized = k.dim() == 5 and v.dim() == 5 + is_linear = (k.dim() == 4 and v.dim() == 4) or (k.dim() == 3 and v.dim() == 3) + if not (is_vectorized or is_linear): + raise ValueError( + "Batch prefill requires 5D vectorized, 4D linear, or 3D linear (page_size=1) K/V" + " tensors" + ) + head_size_v_og = v.size(-2) if is_vectorized else v.size(-1) + if head_size_q_og % k_vector_size != 0 or head_size_v_og % k_vector_size != 0: + raise ValueError("Batch prefill requires head size divisible by vector size") + if is_vectorized: + if k.size(-3) * k_vector_size != head_size_q_og: + raise ValueError("K vectorized layout does not match Q head size") + if k.size(-2) % k_vector_size != 0: + raise ValueError( + "Vectorized KV requires page size divisible by vector size" + ) + if v.size(-1) != k_vector_size: + raise ValueError("Vectorized KV requires last dim equal to vector size") + else: + if k.size(-1) != head_size_q_og: + raise ValueError("K linear layout does not match Q head size") + if k.size(1) != v.size(1) or k.size(2) != v.size(2): + raise ValueError("K/V linear layout must match page size and head count") + if k.stride(-1) != 1 or v.stride(-1) != 1: + raise ValueError("Batch prefill requires K/V with contiguous last dimension") out_padded, softmax_lse, S_dmask, rng_state = _mha_batch_prefill( q, k, @@ -2782,6 +2830,9 @@ def mha_batch_prefill_func( return_lse=return_lse, return_softmax=return_attn_probs and dropout_p > 0, out=out, + kv_last_page_lens=kv_last_page_lens, + block_table=block_table, + seqlen_k=seqlen_k, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, diff --git a/aiter/ops/moe_op.py b/aiter/ops/moe_op.py index a8f36637d7..3068dff1e1 100755 --- a/aiter/ops/moe_op.py +++ b/aiter/ops/moe_op.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch from torch import Tensor diff --git a/aiter/ops/triton/__init__.py b/aiter/ops/triton/__init__.py index fc10be22fb..d09175d91c 100644 --- a/aiter/ops/triton/__init__.py +++ b/aiter/ops/triton/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import importlib.util import sys @@ -42,7 +42,7 @@ ) """ -These following help implement backward-compatibility +These following help implement backward-compatibility for modules that were reorganized so that external repos (like sglang for example), which depend on the old module names, can still import it the old "way" of importing. """ diff --git a/aiter/ops/triton/_triton_kernels/attention/chunked_pa_prefill.py b/aiter/ops/triton/_triton_kernels/attention/chunked_pa_prefill.py index 889d3631c9..1429ca99a2 100644 --- a/aiter/ops/triton/_triton_kernels/attention/chunked_pa_prefill.py +++ b/aiter/ops/triton/_triton_kernels/attention/chunked_pa_prefill.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # The kernel in this file is adapted from the VLLM project: # https://github.com/ROCm/vllm/blob/aiter_integration_final/vllm/attention/ops/chunked_prefill_paged_decode.py diff --git a/aiter/ops/triton/_triton_kernels/attention/extend_attention.py b/aiter/ops/triton/_triton_kernels/attention/extend_attention.py index 9ba1d04097..c71908e497 100644 --- a/aiter/ops/triton/_triton_kernels/attention/extend_attention.py +++ b/aiter/ops/triton/_triton_kernels/attention/extend_attention.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/_triton_kernels/attention/hstu_attention.py b/aiter/ops/triton/_triton_kernels/attention/hstu_attention.py index 59ac5ab620..4eea668c74 100644 --- a/aiter/ops/triton/_triton_kernels/attention/hstu_attention.py +++ b/aiter/ops/triton/_triton_kernels/attention/hstu_attention.py @@ -1,5 +1,5 @@ # Copyright (C) Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2024-2025, The vLLM team. +# Copyright (C) 2024-2026, The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/aiter/ops/triton/_triton_kernels/attention/lean_atten.py b/aiter/ops/triton/_triton_kernels/attention/lean_atten.py index 73fd70f430..20dfe36610 100644 --- a/aiter/ops/triton/_triton_kernels/attention/lean_atten.py +++ b/aiter/ops/triton/_triton_kernels/attention/lean_atten.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. """ Lean Attention diff --git a/aiter/ops/triton/_triton_kernels/attention/mha.py b/aiter/ops/triton/_triton_kernels/attention/mha.py index b3acb81f44..610b36c30a 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import functools import json diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py index da98768b2c..f774867ebc 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import functools import json diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py index c6015c2d30..b845b781f0 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_onekernel_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import functools import json diff --git a/aiter/ops/triton/_triton_kernels/attention/mla_decode_rope.py b/aiter/ops/triton/_triton_kernels/attention/mla_decode_rope.py index ed783b0619..acc03bf28d 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mla_decode_rope.py +++ b/aiter/ops/triton/_triton_kernels/attention/mla_decode_rope.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/_triton_kernels/attention/pa_decode.py b/aiter/ops/triton/_triton_kernels/attention/pa_decode.py index 4499e9b234..4c5d62b5a7 100644 --- a/aiter/ops/triton/_triton_kernels/attention/pa_decode.py +++ b/aiter/ops/triton/_triton_kernels/attention/pa_decode.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/attention/pa_mqa_logits.py b/aiter/ops/triton/_triton_kernels/attention/pa_mqa_logits.py index 0bbe5f6096..d6df12174a 100644 --- a/aiter/ops/triton/_triton_kernels/attention/pa_mqa_logits.py +++ b/aiter/ops/triton/_triton_kernels/attention/pa_mqa_logits.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/attention/pa_prefill.py b/aiter/ops/triton/_triton_kernels/attention/pa_prefill.py index c511eccf65..82d36a1956 100644 --- a/aiter/ops/triton/_triton_kernels/attention/pa_prefill.py +++ b/aiter/ops/triton/_triton_kernels/attention/pa_prefill.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # The kernels in this file are adapted from LightLLM's context_attention_fwd: # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py diff --git a/aiter/ops/triton/_triton_kernels/attention/prefill_attention.py b/aiter/ops/triton/_triton_kernels/attention/prefill_attention.py index 72b1ff8e8f..690feb60c1 100644 --- a/aiter/ops/triton/_triton_kernels/attention/prefill_attention.py +++ b/aiter/ops/triton/_triton_kernels/attention/prefill_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_atomic.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_atomic.py index 71e0d4c2f6..dca0c22a83 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_atomic.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_gated.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_gated.py index eeeaa57a09..2b9dae137d 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_gated.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w16_gated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w8_blockscale.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w8_blockscale.py index fe1e927791..5910ed0e64 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w8_blockscale.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16w8_blockscale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton._triton_kernels.quant.fused_fp8_quant import _fp8_quant_op diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py index 0528288add..30f81e5f03 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a16wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8.py index 1aa6659bcd..cd4ec07ab0 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_blockscale.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_blockscale.py index 943d6032da..1d75520b58 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_blockscale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_per_token_scale.py index cb0ef83816..686ccba53e 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8w8_per_token_scale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8wfp4.py index 18721ab392..66e8b41653 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_a8wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py index e389cb292e..b420eabea3 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/basic/gemm_afp4wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py index 120d4ff0af..4c433c4ab6 100755 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a16wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton._triton_kernels.quant.quant import _mxfp4_quant_op diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8.py index 89d8384aad..29f0a34443 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index 5e5241a51b..bffb53963a 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_afp4wfp4.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_afp4wfp4.py index 66e7bb0282..54ef0fd89e 100755 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_afp4wfp4.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_afp4wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py index 029f4e57e2..6d3554a118 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py +++ b/aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_bf16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_gated.py b/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_gated.py index 854ace7086..86c87a545a 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_gated.py +++ b/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_gated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd diff --git a/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_ungated.py b/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_ungated.py index 331b7c0ded..3365fcef7c 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_ungated.py +++ b/aiter/ops/triton/_triton_kernels/gemm/feed_forward/ff_a16w16_fused_ungated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py index 2290f8b92e..3c73b94389 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py index 83c78c1b48..cf36ee4b4c 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_a16w16.py index 05601f7df5..611532300e 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_mul_add.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_mul_add.py index 34af0f38da..0f55772fe5 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_mul_add.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_mul_add.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton.language as tl from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr diff --git a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_split_cat.py b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_split_cat.py index cf8831d504..3c9bc52985 100644 --- a/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_split_cat.py +++ b/aiter/ops/triton/_triton_kernels/gemm/fused/fused_gemm_afp4wfp4_split_cat.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/gmm.py b/aiter/ops/triton/_triton_kernels/gmm.py index b1baf95568..71e643cd40 100644 --- a/aiter/ops/triton/_triton_kernels/gmm.py +++ b/aiter/ops/triton/_triton_kernels/gmm.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Imports. diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_align_block_size.py b/aiter/ops/triton/_triton_kernels/moe/moe_align_block_size.py index b135454adc..97d8ba1095 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_align_block_size.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_align_block_size.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op.py b/aiter/ops/triton/_triton_kernels/moe/moe_op.py index 137dea368e..b2147bcf89 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_e2e.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_e2e.py index 659f58b918..7f22c3c88e 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_e2e.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_e2e.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_gelu.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_gelu.py index f3a0dde8d4..d41a68605e 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_gelu.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_gelu.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4.py index 50d1ac5ca9..b46741cfaa 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4_silu_fused.py index ab042f878d..97d66598e3 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_mxfp4_silu_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_op_silu_fused.py b/aiter/ops/triton/_triton_kernels/moe/moe_op_silu_fused.py index 8f99e0f9d2..702e3e076f 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_op_silu_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_op_silu_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/moe/moe_routing_sigmoid_top1_fused.py b/aiter/ops/triton/_triton_kernels/moe/moe_routing_sigmoid_top1_fused.py index 4fbfec9bb2..8434028428 100644 --- a/aiter/ops/triton/_triton_kernels/moe/moe_routing_sigmoid_top1_fused.py +++ b/aiter/ops/triton/_triton_kernels/moe/moe_routing_sigmoid_top1_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import functools import json diff --git a/aiter/ops/triton/_triton_kernels/moe/quant_moe.py b/aiter/ops/triton/_triton_kernels/moe/quant_moe.py index f6cf1431e5..121d76f362 100644 --- a/aiter/ops/triton/_triton_kernels/moe/quant_moe.py +++ b/aiter/ops/triton/_triton_kernels/moe/quant_moe.py @@ -405,7 +405,7 @@ def _upcast_from_mxfp( # 3) x is zero, do nothing dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True) - # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping. + # Reshape for proper broadcasting: the scale was stored with a 32-sized "inner" grouping. dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) scale = scale.reshape(dst_scale.shape) diff --git a/aiter/ops/triton/_triton_kernels/normalization/norm.py b/aiter/ops/triton/_triton_kernels/normalization/norm.py index 77b7e60410..50068e12ae 100644 --- a/aiter/ops/triton/_triton_kernels/normalization/norm.py +++ b/aiter/ops/triton/_triton_kernels/normalization/norm.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/normalization/rmsnorm.py b/aiter/ops/triton/_triton_kernels/normalization/rmsnorm.py index ac478dd5d7..7e889b4b1d 100644 --- a/aiter/ops/triton/_triton_kernels/normalization/rmsnorm.py +++ b/aiter/ops/triton/_triton_kernels/normalization/rmsnorm.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/quant/quant.py b/aiter/ops/triton/_triton_kernels/quant/quant.py index 3773fb077b..3b88c8b2b2 100644 --- a/aiter/ops/triton/_triton_kernels/quant/quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/quant.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/rope/rope.py b/aiter/ops/triton/_triton_kernels/rope/rope.py index 077eb23f8f..85c88d03e6 100644 --- a/aiter/ops/triton/_triton_kernels/rope/rope.py +++ b/aiter/ops/triton/_triton_kernels/rope/rope.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import triton.language as tl diff --git a/aiter/ops/triton/_triton_kernels/topk.py b/aiter/ops/triton/_triton_kernels/topk.py index 1f6d8f9536..30bbe18819 100644 --- a/aiter/ops/triton/_triton_kernels/topk.py +++ b/aiter/ops/triton/_triton_kernels/topk.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # The kernel in this file is adapted from FlagGems' topk: # https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/topk.py diff --git a/aiter/ops/triton/attention/chunked_pa_prefill.py b/aiter/ops/triton/attention/chunked_pa_prefill.py index 2f40e4f30c..b791834838 100644 --- a/aiter/ops/triton/attention/chunked_pa_prefill.py +++ b/aiter/ops/triton/attention/chunked_pa_prefill.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # The kernel in this file is adapted from the VLLM project: # https://github.com/ROCm/vllm/blob/aiter_integration_final/vllm/attention/ops/chunked_prefill_paged_decode.py diff --git a/aiter/ops/triton/attention/extend_attention.py b/aiter/ops/triton/attention/extend_attention.py index 1a7dc1d085..b45c4a0998 100644 --- a/aiter/ops/triton/attention/extend_attention.py +++ b/aiter/ops/triton/attention/extend_attention.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/attention/hstu_attention.py b/aiter/ops/triton/attention/hstu_attention.py index 3bf333f51c..344c8efcd6 100644 --- a/aiter/ops/triton/attention/hstu_attention.py +++ b/aiter/ops/triton/attention/hstu_attention.py @@ -1,5 +1,5 @@ # Copyright (C) Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2024-2025, The vLLM team. +# Copyright (C) 2024-2026, The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/aiter/ops/triton/attention/lean_atten.py b/aiter/ops/triton/attention/lean_atten.py index 8f981f9613..74839501eb 100644 --- a/aiter/ops/triton/attention/lean_atten.py +++ b/aiter/ops/triton/attention/lean_atten.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. """ Lean Attention diff --git a/aiter/ops/triton/attention/mha.py b/aiter/ops/triton/attention/mha.py index ae8bbf78fd..e4332046c3 100644 --- a/aiter/ops/triton/attention/mha.py +++ b/aiter/ops/triton/attention/mha.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Tuple, Union import torch diff --git a/aiter/ops/triton/attention/mha_fused_bwd.py b/aiter/ops/triton/attention/mha_fused_bwd.py index 6049cb6b36..c9cc67743d 100644 --- a/aiter/ops/triton/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/attention/mha_fused_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Dict import torch diff --git a/aiter/ops/triton/attention/mha_onekernel_bwd.py b/aiter/ops/triton/attention/mha_onekernel_bwd.py index db79aa2c3d..7205bb1150 100644 --- a/aiter/ops/triton/attention/mha_onekernel_bwd.py +++ b/aiter/ops/triton/attention/mha_onekernel_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Dict import torch diff --git a/aiter/ops/triton/attention/mha_v3.py b/aiter/ops/triton/attention/mha_v3.py index c2fa24c769..6b93c99581 100644 --- a/aiter/ops/triton/attention/mha_v3.py +++ b/aiter/ops/triton/attention/mha_v3.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from __future__ import annotations from typing import Optional, Tuple, Union diff --git a/aiter/ops/triton/attention/mla_decode_rope.py b/aiter/ops/triton/attention/mla_decode_rope.py index 9332b6d8a6..9e2532c743 100644 --- a/aiter/ops/triton/attention/mla_decode_rope.py +++ b/aiter/ops/triton/attention/mla_decode_rope.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/attention/pa_decode.py b/aiter/ops/triton/attention/pa_decode.py index 50d0dfe2fa..3a38ca4c8c 100644 --- a/aiter/ops/triton/attention/pa_decode.py +++ b/aiter/ops/triton/attention/pa_decode.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import math from typing import Optional diff --git a/aiter/ops/triton/attention/pa_mqa_logits.py b/aiter/ops/triton/attention/pa_mqa_logits.py index f5d9573d7c..1f28abf18b 100644 --- a/aiter/ops/triton/attention/pa_mqa_logits.py +++ b/aiter/ops/triton/attention/pa_mqa_logits.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # ======================================================================== # How to use AOT gluon kernel for pa_mqa_logits on lower triton version (below 3.4.0): diff --git a/aiter/ops/triton/attention/pa_prefill.py b/aiter/ops/triton/attention/pa_prefill.py index 2be4bb4fc0..7ce3f1c814 100644 --- a/aiter/ops/triton/attention/pa_prefill.py +++ b/aiter/ops/triton/attention/pa_prefill.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # The kernels in this file are adapted from LightLLM's context_attention_fwd: # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py diff --git a/aiter/ops/triton/attention/prefill_attention.py b/aiter/ops/triton/attention/prefill_attention.py index e8ae05cccc..f5fdac2d9d 100644 --- a/aiter/ops/triton/attention/prefill_attention.py +++ b/aiter/ops/triton/attention/prefill_attention.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -# Copyright (C) 2023-2025 SGLang Team +# Copyright (C) 2023-2026 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py b/aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py index 447a502352..259a54ecc6 100644 --- a/aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py +++ b/aiter/ops/triton/comms/fused/reduce_scatter_rmsnorm_quant_all_gather.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. """ Fused Reduce-Scatter + RMSNorm + Quantization + All-Gather diff --git a/aiter/ops/triton/gemm/basic/gemm_a16w16.py b/aiter/ops/triton/gemm/basic/gemm_a16w16.py index 8e98407863..d06ba0eb3a 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16w16.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a16w16_agnostic.py b/aiter/ops/triton/gemm/basic/gemm_a16w16_agnostic.py index d56f3bd452..d04d56626c 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16w16_agnostic.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16w16_agnostic.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import triton diff --git a/aiter/ops/triton/gemm/basic/gemm_a16w16_atomic.py b/aiter/ops/triton/gemm/basic/gemm_a16w16_atomic.py index 3622837d2a..01234f64c5 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16w16_atomic.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16w16_atomic.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a16w16_gated.py b/aiter/ops/triton/gemm/basic/gemm_a16w16_gated.py index 0ebddce0e6..3ce568a941 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16w16_gated.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16w16_gated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a16w8_blockscale.py b/aiter/ops/triton/gemm/basic/gemm_a16w8_blockscale.py index aeeb013aa6..3105641bcd 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16w8_blockscale.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16w8_blockscale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py index fcef64d6d6..ab691ba860 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a16wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a8w8.py b/aiter/ops/triton/gemm/basic/gemm_a8w8.py index eb65ffe6df..40b1fe8f68 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a8w8.py +++ b/aiter/ops/triton/gemm/basic/gemm_a8w8.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a8w8_blockscale.py b/aiter/ops/triton/gemm/basic/gemm_a8w8_blockscale.py index d6e9954d1b..ca28714eaa 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gemm/basic/gemm_a8w8_blockscale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a8w8_per_token_scale.py b/aiter/ops/triton/gemm/basic/gemm_a8w8_per_token_scale.py index e0ea2e6428..34dcd4873b 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a8w8_per_token_scale.py +++ b/aiter/ops/triton/gemm/basic/gemm_a8w8_per_token_scale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_a8wfp4.py b/aiter/ops/triton/gemm/basic/gemm_a8wfp4.py index 884d95da18..3a16dd9ecb 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a8wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_a8wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py b/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py index 19c7948d5d..763090904d 100644 --- a/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/basic/gemm_afp4wfp4_pre_quant_atomic.py b/aiter/ops/triton/gemm/basic/gemm_afp4wfp4_pre_quant_atomic.py index 86b87530f4..da26736a96 100644 --- a/aiter/ops/triton/gemm/basic/gemm_afp4wfp4_pre_quant_atomic.py +++ b/aiter/ops/triton/gemm/basic/gemm_afp4wfp4_pre_quant_atomic.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_a16wfp4.py b/aiter/ops/triton/gemm/batched/batched_gemm_a16wfp4.py index 4622416c43..11b41b0b47 100755 --- a/aiter/ops/triton/gemm/batched/batched_gemm_a16wfp4.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_a16wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_a8w8.py b/aiter/ops/triton/gemm/batched/batched_gemm_a8w8.py index 596e16d850..875b17e696 100644 --- a/aiter/ops/triton/gemm/batched/batched_gemm_a8w8.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_a8w8.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py b/aiter/ops/triton/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py index 51c379e695..a413116877 100644 --- a/aiter/ops/triton/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4.py b/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4.py index ed930be0c0..40acb42adc 100755 --- a/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4_pre_quant.py b/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4_pre_quant.py index 48ae17a996..b6548bff98 100755 --- a/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4_pre_quant.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_afp4wfp4_pre_quant.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/batched/batched_gemm_bf16.py b/aiter/ops/triton/gemm/batched/batched_gemm_bf16.py index eb4cdd4343..43ca825a10 100644 --- a/aiter/ops/triton/gemm/batched/batched_gemm_bf16.py +++ b/aiter/ops/triton/gemm/batched/batched_gemm_bf16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/feed_forward/ff_a16w16.py b/aiter/ops/triton/gemm/feed_forward/ff_a16w16.py index c4326f6279..f579b78366 100644 --- a/aiter/ops/triton/gemm/feed_forward/ff_a16w16.py +++ b/aiter/ops/triton/gemm/feed_forward/ff_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_gated.py b/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_gated.py index 856dc880d4..feca3bff30 100644 --- a/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_gated.py +++ b/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_gated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_ungated.py b/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_ungated.py index dbf0da5eee..3428ade14b 100644 --- a/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_ungated.py +++ b/aiter/ops/triton/gemm/feed_forward/ff_a16w16_fused_ungated.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py b/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py index 3cb6f2a5b5..ae76e79012 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py b/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py index ee6dda5999..2104f612dd 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_a8w8_blockscale_mul_add.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Union import torch diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_a16w16.py b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_a16w16.py index b56c54b76f..9728919c94 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_a16w16.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_a16w16.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import os diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_mul_add.py b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_mul_add.py index 40121d7a01..0d1956ae0d 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_mul_add.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_mul_add.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional, Union import torch diff --git a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py index 92ef460c4b..89a43c8f50 100644 --- a/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py +++ b/aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py b/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py index 3bb3d7519a..499f2025d4 100644 --- a/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py +++ b/aiter/ops/triton/gluon/gemm_a8w8_blockscale.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import functools diff --git a/aiter/ops/triton/moe/moe_align_block_size.py b/aiter/ops/triton/moe/moe_align_block_size.py index b0918981e7..66bb631c25 100644 --- a/aiter/ops/triton/moe/moe_align_block_size.py +++ b/aiter/ops/triton/moe/moe_align_block_size.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch from aiter.ops.triton.utils.logger import AiterTritonLogger diff --git a/aiter/ops/triton/moe/moe_op.py b/aiter/ops/triton/moe/moe_op.py index 9bdcfdf48f..677afc4ec6 100644 --- a/aiter/ops/triton/moe/moe_op.py +++ b/aiter/ops/triton/moe/moe_op.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_op_e2e.py b/aiter/ops/triton/moe/moe_op_e2e.py index d251ae5aae..f571ffdcd6 100644 --- a/aiter/ops/triton/moe/moe_op_e2e.py +++ b/aiter/ops/triton/moe/moe_op_e2e.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_op_gelu.py b/aiter/ops/triton/moe/moe_op_gelu.py index 492c3a49e2..6abac7ebe0 100644 --- a/aiter/ops/triton/moe/moe_op_gelu.py +++ b/aiter/ops/triton/moe/moe_op_gelu.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_op_mxfp4.py b/aiter/ops/triton/moe/moe_op_mxfp4.py index d58d9f399e..2143c74cc3 100644 --- a/aiter/ops/triton/moe/moe_op_mxfp4.py +++ b/aiter/ops/triton/moe/moe_op_mxfp4.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_op_mxfp4_silu_fused.py b/aiter/ops/triton/moe/moe_op_mxfp4_silu_fused.py index 493f4d66b3..32dd7b3a67 100644 --- a/aiter/ops/triton/moe/moe_op_mxfp4_silu_fused.py +++ b/aiter/ops/triton/moe/moe_op_mxfp4_silu_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_op_silu_fused.py b/aiter/ops/triton/moe/moe_op_silu_fused.py index 646f0897aa..b9718e9486 100644 --- a/aiter/ops/triton/moe/moe_op_silu_fused.py +++ b/aiter/ops/triton/moe/moe_op_silu_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/moe/moe_routing_sigmoid_top1_fused.py b/aiter/ops/triton/moe/moe_routing_sigmoid_top1_fused.py index 596e6d2856..f6898d2f91 100644 --- a/aiter/ops/triton/moe/moe_routing_sigmoid_top1_fused.py +++ b/aiter/ops/triton/moe/moe_routing_sigmoid_top1_fused.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from typing import Optional import torch diff --git a/aiter/ops/triton/normalization/norm.py b/aiter/ops/triton/normalization/norm.py index 4305e9a921..11d862acd3 100644 --- a/aiter/ops/triton/normalization/norm.py +++ b/aiter/ops/triton/normalization/norm.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/normalization/rmsnorm.py b/aiter/ops/triton/normalization/rmsnorm.py index f08549ffe8..a4127feec2 100644 --- a/aiter/ops/triton/normalization/rmsnorm.py +++ b/aiter/ops/triton/normalization/rmsnorm.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/ops/triton/quant/quant.py b/aiter/ops/triton/quant/quant.py index d6ee3f0387..0883d78df0 100644 --- a/aiter/ops/triton/quant/quant.py +++ b/aiter/ops/triton/quant/quant.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import triton import torch diff --git a/aiter/ops/triton/rope/rope.py b/aiter/ops/triton/rope/rope.py index 570963e512..b02927d3ce 100644 --- a/aiter/ops/triton/rope/rope.py +++ b/aiter/ops/triton/rope/rope.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import triton diff --git a/aiter/rotary_embedding.py b/aiter/rotary_embedding.py index 32bbe43dc9..1f1a1d3068 100644 --- a/aiter/rotary_embedding.py +++ b/aiter/rotary_embedding.py @@ -2,8 +2,8 @@ # coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py -# Copyright (C) 2023-2025 The vLLM team. -# Copyright (C) 2022-2025 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright (C) 2023-2026 The vLLM team. +# Copyright (C) 2022-2026 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index a5a837e222..dba8279da3 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch from ..jit.utils.chip_info import get_gfx from ..ops.enum import QuantType @@ -58,9 +58,9 @@ def str2bool(v): def str2tuple(v): """ Convert string to int or tuple of ints. - - "512" → 512 (single value without comma returns int) - - "512," → (512,) (trailing comma returns tuple) - - "512,1024" → (512, 1024) (multiple values return tuple) + - "512" -> 512 (single value without comma returns int) + - "512," -> (512,) (trailing comma returns tuple) + - "512,1024" -> (512, 1024) (multiple values return tuple) """ try: parts = [int(p.strip()) for p in v.strip("()").split(",") if p.strip()] diff --git a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py index 32a375832e..ca3a59c7ee 100644 --- a/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py +++ b/csrc/ck_batched_gemm_a8w8/batched_gemm_a8w8_tune.py @@ -7,7 +7,6 @@ from aiter.jit.core import AITER_CONFIG_A8W8_BATCHED_GEMM from aiter.utility.base_tuner import GemmCommonTuner from batched_gemm_a8w8_common import kernels_list -import argparse from aiter.utility.mp_tuner import mp_tuner diff --git a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py index b0e8990b35..e03891aa12 100644 --- a/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py +++ b/csrc/ck_batched_gemm_bf16/batched_gemm_bf16_tune.py @@ -8,7 +8,6 @@ from aiter import dtypes from batched_gemm_bf16_common import kernels_list from aiter.utility.mp_tuner import mp_tuner -import argparse def run_torch(x, weight, bias=None, dtype=dtypes.bf16): diff --git a/csrc/ck_deepgemm/include/deepgemm_common.cuh b/csrc/ck_deepgemm/include/deepgemm_common.cuh index 801ff24ea1..a524e25e32 100644 --- a/csrc/ck_deepgemm/include/deepgemm_common.cuh +++ b/csrc/ck_deepgemm/include/deepgemm_common.cuh @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -132,11 +132,10 @@ void grouped_flatmm(KernelArguments& args, ck_stream_config& s) const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem 0: profile_result = pd.concat(prorfiles) profile_result["err"] = profile_result["err"].apply(lambda x: f"{x:.1%}") - profile_file = f"aiter/configs/profile_fmoe.csv" + profile_file = "aiter/configs/profile_fmoe.csv" old_profile = self.get_tuned_gemm_list( profile_file, profile_result.columns.tolist() ) diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h index 6ae10e9bf6..8069534061 100644 --- a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages.h @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. // #include "moe_flatmm.hpp" #include "ck_tile/core.hpp" @@ -20,7 +20,7 @@ #include #include -using MoeKernel = std::function -struct moe_gemm1_heuristic_dispatcher{ - static MoeKernel dispatch(int M, int N, int K, int block_m){} +template +struct moe_gemm1_heuristic_dispatcher +{ + static MoeKernel dispatch(int M, int N, int K, int block_m) {} }; - -template -struct moe_gemm2_heuristic_dispatcher{ - static MoeKernel dispatch(int M, int N, int K, int block_m){} +template +struct moe_gemm2_heuristic_dispatcher +{ + static MoeKernel dispatch(int M, int N, int K, int block_m) {} }; __attribute__((visibility("default"))) torch::Tensor diff --git a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh index b9c0ba84a6..38af9b3bce 100644 --- a/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh +++ b/csrc/ck_tile_gemm_moe_2stages/include/moe_cktile2stages_common.cuh @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" @@ -67,13 +67,13 @@ struct MoeFlatmmConfig __host__ static constexpr int32_t GetBMemNTType(int32_t M, int32_t N, int32_t K) { - (void)N; - (void)K; - if(M <= 416) - { - return 2; - } - return 0; + (void)N; + (void)K; + if(M <= 416) + { + return 2; + } + return 0; } template (b_mem_nt_type_.value); @@ -232,7 +232,6 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1>, ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1>; - // TODO: support more act type. using FusedAct = std::conditional_t; @@ -322,33 +321,28 @@ void moe_gemm(const MoeFlatmmHostArgs& args, const ck_stream_config& s) // return ave_time; }; - const auto RunBMem = - [&](const auto has_hot_loop_, const auto tail_number_) { - switch(b_mem_nt_type) - { - case 2: { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - break; - default: { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - } - }; + const auto RunBMem = [&](const auto has_hot_loop_, const auto tail_number_) { + switch(b_mem_nt_type) + { + case 2: { + Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); + } + break; + default: { + Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); + } + } + }; if(tail_num == ck_tile::TailNumber::Odd) { RunBMem(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { RunBMem(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + ck_tile::integral_constant{}); } else { diff --git a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py index 765a3228dd..c04f9b62ca 100644 --- a/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py +++ b/csrc/ck_tile_gemm_moe_2stages/moe_cktile2stages_common.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. from dataclasses import dataclass import os import sys diff --git a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh index 9a07ebba97..06ef0f9e9f 100644 --- a/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh +++ b/csrc/cktile_gemm_a8w8_bpreshuffle/include/gemm_a8w8_bpreshuffle_cktile_common.cuh @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #ifdef USE_ROCM @@ -100,11 +100,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem namespace aiter { -mha_fwd_traits get_mha_fwd_traits(int head_size_q, - int head_size_v, - std::string dtype, - bool is_group_mode, - bool has_logits_soft_cap, - mask_enum mask_type, - bias_enum bias_type, - bool has_lse, - bool has_dropout, - quant_scale_enum qscale_type, - bool use_ext_asm, - bool has_sink = false, - int how_v3_bf16_cvt = 1, - bool skip_min_seqlen_q = false) +mha_batch_prefill_traits +get_mha_batch_prefill_traits(int head_size_q, + int head_size_v, + std::string dtype, + bool is_group_mode, + bool has_logits_soft_cap, + mask_enum mask_type, + bias_enum bias_type, + bool has_lse, + bool has_dropout, + quant_scale_enum qscale_type, + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout, + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table, + int page_size, + bool skip_min_seqlen_q = false) { - return mha_fwd_traits(head_size_q, - head_size_v, - dtype, - is_group_mode, - has_logits_soft_cap, - mask_type, - bias_type, - has_lse, - has_dropout, - qscale_type, - use_ext_asm, - how_v3_bf16_cvt, - skip_min_seqlen_q, - has_sink); + return mha_batch_prefill_traits(head_size_q, + head_size_v, + dtype, + is_group_mode, + has_logits_soft_cap, + mask_type, + bias_type, + has_lse, + has_dropout, + qscale_type, + skip_min_seqlen_q, + kv_memory_layout, + kv_lookup_table, + page_size); } float mha_batch_prefill(mha_batch_prefill_args args, @@ -46,17 +47,19 @@ float mha_batch_prefill(mha_batch_prefill_args args, int head_size_q = args.hdim_q; int head_size_v = args.hdim_v; bool has_dropout = args.p_drop > 0.f; - auto traits = get_mha_fwd_traits(head_size_q, - head_size_v, - q_dtype_str, - is_group_mode, - args.logits_soft_cap > 0.f, - mask_type, - bias_type, - has_lse, - has_dropout, - qscale_type, - use_ext_asm); + auto traits = get_mha_batch_prefill_traits(head_size_q, + head_size_v, + q_dtype_str, + is_group_mode, + args.logits_soft_cap > 0.f, + mask_type, + bias_type, + has_lse, + has_dropout, + qscale_type, + args.kv_memory_layout, + args.kv_lookup_table, + args.page_block_size); return fmha_batch_prefill(traits, args, stream_config); } diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index 3e7b4ad097..73915a3b84 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -2,8 +2,8 @@ // Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" -#include #include +#include #include enum class GPUArch @@ -12,15 +12,14 @@ enum class GPUArch gfx950 }; - -#define CHECK_COND(x) \ - do { \ - if (!(x)) { \ - std::cerr << "check failed, file=" \ - << __FILE__ << ", line=" \ - << __LINE__ << std::endl; \ - std::terminate(); \ - } \ +#define CHECK_COND(x) \ + do \ + { \ + if(!(x)) \ + { \ + std::cerr << "check failed, file=" << __FILE__ << ", line=" << __LINE__ << std::endl; \ + std::terminate(); \ + } \ } while(0) #define HIP_CALL(call) \ diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index a66ba77bc1..349404f5bc 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -28,127 +28,125 @@ #include #include +namespace aiter { -namespace aiter +constexpr int kMaxBlocks = 80; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links +struct Signal { - - constexpr int kMaxBlocks = 80; - // note: we don't want to use atomics for signals because peer atomics are no - // supported on PCIe links - struct Signal - { alignas(128) uint32_t start[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank - }; +}; #ifdef USE_ROCM - struct __align__(16) RankData { const void *ptrs[8]; }; +struct __align__(16) RankData { const void* ptrs[8]; }; #else - struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; #endif - struct __align__(16) RankSignals - { +struct __align__(16) RankSignals +{ #ifndef USE_ROCM volatile #endif - Signal *signals[8]; - }; + Signal* signals[8]; +}; - // like std::array, but aligned - template - struct __align__(alignof(T) * sz) array_t - { +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t +{ T data[sz]; - using type = T; + using type = T; static constexpr int size = sz; - }; +}; - // use packed type to maximize memory efficiency - // goal: generate ld.128 and st.128 instructions - template - struct packed_t - { +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t +{ // the (P)acked type for load/store using P = array_t; // the (A)ccumulator type for reduction using A = array_t; - }; +}; #define DINLINE __device__ __forceinline__ - // scalar cast functions - DINLINE float upcast_s(half val) { return __half2float(val); } +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } - template - DINLINE T downcast_s(float val); - template <> - DINLINE half downcast_s(float val) - { +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) +{ return __float2half(val); - } +} - // scalar add functions - // for some reason when compiling with Pytorch, the + operator for half and - // bfloat is disabled so we call the intrinsics directly - DINLINE half &assign_add(half &a, half b) - { +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) +{ a = __hadd(a, b); return a; - } - DINLINE float &assign_add(float &a, float b) { return a += b; } - -#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) - DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); } - template <> - DINLINE __hip_bfloat16 downcast_s(float val) - { +} +DINLINE float& assign_add(float& a, float b) { return a += b; } + +#if(__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE __hip_bfloat16 downcast_s(float val) +{ return __float2bfloat16(val); - } - DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b) - { +} +DINLINE __hip_bfloat16& assign_add(__hip_bfloat16& a, __hip_bfloat16 b) +{ a = __hadd(a, b); return a; - } +} #endif - template - DINLINE array_t &packed_assign_add(array_t &a, array_t b) - { +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) +{ #pragma unroll - for (int i = 0; i < N; i++) + for(int i = 0; i < N; i++) { - assign_add(a.data[i], b.data[i]); + assign_add(a.data[i], b.data[i]); } return a; - } +} - template - DINLINE array_t upcast(array_t val) - { - if constexpr (std::is_same::value) +template +DINLINE array_t upcast(array_t val) +{ + if constexpr(std::is_same::value) { - return val; + return val; } else { - array_t out; + array_t out; #pragma unroll - for (int i = 0; i < N; i++) - { - out.data[i] = upcast_s(val.data[i]); - } - return out; + for(int i = 0; i < N; i++) + { + out.data[i] = upcast_s(val.data[i]); + } + return out; } - } +} - template - DINLINE O downcast(array_t val) - { - if constexpr (std::is_same::value) +template +DINLINE O downcast(array_t val) +{ + if constexpr(std::is_same::value) { - return val; + return val; } // else if constexpr (std::is_same::value) // { @@ -167,73 +165,75 @@ namespace aiter // } else { - O out; + O out; #pragma unroll - for (int i = 0; i < O::size; i++) - { - out.data[i] = downcast_s(val.data[i]); - } - return out; - } - } - - // This function is meant to be used as the first synchronization in the all - // reduce kernel. Thus, it doesn't need to make any visibility guarantees for - // prior memory accesses. Note: volatile writes will not be reordered against - // other volatile writes. - template - DINLINE void start_sync(const RankSignals &sg, + for(int i = 0; i < O::size; i++) + { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - int rank) - { + Signal* self_sg, + int rank) +{ #ifdef USE_ROCM uint32_t flag = self_sg->_flag[blockIdx.x] + 1; - if (threadIdx.x < ngpus) - { - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], - flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); - // wait until we got true from all ranks - while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], - __ATOMIC_RELAXED, - __MEMORY_SCOPE_DEVICE) < flag) - ; + if(threadIdx.x < ngpus) + { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], + flag, + __ATOMIC_RELAXED, + __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while(__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED, + __MEMORY_SCOPE_DEVICE) < flag) + ; } __syncthreads(); // use one thread to update flag - if (threadIdx.x == 0) - self_sg->_flag[blockIdx.x] = flag; + if(threadIdx.x == 0) + self_sg->_flag[blockIdx.x] = flag; #else - if (threadIdx.x < ngpus) + if(threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while(!self_sg->start[blockIdx.x][threadIdx.x]) + ; } __syncthreads(); #endif - } +} - // This function is meant to be used as the second or the final synchronization - // barrier in the all reduce kernel. If it's the final synchronization barrier, - // we don't need to make any visibility guarantees for prior memory accesses. - template - DINLINE void end_sync(const RankSignals &sg, +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - int rank) - { + Signal* self_sg, + int rank) +{ #ifdef USE_ROCM __syncthreads(); // eliminate the case that prior writes are not visible after signals become @@ -241,70 +241,71 @@ namespace aiter // testing. Might be the case that hardware provides stronger guarantee than // the memory model. uint32_t flag = self_sg->_flag[blockIdx.x] + 1; - if (threadIdx.x < ngpus) - { - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], - flag, - final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, - __MEMORY_SCOPE_SYSTEM); - // wait until we got true from all ranks - while ( - __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], - final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, - __MEMORY_SCOPE_DEVICE) < flag) - ; + if(threadIdx.x < ngpus) + { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], + flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, + __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while(__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, + __MEMORY_SCOPE_DEVICE) < flag) + ; } __syncthreads(); // use one thread to update flag - if (threadIdx.x == 0) - self_sg->_flag[blockIdx.x] = flag; + if(threadIdx.x == 0) + self_sg->_flag[blockIdx.x] = flag; #else __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. - if constexpr (!final_sync) - __threadfence_system(); - if (threadIdx.x < ngpus) - { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; - } - if constexpr (!final_sync) - __syncthreads(); + if constexpr(!final_sync) + __threadfence_system(); + if(threadIdx.x < ngpus) + { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while(!self_sg->end[blockIdx.x][threadIdx.x]) + ; + } + if constexpr(!final_sync) + __syncthreads(); #endif - } +} - template - DINLINE P packed_reduce(const P *ptrs[], int idx) - { +template +DINLINE P packed_reduce(const P* ptrs[], int idx) +{ A tmp = upcast(ptrs[0][idx]); #pragma unroll - for (int i = 1; i < ngpus; i++) + for(int i = 1; i < ngpus; i++) { - packed_assign_add(tmp, upcast(ptrs[i][idx])); + packed_assign_add(tmp, upcast(ptrs[i][idx])); } return downcast

(tmp); - } +} - template - __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage_naive(RankData *_dp, RankSignals sg, +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage_naive(RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - T *__restrict__ result, int rank, int size) - { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) +{ using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -312,57 +313,58 @@ namespace aiter auto dp = *_dp; start_sync(sg, self_sg, rank); // do the actual reduction - for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; - idx += gridDim.x * blockDim.x) + for(int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); - } +} - template +template #ifdef USE_ROCM - DINLINE P *get_tmp_buf(Signal *sg) - { +DINLINE P* get_tmp_buf(Signal* sg) +{ #else - DINLINE P *get_tmp_buf(volatile Signal *sg) - { +DINLINE P* get_tmp_buf(volatile Signal* sg) +{ #endif - return (P *)(((Signal *)sg) + 1); - } + return (P*)(((Signal*)sg) + 1); +} - template - __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage_naive(RankData *_dp, RankSignals sg, +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage_naive(RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - T *__restrict__ result, int rank, int size) - { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - using P = typename packed_t::P; - using A = typename packed_t::A; - int part = size / ngpus; - int start = rank * part; - int end = rank == ngpus - 1 ? size : start + part; + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll - for (int i = 0; i < ngpus; i++) + for(int i = 0; i < ngpus; i++) { - int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; - tmps[i] = get_tmp_buf

(sg.signals[target]); + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter - for (int idx = start + tid; idx < end; idx += stride) + for(int idx = start + tid; idx < end; idx += stride) { - tmp_out[idx - start] = packed_reduce(ptrs, idx); + tmp_out[idx - start] = packed_reduce(ptrs, idx); } end_sync(sg, self_sg, rank); @@ -371,36 +373,38 @@ namespace aiter // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. - for (int idx = tid; idx < largest_part; idx += stride) + for(int idx = tid; idx < largest_part; idx += stride) { #pragma unroll - for (int i = 0; i < ngpus; i++) - { - int gather_from_rank = ((rank + i) % ngpus); - if (gather_from_rank == ngpus - 1 || idx < part) + for(int i = 0; i < ngpus; i++) { - int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + int gather_from_rank = ((rank + i) % ngpus); + if(gather_from_rank == ngpus - 1 || idx < part) + { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } } - } } - } +} #define THREAD_NUM 512 - template - __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - T *__restrict__ result, int rank, int size) - { - using P = typename packed_t::P; - using A = typename packed_t::A; + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) +{ + using P = typename packed_t::P; + using A = typename packed_t::A; constexpr int pack_size = packed_t::P::size; - constexpr int tnum_gpu = THREAD_NUM / ngpus; + constexpr int tnum_gpu = THREAD_NUM / ngpus; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results @@ -411,109 +415,115 @@ namespace aiter int lane_id = threadIdx.x % tnum_gpu; start_sync(sg, self_sg, rank); // do the actual reduction - for (int idx = blockIdx.x * tnum_gpu + lane_id; idx < size; - idx += gridDim.x * tnum_gpu) - { - *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ((const P**)&dp.ptrs[0])[warp_id][idx]; - __syncthreads(); - if (warp_id == 0) - { - A add_reg; -#pragma unroll - for (int i = 0; i < pack_size; ++i) + for(int idx = blockIdx.x * tnum_gpu + lane_id; idx < size; idx += gridDim.x * tnum_gpu) + { + *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = + ((const P**)&dp.ptrs[0])[warp_id][idx]; + __syncthreads(); + if(warp_id == 0) { - add_reg.data[i] = ck_tile::type_convert(tmp_smem[threadIdx.x * pack_size + i]); - } - constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; + A add_reg; #pragma unroll - for (int i = 1; i < ngpus; ++i) - { + for(int i = 0; i < pack_size; ++i) + { + add_reg.data[i] = + ck_tile::type_convert(tmp_smem[threadIdx.x * pack_size + i]); + } + constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll - for (int j = 0; j < pack_size; ++j) - { - add_reg.data[j] += ck_tile::type_convert(tmp_smem[smem_gpu_loop_stride * i + threadIdx.x * pack_size + j]); - } - } - P write_reg; + for(int i = 1; i < ngpus; ++i) + { #pragma unroll - for (int i = 0; i < pack_size; ++i) - { - write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); + for(int j = 0; j < pack_size; ++j) + { + add_reg.data[j] += ck_tile::type_convert( + tmp_smem[smem_gpu_loop_stride * i + threadIdx.x * pack_size + j]); + } + } + P write_reg; +#pragma unroll + for(int i = 0; i < pack_size; ++i) + { + write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); + } + ((P*)result)[idx] = write_reg; } - ((P *)result)[idx] = write_reg; - } - __syncthreads(); + __syncthreads(); } // maybe do not need device sync // end_sync(sg, self_sg, rank); - } +} - template - __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, +template +__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal *self_sg, - T *__restrict__ result, int rank, int size) - { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) +{ constexpr int pack_size = packed_t::P::size; - constexpr int tnum_gpu = THREAD_NUM / ngpus; - using P = typename packed_t::P; - using A = typename packed_t::A; + constexpr int tnum_gpu = THREAD_NUM / ngpus; + using P = typename packed_t::P; + using A = typename packed_t::A; __shared__ T tmp_smem[tnum_gpu * ngpus * pack_size]; - int warp_id = threadIdx.x / tnum_gpu; - int lane_id = threadIdx.x % tnum_gpu; - int tid = blockIdx.x * tnum_gpu + lane_id; - int stride = gridDim.x * tnum_gpu; - int part = size / ngpus; - int start = rank * part; - int end = rank == ngpus - 1 ? size : start + part; + int warp_id = threadIdx.x / tnum_gpu; + int lane_id = threadIdx.x % tnum_gpu; + int tid = blockIdx.x * tnum_gpu + lane_id; + int stride = gridDim.x * tnum_gpu; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll - for (int i = 0; i < ngpus; i++) + for(int i = 0; i < ngpus; i++) { - int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; - tmps[i] = get_tmp_buf

(sg.signals[target]); + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; start_sync(sg, self_sg, rank); // stage 1: reduce scatter - for (int idx = start + tid; idx < end; idx += stride) - { - *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx]; - __syncthreads(); - // cal add in first 64 threads - if (warp_id == 0) - { - A add_reg; -#pragma unroll - for (int i = 0; i < pack_size; ++i) + for(int idx = start + tid; idx < end; idx += stride) + { + *(reinterpret_cast(&tmp_smem[0]) + threadIdx.x) = ptrs[warp_id][idx]; + __syncthreads(); + // cal add in first 64 threads + if(warp_id == 0) { - add_reg.data[i] = ck_tile::type_convert(tmp_smem[pack_size * threadIdx.x + i]); - } - constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; + A add_reg; #pragma unroll - for (int i = 1; i < ngpus; ++i) - { + for(int i = 0; i < pack_size; ++i) + { + add_reg.data[i] = + ck_tile::type_convert(tmp_smem[pack_size * threadIdx.x + i]); + } + constexpr int smem_gpu_loop_stride = tnum_gpu * pack_size; #pragma unroll - for (int j = 0; j < pack_size; ++j) - { - add_reg.data[j] += ck_tile::type_convert(tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]); - } - } - P write_reg; + for(int i = 1; i < ngpus; ++i) + { #pragma unroll - for (int i = 0; i < pack_size; ++i) - { - write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); + for(int j = 0; j < pack_size; ++j) + { + add_reg.data[j] += ck_tile::type_convert( + tmp_smem[i * smem_gpu_loop_stride + pack_size * threadIdx.x + j]); + } + } + P write_reg; +#pragma unroll + for(int i = 0; i < pack_size; ++i) + { + write_reg.data[i] = ck_tile::type_convert(add_reg.data[i]); + } + tmp_out[idx - start] = write_reg; } - tmp_out[idx - start] = write_reg; - } - __syncthreads(); + __syncthreads(); } end_sync(sg, self_sg, rank); @@ -522,734 +532,706 @@ namespace aiter // between threads that have the same tid. If thread i computes the sum of // start + i in the first stage, then thread i also gathers start + i from all // ranks. - for (int idx = tid; idx < largest_part; idx += stride) - { - int dst_idx = (warp_id + rank) % ngpus * part + idx; - ((P *)result)[dst_idx] = tmps[warp_id][idx]; - } - } - - /* - * naive allgather - * for case: input(1345,) - * */ - template - __global__ void __launch_bounds__(512, 1) allgather_naive( - RankData* _dp, - RankSignals sg, - Signal* self_sg, - T* __restrict__ result, - int rank, - int size - ) - { + for(int idx = tid; idx < largest_part; idx += stride) + { + int dst_idx = (warp_id + rank) % ngpus * part + idx; + ((P*)result)[dst_idx] = tmps[warp_id][idx]; + } +} + +/* + * naive allgather + * for case: input(1345,) + * */ +template +__global__ void __launch_bounds__(512, 1) allgather_naive( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) +{ constexpr int tnum_gpu = THREAD_NUM / ngpus; - int warp_id = threadIdx.x / tnum_gpu; - int lane_id = threadIdx.x % tnum_gpu; - int tid = blockIdx.x * tnum_gpu + lane_id; - int stride = gridDim.x * tnum_gpu; + int warp_id = threadIdx.x / tnum_gpu; + int lane_id = threadIdx.x % tnum_gpu; + int tid = blockIdx.x * tnum_gpu + lane_id; + int stride = gridDim.x * tnum_gpu; const T* ptrs[ngpus]; #pragma unroll - for (int i = 0; i < ngpus; ++i) + for(int i = 0; i < ngpus; ++i) { - ptrs[i] = (const T*)_dp->ptrs[i]; + ptrs[i] = (const T*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); - for (int idx = tid; idx < size; idx += stride) + for(int idx = tid; idx < size; idx += stride) { - int write_idx = warp_id * size + idx; - result[write_idx] = ptrs[warp_id][idx]; + int write_idx = warp_id * size + idx; + result[write_idx] = ptrs[warp_id][idx]; } - } +} - template - __global__ void __launch_bounds__(512, 1) allgather_vec( - RankData* _dp, - RankSignals sg, - Signal* self_sg, - T* __restrict__ result, - int rank, - int size - ) - { +template +__global__ void __launch_bounds__(512, 1) allgather_vec( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) +{ constexpr int tnum_gpu = THREAD_NUM / ngpus; - using P = typename packed_t::P; - int warp_id = threadIdx.x / tnum_gpu; - int lane_id = threadIdx.x % tnum_gpu; - int tid = blockIdx.x * tnum_gpu + lane_id; - int stride = gridDim.x * tnum_gpu; + using P = typename packed_t::P; + int warp_id = threadIdx.x / tnum_gpu; + int lane_id = threadIdx.x % tnum_gpu; + int tid = blockIdx.x * tnum_gpu + lane_id; + int stride = gridDim.x * tnum_gpu; const P* ptrs[ngpus]; #pragma unroll - for (int i = 0; i < ngpus; ++i) + for(int i = 0; i < ngpus; ++i) { - ptrs[i] = (const P*)_dp->ptrs[i]; + ptrs[i] = (const P*)_dp->ptrs[i]; } start_sync(sg, self_sg, rank); - for (int idx = tid; idx < size; idx += stride) - { - int write_idx = warp_id * size + idx; - *(reinterpret_cast(&result[0]) + write_idx) = ptrs[warp_id][idx]; - } - } - - /* - * reduce_scatter, at first dim - * range = size / (pack_size * ngpu) - * for case: - * input:(ngpus * n) -> output:(n) - * input:(ngpus * m, n, ...) -> output(m, n, ...) - * cond: size % (pack_size * ngpus) == 0 - * */ - template - __global__ void __launch_bounds__(512, 1) reduce_scatter_first_dim( - RankData *_dp, - RankSignals sg, - Signal *self_sg, - T *__restrict__ result, - int rank, - int range - ) - { - int tid = blockIdx.x * blockDim.x + threadIdx.x; + for(int idx = tid; idx < size; idx += stride) + { + int write_idx = warp_id * size + idx; + *(reinterpret_cast(&result[0]) + write_idx) = ptrs[warp_id][idx]; + } +} + +/* + * reduce_scatter, at first dim + * range = size / (pack_size * ngpu) + * for case: + * input:(ngpus * n) -> output:(n) + * input:(ngpus * m, n, ...) -> output(m, n, ...) + * cond: size % (pack_size * ngpus) == 0 + * */ +template +__global__ void __launch_bounds__(512, 1) reduce_scatter_first_dim( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int range) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; - using P = typename packed_t::P; - using A = typename packed_t::A; - const P *ptrs[ngpus]; + using P = typename packed_t::P; + using A = typename packed_t::A; + const P* ptrs[ngpus]; #pragma unroll - for (int i = 0; i < ngpus; i++) + for(int i = 0; i < ngpus; i++) { - int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; } start_sync(sg, self_sg, rank); - for (int idx = tid; idx < range; idx += stride) + for(int idx = tid; idx < range; idx += stride) { - int load_index = rank * range + idx; - int store_index = idx; - *(reinterpret_cast(result) + store_index) = packed_reduce(ptrs, load_index); + int load_index = rank * range + idx; + int store_index = idx; + *(reinterpret_cast(result) + store_index) = + packed_reduce(ptrs, load_index); } - } +} - // fp8 quant all-reduce code start - template - struct Fp16Filter - { +// fp8 quant all-reduce code start +template +struct Fp16Filter +{ static const bool value = false; - }; +}; - template <> - struct Fp16Filter - { +template <> +struct Fp16Filter +{ static const bool value = true; - }; +}; - template - struct Bf16Filter - { +template +struct Bf16Filter +{ static const bool value = false; - }; +}; - template <> - struct Bf16Filter<__hip_bfloat16> - { +template <> +struct Bf16Filter<__hip_bfloat16> +{ static const bool value = true; - }; +}; - // dtypes only support half and bf16 now -#define FP16_FILTER \ - typename std::enable_if::value, void>::type* = nullptr +// dtypes only support half and bf16 now +#define FP16_FILTER typename std::enable_if::value, void>::type* = nullptr -#define BF16_FILTER \ - typename std::enable_if::value, void>::type* = nullptr +#define BF16_FILTER typename std::enable_if::value, void>::type* = nullptr - template