diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6593950f07eb..af9e1e582ca6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -37,7 +37,6 @@ is_hip_rdna4, is_hip_gfx1250, is_xpu, - get_arch, torch_float8_dtypes, torch_dtypes, numpy_random, @@ -4302,7 +4301,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ @pytest.mark.interpreter -@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) +@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cs", ".cv"]) def test_load_cache_modifier(cache, device): src = torch.empty(128, device=device) dst = torch.empty(128, device=device) @@ -4316,12 +4315,12 @@ def _kernel(dst, src, CACHE: tl.constexpr): pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): - target_arch = get_arch() # TODO: support testing for remaining architectures - if 'gfx94' not in target_arch: + if not is_hip_cdna3() and not is_hip_cdna4(): return amdgcn = pgm.asm['amdgcn'] cg_cache_modifier_str = 'nt' + cs_cache_modifier_str = 'sc0 nt' cv_cache_modifier_str = 'sc0 sc1' buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line] global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line] @@ -4330,20 +4329,19 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert cg_cache_modifier_str not in load_line if cache == '.cg': assert cg_cache_modifier_str in load_line + if cache == ".cs": + assert cs_cache_modifier_str in load_line if cache == '.cv': assert cv_cache_modifier_str in load_line if is_cuda(): ptx = pgm.asm['ptx'] - if cache == '': - assert 'ld.global.ca' not in ptx - assert 'ld.global.cg' not in ptx - if cache == '.cg': - assert 'ld.global.cg' in ptx - assert 'ld.global.ca' not in ptx - if cache == '.ca': - assert 'ld.global.ca' in ptx - assert 'ld.global.cg' not in ptx + all_modifiers = ['.ca', '.cg', '.cs', '.cv'] + for modifier in all_modifiers: + if modifier == cache: + assert f'ld.global{modifier}' in ptx + else: + assert f'ld.global{modifier}' not in ptx @pytest.mark.interpreter @@ -4444,9 +4442,8 @@ def _kernel(dst, src, CACHE: tl.constexpr): pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): - target_arch = get_arch() # TODO: support testing for remaining architectures - if 'gfx94' not in target_arch: + if not is_hip_cdna3() and not is_hip_cdna4(): return amdgcn = pgm.asm['amdgcn'] cs_cache_modifier_str = 'nt' @@ -4454,7 +4451,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line] global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line] store_line = global_store_line[0] if global_store_line else buffer_store_line[0] - if cache == '' or cache == '.cg': + if cache == '' or cache == '.wb' or cache == '.cg': assert cs_cache_modifier_str not in store_line assert wt_cache_modifier_str not in store_line if cache == '.cs': @@ -4466,31 +4463,12 @@ def _kernel(dst, src, CACHE: tl.constexpr): if is_cuda(): ptx = pgm.asm['ptx'] - if cache == '': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.wb': - assert 'st.global.wb' in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cg': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' not in ptx - if cache == '.cs': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' in ptx - assert 'st.global.wt' not in ptx - if cache == '.wt': - assert 'st.global.wb' not in ptx - assert 'st.global.cg' not in ptx - assert 'st.global.cs' not in ptx - assert 'st.global.wt' in ptx + all_modifiers = ['.wb', '.cg', '.cs', '.wt'] + for modifier in all_modifiers: + if modifier == cache: + assert f'st.global{modifier}' in ptx + else: + assert f'st.global{modifier}' not in ptx @pytest.mark.interpreter diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index 8c52bb81d877..7199fedd8ca4 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -119,11 +119,6 @@ def is_xpu(): return False if target is None else target.backend == "xpu" -def get_arch(): - target = get_current_target() - return "" if target is None else str(target.arch) - - def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 42ba268365ea..0b80d19acb05 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -919,6 +919,8 @@ def _str_to_load_cache_modifier(self, cache_modifier): cache = ir.CACHE_MODIFIER.CA elif cache_modifier == ".cg": cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS elif cache_modifier == ".cv": cache = ir.CACHE_MODIFIER.CV else: diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index bcc32d7dc29f..baf959000b60 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -287,6 +287,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, .global() .o("ca", op.getCache() == triton::CacheModifier::CA) .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("cs", op.getCache() == triton::CacheModifier::CS) + .o("cv", op.getCache() == triton::CacheModifier::CV) .o("L1::evict_first", op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) .o("L1::evict_last",