Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0);
}

int group_count = 4;
int group_count = rand() % 16 + 1;

// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
Expand Down Expand Up @@ -189,12 +189,17 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};

// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();

// do GEMM
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);

DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));

gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());

if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
Expand Down
2 changes: 2 additions & 0 deletions include/ck/tensor_operation/gpu/device/device_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ struct BaseOperator

virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }

virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
Comment thread
asroy marked this conversation as resolved.

virtual ~BaseOperator() {}
};

Expand Down
206 changes: 97 additions & 109 deletions include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,33 @@ template <typename GridwiseGemm,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop,
index_t MaxGroupCount>
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];

const index_t block_id = get_block_1d_id();

#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;

GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));

index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd &&
i < group_count)
? i
: group_id;
});

const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
for(index_t i = 0; i < group_count; i++)
{
group_id =
(block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
? i
: group_id;
}

GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
Expand All @@ -87,9 +63,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
#else
ignore = gemm_descs;
ignore = group_count;
Expand Down Expand Up @@ -389,6 +363,8 @@ struct DeviceGroupedGemmXdl
{
grid_size_ = 0;

gemm_descs_args_workspace_ = nullptr;

group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());

if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
Expand Down Expand Up @@ -463,6 +439,8 @@ struct DeviceGroupedGemmXdl

std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;

void* gemm_descs_args_workspace_;

index_t grid_size_;
};

Expand All @@ -473,49 +451,49 @@ struct DeviceGroupedGemmXdl

float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;

bool has_main_k_block_loop = true;

static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_desc_kernel_arg_.size())
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";

std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";

std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;

if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
{
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i];

std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";

std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";

std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;

if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}

const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);

if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
});

const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);

if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}

hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));

float ave_time = 0;

Expand All @@ -525,47 +503,47 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<GemmDescKernelArg>,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
true,
MaxGroupCount>;

ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
true>;

ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
Comment thread
asroy marked this conversation as resolved.
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<GemmDescKernelArg>,
GemmDescKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
false,
MaxGroupCount>;

ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_args,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
false>;

ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}

return ave_time;
Expand Down Expand Up @@ -654,6 +632,16 @@ struct DeviceGroupedGemmXdl

return str.str();
}

size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}

void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
};

} // namespace device
Expand Down
7 changes: 6 additions & 1 deletion test/grouped_gemm/grouped_gemm_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto c_element_op = PassThrough{};

// do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();

auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);

DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));

groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());

invoker_ptr->Run(argument_ptr.get());

for(std::size_t i = 0; i < gemm_shapes.size(); i++)
Expand Down