Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ jobs:
./python

# AMD ROCm tests
# runtime and transform tests needs to repair, then rm it from ignore list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix grammar and add tracking reference.

The comment has a grammatical error: "needs to repair" should be "need to be repaired" or "need repair". Additionally, consider adding a TODO marker or issue reference to track when these directories can be un-ignored.

🔎 Suggested fix
-      # runtime and transform tests needs to repair, then rm it from ignore list
+      # TODO(#XXXX): runtime and transform tests need to be repaired, then remove from ignore list

Replace #XXXX with an actual issue number. Would you like me to help create a tracking issue for this technical debt?

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# runtime and transform tests needs to repair, then rm it from ignore list
# TODO(#XXXX): runtime and transform tests need to be repaired, then remove from ignore list
🤖 Prompt for AI Agents
In .github/workflows/ci.yml around line 395, the inline comment "runtime and
transform tests needs to repair, then rm it from ignore list" is grammatically
incorrect and lacks tracking: change it to a clear TODO such as "# TODO: runtime
and transform tests need to be repaired; remove from ignore list — see
ISSUE-XXXX" (or "# TODO: runtime and transform tests need repair; remove from
ignore list — see #1234"), and replace ISSUE-XXXX/#1234 with an actual issue
number or create one and reference it.

- name: Run ROCm tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
id: rocm-tests
if: contains(matrix.runner.toolkit, 'ROCm')
Expand All @@ -383,7 +384,8 @@ jobs:
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
)
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \
./python/amd
--ignore=./python/runtime --ignore=./python/transform \
./python

# Apple Metal tests
- name: Run Metal tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }})
Expand Down
2 changes: 2 additions & 0 deletions src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
int kNPerWarp = 8; // Columns processed by a single warp
if (TargetIsVolta(target)) {
kNPerWarp = 16;
} else if (TargetIsCDNA(target)) {
kNPerWarp = 16;
}
ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M;
Expand Down
6 changes: 4 additions & 2 deletions src/op/logical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ TVM_REGISTER_OP("tl.any_of")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "any_of")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op);
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", any_of_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", any_of_op);

TVM_REGISTER_OP("tl.all_of")
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "all_of")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op);
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", all_of_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", all_of_op);

} // namespace tl
} // namespace tvm
44 changes: 29 additions & 15 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -767,24 +767,26 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
this->stream << ");\n";
};
if (op->op.same_as(builtin::ptx_cp_async())) {
if (op->op.same_as(builtin::ptx_cp_async()) ||
op->op.same_as(tl::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< dst_offset << ", " << src << "+" << src_offset << ");\n";
std::string src = this->PrintExpr(op->args[1]);
std::string size = this->PrintExpr(op->args[2]);
this->PrintIndent();
if (op->args.size() == 3) {
// Non-predicated version
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src
<< ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
// Predicated version
std::string condition = this->PrintExpr(op->args[3]);
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
<< ", " << src << ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
Expand Down Expand Up @@ -942,6 +944,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float32x4", "float32x4"},
{"float8_e4m3fnuzx4", "fp8_e4_4_t"},
{"float8_e4m3fnuzx8", "long"},
{"float8_e5m2fnuzx4", "fp8_e5_4_t"},
{"float8_e5m2fnuzx8", "long"},
Comment on lines +947 to +948
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the file context around lines 945-946
sed -n '935,955p' src/target/codegen_hip.cc

Repository: tile-ai/tilelang

Length of output: 890


🏁 Script executed:

# Also search for other FP8 mappings to understand the pattern
rg -A5 -B5 "float8_e5m2fnuzx" src/target/codegen_hip.cc

Repository: tile-ai/tilelang

Length of output: 638


🏁 Script executed:

# Search for type definitions related to fp8_e5_4_t
rg "fp8_e5_4_t|fp8.*e5" --type cc --type h

Repository: tile-ai/tilelang

Length of output: 86


🌐 Web query:

ROCm MFMA FP8 e5m2 data type support and mappings hip header

💡 Result:

Short summary of ROCm FP8 (E5M2) / MFMA HIP mappings and header names

  • HIP C typedefs / builtins:

    • __hip_fp8_e5m2 — standard OCP E5M2 FP8 type. [1]
    • __hip_fp8_e5m2_fnuz — FNUZ variant (no inf, signed-zero handling per AMD). [1][2]
  • hipDataType enum entries (library_types.h):

    • HIP_R_8F_E5M2 = 29 (maps to __hip_fp8_e5m2). [1]
    • HIP_R_8F_E5M2_FNUZ = 1001 (maps to __hip_fp8_e5m2_fnuz). [1]
  • Notes about MFMA / hardware support:

    • FP8 (E5M2) is supported as an internal compute type on CDNA3 / MI300-class Matrix cores (MFMA/Matrix units); library support (rocWMMA/rocBLAS/hipBLAS/etc.) depends on hardware and ROCm release. [2][1]

References
[1] ROCm “Data types and precision support” / hip type & hipDataType listings.
[2] ROCm “Precision support” (notes on E5M2, FNUZ and MI300/Matrix cores).


🏁 Script executed:

# Search for fp8_e5_4_t and fp8_e4_4_t type definitions
rg "fp8_e[45]_4_t" src/

Repository: tile-ai/tilelang

Length of output: 2784


🏁 Script executed:

# Check if these types are defined in HIP headers or locally
fd -e h | xargs grep -l "fp8_e[45]_4_t" 2>/dev/null || echo "Not found in headers"

Repository: tile-ai/tilelang

Length of output: 123


🏁 Script executed:

# Search more broadly for FP8 type definitions
rg "typedef.*fp8|struct.*fp8" --type-list | head -20

Repository: tile-ai/tilelang

Length of output: 449


🏁 Script executed:

# Look for where these mappings are used/applied
rg "float8_e5m2fnuzx4|float8_e4m3fnuzx4" src/

Repository: tile-ai/tilelang

Length of output: 201


🏁 Script executed:

# Read the hip_fp8.h file to see all FP8 type definitions
cat -n src/tl_templates/hip/hip_fp8.h | head -100

Repository: tile-ai/tilelang

Length of output: 3006


🏁 Script executed:

# Check if there's an 8-element FP8 type
rg "fp8_e[45]_8_t|e[45]m[23].*8" src/tl_templates/hip/hip_fp8.h

Repository: tile-ai/tilelang

Length of output: 321


🏁 Script executed:

# Look at the full structure of fp8_e5_8_t and understand the x8 mapping
sed -n '1,200p' src/tl_templates/hip/hip_fp8.h

Repository: tile-ai/tilelang

Length of output: 5884


🏁 Script executed:

# Look at broader context around the dtype map definition
sed -n '920,960p' src/target/codegen_hip.cc

Repository: tile-ai/tilelang

Length of output: 1932


🏁 Script executed:

# Search for any comments or documentation about x8 → long mapping
rg -B10 "float8.*fnuzx8.*long" src/target/codegen_hip.cc

Repository: tile-ai/tilelang

Length of output: 525


🏁 Script executed:

# Check if there are MFMA builtin requirements documented
rg "mfma.*float8|MFMA.*FP8" src/target/codegen_hip.cc | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for any test or usage of these x8 mappings
rg "float8_e[45]m[23]fnuzx8" src/ --type cc --type h

Repository: tile-ai/tilelang

Length of output: 86


Fix FP8 x8 dtype mappings to use proper struct types.

The x8 FP8 dtype mappings use long instead of the semantically correct struct types defined in hip_fp8.h. The correct mappings should be:

  • float8_e4m3fnuzx8fp8_e4_8_t (currently maps to long)
  • float8_e5m2fnuzx8fp8_e5_8_t (currently maps to long)

Both fp8_e4_8_t and fp8_e5_8_t are fully defined in src/tl_templates/hip/hip_fp8.h with proper member access and alignment. Using long is inconsistent with how x4 variants are correctly mapped to their struct types (fp8_e4_4_t and fp8_e5_4_t) and loses type safety for MFMA operations.

🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 945-946 the FP8 x8 dtype mappings
incorrectly use "long"; replace the two mappings so "float8_e4m3fnuzx8" maps to
"fp8_e4_8_t" and "float8_e5m2fnuzx8" maps to "fp8_e5_8_t" and ensure the file
includes the hip_fp8.h header (or the header that defines fp8_e4_8_t and
fp8_e5_8_t) so the struct types are available for MFMA/type-safe operations.

{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
Expand Down Expand Up @@ -980,6 +984,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
// HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP
return;
} else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) {
os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_min())) {
os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitand())) {
os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_bitor())) {
os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
8 changes: 4 additions & 4 deletions src/tl_templates/cuda/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
threads == 128 or threads == 64 or threads == 32);
template <typename T, int SEG = 32>
static TL_DEVICE T run(const T *__restrict__ src, T *__restrict__ dst, int H,
int W) {
static TL_DEVICE void run(const T *__restrict__ src, T *__restrict__ dst,
int H, int W) {
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the actual file to confirm the current signature
head -n 200 src/tl_templates/cuda/reduce.h | tail -n 30

Repository: tile-ai/tilelang

Length of output: 1065


🏁 Script executed:

# Search for all calls to CumSum2D::run
rg -n 'CumSum2D.*::run\s*\(' --type=cpp --type=h -C3

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Also search more broadly for CumSum2D usage
rg -n 'CumSum2D' --type=cpp --type=h

Repository: tile-ai/tilelang

Length of output: 365


🏁 Script executed:

# Look for any assignment from CumSum2D::run calls
rg -n 'CumSum2D.*run\s*\(' . --type=cpp --type=h -A2 -B2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the reduce.cc file around line 551 to see how CumSum2D is used
sed -n '540,560p' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 919


🏁 Script executed:

# Search for actual invocations/calls of CumSum2D::run
rg 'CumSum2D' -A5 -B5 --type=cc --type=h | head -80

Repository: tile-ai/tilelang

Length of output: 86


🏁 Script executed:

# Look for any kernel launches or calls that might use CumSum2D
rg -n '\.run\s*\(' src/op/reduce.cc -A3 -B3

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the full context of CumSum2D in reduce.cc
sed -n '545,570p' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 1101


🏁 Script executed:

# Check the exact lines again with more context
sed -n '545,570p' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 1101


🏁 Script executed:

# Look for any mismatch between Call return type and function signature
rg 'Evaluate\(Call\(' src/op/reduce.cc -B10 | head -40

Repository: tile-ai/tilelang

Length of output: 567


🏁 Script executed:

# Check if there are TODO/FIXME comments related to CumSum
rg -i 'TODO|FIXME' src/op/reduce.cc | grep -i cumsum

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Verify the current signature of CumSum1D and CumSum2D to understand the pattern
rg 'CumSum1D.*run' -A2 src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the full CumSum implementation pattern to understand intent
cat src/tl_templates/cuda/reduce.h | head -250 | tail -80

Repository: tile-ai/tilelang

Length of output: 2419


🏁 Script executed:

# Look for other uses of call_extern with different return types
rg 'call_extern' src/op/reduce.cc -B3 -A3

Repository: tile-ai/tilelang

Length of output: 1198


🏁 Script executed:

# Check if there are any type conversion or void handling in the call
rg -A20 'CumSum2D.*::run' src/tl_templates/cuda/reduce.h | head -30

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check CumSum1D to see if it also returns void
rg -B2 -A5 'static TL_DEVICE.*run.*CumSum1D' src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the full context around the Call to understand what dst->dtype represents
sed -n '520,575p' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 2195


🏁 Script executed:

# Find CumSum1D::run signature
rg -A3 'struct CumSum1D' src/tl_templates/cuda/reduce.h | head -30

Repository: tile-ai/tilelang

Length of output: 298


🏁 Script executed:

# Check if there are any recent changes or git history
git log --oneline src/op/reduce.cc | head -10

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

# Check for any compilation errors or warnings
rg 'void' src/tl_templates/cuda/reduce.h | grep -i cumsum

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Get the full CumSum1D::run signature
rg -A8 'struct CumSum1D' src/tl_templates/cuda/reduce.h | grep -A5 'static TL_DEVICE'

Repository: tile-ai/tilelang

Length of output: 203


🏁 Script executed:

# Check git history to see if this was a recent change
git log -p --follow -- src/tl_templates/cuda/reduce.h | grep -B5 -A5 'CumSum.*void' | head -30

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look at the last commit that touched this file
git log --oneline -n 5 -- src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

# Check if there's special handling for void in call_extern
rg 'call_extern' src/ -B2 -A2 | grep -B5 -A5 'void'

Repository: tile-ai/tilelang

Length of output: 1550


🏁 Script executed:

# Look for any Cast operations around the CumSum call
sed -n '535,575p' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 1532


🏁 Script executed:

# Check the type of builtin::call_extern to understand expected return type
rg 'builtin::call_extern' src/ -B3 | head -30

Repository: tile-ai/tilelang

Length of output: 2119


🏁 Script executed:

# Verify if there's a mismatch between function signature and codegen
rg -A1 'Call.*dst.*dtype.*call_extern' src/op/reduce.cc

Repository: tile-ai/tilelang

Length of output: 139


🏁 Script executed:

# Look for other examples of Call with void return type
rg 'Call\(.*void' src/ --type=cpp | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check if DataType::Void() exists
rg 'DataType::Void\|void_dtype' src/ --type=cpp

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Examine how other void-returning functions are called via call_extern
rg -B5 'DataType::Handle' src/op/ | grep -A3 'call_extern'

Repository: tile-ai/tilelang

Length of output: 457


🏁 Script executed:

# Check if the mismatch between Call(dst->dtype) and void return has any handling
git log -p --all -- src/op/reduce.cc | grep -B10 -A10 'CumSum2D.*void' | head -40

Repository: tile-ai/tilelang

Length of output: 42


Fix code generator to match void return type of CumSum functions.

Both CumSum1D::run and CumSum2D::run now return void, writing results directly to dst. However, the code generator in src/op/reduce.cc line 570 incorrectly declares the return type as dst->dtype. Update the call to use DataType::Handle() instead, matching the pattern used elsewhere for void-returning call_extern invocations (see src/op/gemm.cc and src/op/copy.cc).

Current code generation (reduce.cc:570)
return Evaluate(Call(dst->dtype, builtin::call_extern(), args));
🤖 Prompt for AI Agents
In src/op/reduce.cc around line 570, the code generator incorrectly emits a
return type of dst->dtype for CumSum call_extern invocations even though
CumSum1D::run and CumSum2D::run are void and write directly into dst; change the
emitted return type to DataType::Handle() (the void/handle pattern used in
src/op/gemm.cc and src/op/copy.cc) so the generated call uses DataType::Handle()
for call_extern and the Evaluate(...) call matches the void-returning extern
function.


constexpr int TILE_H = threads / SEG;
constexpr unsigned MASK = 0xffffffff;
const int num_blocks = (H + TILE_H - 1) / TILE_H;
const int tid = threadIdx.x;
const int lane = tid % 32;
const int row = tid / 32;
const int lane = tid % SEG;
const int row = tid / SEG;

for (int b = 0; b < num_blocks; ++b) {
const int gRow = b * TILE_H + row;
Expand Down
104 changes: 104 additions & 0 deletions src/tl_templates/hip/atomic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#pragma once

#include <hip/hip_runtime.h>

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicAdd(T1 *address, T2 val,
int memory_order = 0) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicAdd(T1 &address, T2 val,
int memory_order = 0) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
template <typename T1, typename T2>
__forceinline__ __device__ T1 AtomicAddRet(T1 *ref, T2 val,
int memory_order = 0) {
return atomicAdd(ref, static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicMax(T1 *address, T2 val,
int memory_order = 0) {
atomicMax(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicMax(T1 &address, T2 val,
int memory_order = 0) {
atomicMax(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicMin(T1 *address, T2 val,
int memory_order = 0) {
atomicMin(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
__forceinline__ __device__ void AtomicMin(T1 &address, T2 val,
int memory_order = 0) {
atomicMin(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}

__forceinline__ __device__ void AtomicAddx2(float *ref, float *val,
int memory_order = 0) {
float2 add_val = *reinterpret_cast<float2 *>(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
__forceinline__ __device__ float2 AtomicAddx2Ret(float *ref, float *val,
int memory_order = 0) {
float2 add_val = *reinterpret_cast<float2 *>(val);
float2 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
return ret;
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
__forceinline__ __device__ void AtomicAddx4(float *ref, float *val,
int memory_order = 0) {
float4 add_val = *reinterpret_cast<float4 *>(val);
atomicAdd(ref + 0, add_val.x);
atomicAdd(ref + 1, add_val.y);
atomicAdd(ref + 2, add_val.z);
atomicAdd(ref + 3, add_val.w);
}

// Add an extra unused input to accommodate the additional 'memory_order'
// argument during lowering.
__forceinline__ __device__ float4 AtomicAddx4Ret(float *ref, float *val,
int memory_order = 0) {
float4 add_val = *reinterpret_cast<float4 *>(val);
float4 ret;
ret.x = atomicAdd(ref + 0, add_val.x);
ret.y = atomicAdd(ref + 1, add_val.y);
ret.z = atomicAdd(ref + 2, add_val.z);
ret.w = atomicAdd(ref + 3, add_val.w);
return ret;
}
98 changes: 88 additions & 10 deletions src/tl_templates/hip/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include "atomic.h"
#include <ck_tile/core.hpp>
#include <hip/amd_detail/amd_warp_functions.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
Expand Down Expand Up @@ -105,18 +107,94 @@ TL_DEVICE unsigned __pack_bfloat162(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}

template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
namespace tl {

// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) {
if (a[i]) {
return true;
}
}
return false;
}

// All
template <typename T> TL_DEVICE bool All(T *a, int size) {
for (int i = 0; i < size; i++) {
if (!a[i]) {
return false;
}
}
return true;
}

// TODO(gong): support shfl_sync(rocm 7.1.1 provide shfl_sync)
// shfl_sync func
template <typename T> TL_DEVICE T shfl_xor(T val, int delta) {
return __shfl_xor(val, delta);
}

template <typename T> TL_DEVICE T shfl_down(T val, int delta) {
return __shfl_down(val, delta);
}

template <typename T> TL_DEVICE T shfl_up(T val, int delta) {
return __shfl_up(val, delta);
}

template <typename T> TL_DEVICE T shfl(T val, int srcLane) {
return __shfl(val, srcLane);
}

// specialize half_t
template <> TL_DEVICE half_t shfl_xor(half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_xor(f, delta);
return half_t(r);
}

template <> TL_DEVICE half_t shfl_down(half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down(f, delta);
return half_t(r);
}

template <> TL_DEVICE half_t shfl_up(half_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up(f, delta);
return half_t(r);
}

template <> TL_DEVICE half_t shfl(half_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl(f, srcLane);
return half_t(r);
}

// specialize bfloat16_t
template <> TL_DEVICE bfloat16_t shfl_xor(bfloat16_t val, int laneMask) {
float f = static_cast<float>(val);
float r = __shfl_xor(f, laneMask);
return bfloat16_t(r);
}

// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
template <> TL_DEVICE bfloat16_t shfl_down(bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_down(f, delta);
return bfloat16_t(r);
}

template <typename T1, typename T2>
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val) {
return atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
template <> TL_DEVICE bfloat16_t shfl_up(bfloat16_t val, int delta) {
float f = static_cast<float>(val);
float r = __shfl_up(f, delta);
return bfloat16_t(r);
}

template <> TL_DEVICE bfloat16_t shfl(bfloat16_t val, int srcLane) {
float f = static_cast<float>(val);
float r = __shfl(f, srcLane);
return bfloat16_t(r);
}

} // namespace tl
Loading
Loading