From 2cb23c09fae5d7ad7cf07d8db86f4aba62c4549f Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 6 Apr 2026 15:48:47 +0000 Subject: [PATCH 1/3] Improve cache modifiers support for loads This PR: - Enables CS and CV load cache modifier in Nvidia backend - Enables CS cache modifier in frontend - Improves test coverage for CDNA3 and CDNA4 This PR enables CS cache modifier, underlying infrastructure already exists, for some reason this modifier is disabled in frontend. --- python/test/unit/language/test_core.py | 34 ++++++++++++++----- python/triton/_internal_testing.py | 5 --- python/triton/language/semantic.py | 2 ++ .../LoadStoreOpToLLVM.cpp | 2 ++ 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6593950f07eb..0603676c63a2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -37,7 +37,6 @@ is_hip_rdna4, is_hip_gfx1250, is_xpu, - get_arch, torch_float8_dtypes, torch_dtypes, numpy_random, @@ -4302,7 +4301,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ @pytest.mark.interpreter -@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cs", ".cv"]) def test_load_cache_modifier(cache, device): src = torch.empty(128, device=device) dst = torch.empty(128, device=device) @@ -4316,12 +4315,12 @@ def _kernel(dst, src, CACHE: tl.constexpr): pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): - target_arch = get_arch() # TODO: support testing for remaining architectures - if 'gfx94' not in target_arch: + if not is_hip_cdna3() and not is_hip_cdna4(): return amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' + cs_cache_modifier_str = 'sc0 nt' cv_cache_modifier_str = 'sc0 sc1' buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] @@ -4330,6 +4329,8 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert cg_cache_modifier_str not in load_line if cache == '.cg': assert cg_cache_modifier_str in load_line + if cache == ".cs": + assert cs_cache_modifier_str in load_line if cache == '.cv': assert cv_cache_modifier_str in load_line @@ -4338,12 +4339,28 @@ def _kernel(dst, src, CACHE: tl.constexpr): if cache == '': assert 'ld.global.ca' not in ptx assert 'ld.global.cg' not in ptx + assert 'ld.global.cs' not in ptx + assert 'ld.global.cv' not in ptx + if cache == '.ca': + assert 'ld.global.ca' in ptx + assert 'ld.global.cs' not in ptx + assert 'ld.global.cg' not in ptx + assert 'ld.global.cv' not in ptx if cache == '.cg': + assert 'ld.global.ca' not in ptx assert 'ld.global.cg' in ptx + assert 'ld.global.cs' not in ptx + assert 'ld.global.cv' not in ptx + if cache == '.cs': assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx assert 'ld.global.cg' not in ptx + assert 'ld.global.cs' in ptx + assert 'ld.global.cv' not in ptx + if cache == '.cv': + assert 'ld.global.ca' not in ptx + assert 'ld.global.cs' not in ptx + assert 'ld.global.cg' not in ptx + assert 'ld.global.cv' in ptx @pytest.mark.interpreter @@ -4444,9 +4461,8 @@ def _kernel(dst, src, CACHE: tl.constexpr): pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): - target_arch = get_arch() # TODO: support testing for remaining architectures - if 'gfx94' not in target_arch: + if not is_hip_cdna3() and not is_hip_cdna4(): return amdgcn = pgm.asm['amdgcn'] cs_cache_modifier_str = 'nt' @@ -4454,7 +4470,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line] global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] store_line = global_store_line[0] if global_store_line else buffer_store_line[0] - if cache == '' or cache == '.cg': + if cache == '' or cache == '.wb' or cache == '.cg': assert cs_cache_modifier_str not in store_line assert wt_cache_modifier_str not in store_line if cache == '.cs': diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 8c52bb81d877..7199fedd8ca4 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -119,11 +119,6 @@ def is_xpu(): return False if target is None else target.backend == "xpu" -def get_arch(): - target = get_current_target() - return "" if target is None else str(target.arch) - - def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 42ba268365ea..0b80d19acb05 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -919,6 +919,8 @@ def _str_to_load_cache_modifier(self, cache_modifier): cache = ir.CACHE_MODIFIER.CA elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS elif cache_modifier == ".cv": cache = ir.CACHE_MODIFIER.CV else: diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 5a9c020ad3fd..f41c22605841 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -303,6 +303,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, .global() .o("ca", op.getCache() == triton::CacheModifier::CA) .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("cs", op.getCache() == triton::CacheModifier::CS) + .o("cv", op.getCache() == triton::CacheModifier::CV) .o("L1::evict_first", op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last", From d539b3b4833d8db72206509aee36336483745df0 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 7 Apr 2026 15:30:23 +0000 Subject: [PATCH 2/3] reduce code duplication --- python/test/unit/language/test_core.py | 62 +++++--------------------- 1 file changed, 12 insertions(+), 50 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 0603676c63a2..6401354d5412 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4336,31 +4336,12 @@ def _kernel(dst, src, CACHE: tl.constexpr): if is_cuda(): ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - assert 'ld.global.cs' not in ptx - assert 'ld.global.cv' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cs' not in ptx - assert 'ld.global.cg' not in ptx - assert 'ld.global.cv' not in ptx - if cache == '.cg': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' in ptx - assert 'ld.global.cs' not in ptx - assert 'ld.global.cv' not in ptx - if cache == '.cs': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - assert 'ld.global.cs' in ptx - assert 'ld.global.cv' not in ptx - if cache == '.cv': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cs' not in ptx - assert 'ld.global.cg' not in ptx - assert 'ld.global.cv' in ptx + all_modifiers = ['.ca', '.cg', '.cs', '.cv'] + for modifier in all_modifiers: + if modifier == cache: + assert f'ld.global.{modifier}' in ptx + else: + assert f'ld.global.{modifier}' not in ptx @pytest.mark.interpreter @@ -4482,31 +4463,12 @@ def _kernel(dst, src, CACHE: tl.constexpr): if is_cuda(): ptx = pgm.asm['ptx'] - if cache == '': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.wb': - assert 'st.global.wb' in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cg': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cs': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' in ptx - assert 'st.global.wt' not in ptx - if cache == '.wt': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' in ptx + all_modifiers = ['.wb', '.cg', '.cs', '.wt'] + for modifier in all_modifiers: + if modifier == cache: + assert f'st.global.{modifier}' in ptx + else: + assert f'st.global.{modifier}' not in ptx @pytest.mark.interpreter From 110be2659bbef2de0c73340bb5caac7b0218f78e Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Tue, 7 Apr 2026 18:24:29 +0000 Subject: [PATCH 3/3] fix prefix concatenation in test --- python/test/unit/language/test_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6401354d5412..af9e1e582ca6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4339,9 +4339,9 @@ def _kernel(dst, src, CACHE: tl.constexpr): all_modifiers = ['.ca', '.cg', '.cs', '.cv'] for modifier in all_modifiers: if modifier == cache: - assert f'ld.global.{modifier}' in ptx + assert f'ld.global{modifier}' in ptx else: - assert f'ld.global.{modifier}' not in ptx + assert f'ld.global{modifier}' not in ptx @pytest.mark.interpreter @@ -4466,9 +4466,9 @@ def _kernel(dst, src, CACHE: tl.constexpr): all_modifiers = ['.wb', '.cg', '.cs', '.wt'] for modifier in all_modifiers: if modifier == cache: - assert f'st.global.{modifier}' in ptx + assert f'st.global{modifier}' in ptx else: - assert f'st.global.{modifier}' not in ptx + assert f'st.global{modifier}' not in ptx @pytest.mark.interpreter