Skip to content

Commit

Permalink
Introduce support for buffer operations
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 30, 2024
1 parent 99a66ec commit 6dafa47
Show file tree
Hide file tree
Showing 9 changed files with 512 additions and 49 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace mlir::triton {
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
// clang-format off
"AMDGCN_ENABLE_DUMP",
"AMDGCN_USE_BUFFER_OPS",
"DISABLE_FAST_REDUCTION",
"DISABLE_LLVM_OPT",
"DISABLE_MMA_V3",
Expand Down
39 changes: 39 additions & 0 deletions test/Conversion/amd/buffer_load_store.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: AMDGCN_USE_BUFFER_OPS=1 triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --convert-builtin-func-to-llvm | FileCheck %s

#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: buffer_load_store_vec8
tt.func @buffer_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
// Load 8 elements from A with two vectorized load instruction
// CHECK-COUNT-5: llvm.select
// CHECK: %[[mask0:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask0]]
// CHECK: %[[mask1:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask1]]
// CHECK: %[[mask2:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask2]]
// CHECK: %[[mask3:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[mask3]]
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
// CHECK: %[[mask4:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask4]]
// CHECK: %[[mask5:.*]] = llvm.select
// CHECK: rocdl.raw.ptr.buffer.store{{.*}}, {{.*}}, %[[mask5]]
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
tt.store %13, %11 : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
}
}
38 changes: 37 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from triton.backends.compiler import BaseBackend, GPUTarget
from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor
from triton._C.libtriton import ir, passes, llvm, amd
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from types import ModuleType
import hashlib
import tempfile
import os
import sys
import re
import subprocess
import functools
Expand Down Expand Up @@ -71,6 +72,34 @@ def hash(self):
return hashlib.sha256(key.encode("utf-8")).hexdigest()


class HIPAttrsDescriptor(AttrsDescriptor):
__slots__ = ("pointer_range")

@staticmethod
def is_within2gb(arg):
if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
return sys.getsizeof(arg.untyped_storage()) < 2**31 - 1
return False

def _add_backend_properties(self, params=None, values=None):
self.property_values["tt.pointer_range"] = 32
if (params is None or values is None):
return

# tt.pointer_range: does the pointer space fit in 2GB
self.arg_properties["tt.pointer_range"] = [
param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]

@staticmethod
def get_property_key(val, align):
generic_key = AttrsDescriptor.get_property_key(val, align)
hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N"
key = (generic_key + hip_key).replace("N", "")
return key if key else "N"


class HIPBackend(BaseBackend):

@staticmethod
Expand Down Expand Up @@ -117,6 +146,13 @@ def get_module_map(self) -> Dict[str, ModuleType]:
def load_dialects(self, ctx):
amd.load_dialects(ctx)

def get_attrs_descriptor(self, params, args):
return HIPAttrsDescriptor(params, args)

@staticmethod
def compute_spec_key(arg, align):
return HIPAttrsDescriptor.get_property_key(arg, align)

@staticmethod
def path_to_rocm_lld():
# Check env path for ld.lld
Expand Down
Loading

0 comments on commit 6dafa47

Please sign in to comment.