Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 19 additions & 41 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
is_hip_rdna4,
is_hip_gfx1250,
is_xpu,
get_arch,
torch_float8_dtypes,
torch_dtypes,
numpy_random,
Expand Down Expand Up @@ -4302,7 +4301,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_


@pytest.mark.interpreter
@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"])
@pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cs", ".cv"])
def test_load_cache_modifier(cache, device):
src = torch.empty(128, device=device)
dst = torch.empty(128, device=device)
Expand All @@ -4316,12 +4315,12 @@ def _kernel(dst, src, CACHE: tl.constexpr):
pgm = _kernel[(1, )](dst, src, CACHE=cache)

if is_hip():
target_arch = get_arch()
# TODO: support testing for remaining architectures
if 'gfx94' not in target_arch:
if not is_hip_cdna3() and not is_hip_cdna4():
return
amdgcn = pgm.asm['amdgcn']
cg_cache_modifier_str = 'nt'
cs_cache_modifier_str = 'sc0 nt'
cv_cache_modifier_str = 'sc0 sc1'
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
Expand All @@ -4330,20 +4329,19 @@ def _kernel(dst, src, CACHE: tl.constexpr):
assert cg_cache_modifier_str not in load_line
if cache == '.cg':
assert cg_cache_modifier_str in load_line
if cache == ".cs":
assert cs_cache_modifier_str in load_line
if cache == '.cv':
assert cv_cache_modifier_str in load_line

if is_cuda():
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
all_modifiers = ['.ca', '.cg', '.cs', '.cv']
for modifier in all_modifiers:
if modifier == cache:
assert f'ld.global{modifier}' in ptx
else:
assert f'ld.global{modifier}' not in ptx


@pytest.mark.interpreter
Expand Down Expand Up @@ -4444,17 +4442,16 @@ def _kernel(dst, src, CACHE: tl.constexpr):
pgm = _kernel[(1, )](dst, src, CACHE=cache)

if is_hip():
target_arch = get_arch()
# TODO: support testing for remaining architectures
if 'gfx94' not in target_arch:
if not is_hip_cdna3() and not is_hip_cdna4():
return
amdgcn = pgm.asm['amdgcn']
cs_cache_modifier_str = 'nt'
wt_cache_modifier_str = 'sc0 sc1'
buffer_store_line = [line for line in amdgcn.splitlines() if "buffer_store" in line]
global_store_line = [line for line in amdgcn.splitlines() if "global_store" in line]
store_line = global_store_line[0] if global_store_line else buffer_store_line[0]
if cache == '' or cache == '.cg':
if cache == '' or cache == '.wb' or cache == '.cg':
assert cs_cache_modifier_str not in store_line
assert wt_cache_modifier_str not in store_line
if cache == '.cs':
Expand All @@ -4466,31 +4463,12 @@ def _kernel(dst, src, CACHE: tl.constexpr):

if is_cuda():
ptx = pgm.asm['ptx']
if cache == '':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.wb':
assert 'st.global.wb' in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cg':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' not in ptx
if cache == '.cs':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' in ptx
assert 'st.global.wt' not in ptx
if cache == '.wt':
assert 'st.global.wb' not in ptx
assert 'st.global.cg' not in ptx
assert 'st.global.cs' not in ptx
assert 'st.global.wt' in ptx
all_modifiers = ['.wb', '.cg', '.cs', '.wt']
for modifier in all_modifiers:
if modifier == cache:
assert f'st.global{modifier}' in ptx
else:
assert f'st.global{modifier}' not in ptx


@pytest.mark.interpreter
Expand Down
5 changes: 0 additions & 5 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def is_xpu():
return False if target is None else target.backend == "xpu"


def get_arch():
target = get_current_target()
return "" if target is None else str(target.arch)


def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,8 @@ def _str_to_load_cache_modifier(self, cache_modifier):
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
elif cache_modifier == ".cs":
cache = ir.CACHE_MODIFIER.CS
elif cache_modifier == ".cv":
cache = ir.CACHE_MODIFIER.CV
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
.global()
.o("ca", op.getCache() == triton::CacheModifier::CA)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("cs", op.getCache() == triton::CacheModifier::CS)
.o("cv", op.getCache() == triton::CacheModifier::CV)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
Expand Down
Loading