Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,34 @@ CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
// https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
struct waitcnt_arg
{
#if defined(__gfx12__)
// use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;

CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;

template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_vmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
return MAX & (cnt << 8);
}

template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_expcnt()
{
return 0; // no export in MI series
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean Navi Series?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, expcnt is only used on graphics part; such as position export in vertex shader, color export in pixel shader.

}

template <index_t cnt>
CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
{
static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could I know if there is a document that we could find that difference between MI and Navi?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just check shader programming guide.

return MAX & cnt;
}
#else
// bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
// [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
Expand Down Expand Up @@ -167,25 +195,50 @@ struct waitcnt_arg
static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
return MAX & (cnt << 8);
}
#endif
};

template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt()
{
#if defined(__gfx12__)
// GFX12 do't use __builtin_amdgcn_s_waitcnt
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>();

asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a builtin for this instruction? We could also ask compiler to add one, if the answer is no.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no builtin for s_wait_loadcnt_dscnt.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix! I will add a ticket for the compiler.

#else
__builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
#endif
}

template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_barrier()
{
#if defined(__gfx12__)
// GFX12 optimization: Manual barrier implementation avoids performance penalty
// from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
waitcnt_arg::from_expcnt<expcnt>() |
waitcnt_arg::from_lgkmcnt<lgkmcnt>();

asm volatile("s_wait_loadcnt_dscnt %0\n"
"s_barrier_signal -1\n"
"s_barrier_wait -1"
:
: "n"(wait_mask)
: "memory");
#else
s_waitcnt<vmcnt, expcnt, lgkmcnt>();
__builtin_amdgcn_s_barrier();
#endif
}

template <index_t lgkmcnt = 0>
Expand Down
4 changes: 2 additions & 2 deletions include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ struct MoeSortingKernel
else
smem_tokens(curr_token_id, eid)++;
}
__builtin_amdgcn_s_waitcnt(0xc07f);
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
}
Expand Down Expand Up @@ -922,7 +922,7 @@ struct MoeSortingKernel
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
// for above write however __syncthreads will cause barrier with waves other
// than 0(which is not we want)
__builtin_amdgcn_s_waitcnt(0xc07f);
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, 0>();
}
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
{
Expand Down
Loading