Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of prefill mode FP8 Grouped Gemm #3522

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ def main(args: Any):
M = [int(m) for m in args.M.strip().split(",")]
N = [int(n) for n in args.N.strip().split(",")]
K = [int(k) for k in args.K.strip().split(",")]
if args.groups:
M = M * int(args.groups)
N = N * int(args.groups)
K = K * int(args.groups)
if args.B is None:
B = [1] * len(M)
else:
Expand Down Expand Up @@ -384,6 +388,11 @@ def invoke_main() -> None:
help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
"as the size of groups. The length of each must be the same.",
)
parser.add_argument(
"--groups",
default=None,
help="If set with grouped mode, repeat input shapes this many times.",
)
parser.add_argument(
"--no_cuda_graph",
default=False,
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def quantize_fixed_nk(self, x, w):
wq,
x_scale,
w_scale,
torch.tensor(m_values).to(dtype=torch.int32, device=xq[0].device),
torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device),
output,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,107 @@ RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
}

__global__ void set_kernel_args_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int M,
int N,
int K) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Each kernel annoyingly can only set the kernel args for one group.
// This could only be avoided with complicated memory management.
if (idx == 0) {
// Write kernel arguments directly to memory.
KernelArguments kernel_group_args = {
XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N};
kernel_args[0] = kernel_group_args;
}
}

void set_static_kernel_args(
at::Tensor kernel_args,
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::vector<at::Tensor> output) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size();
// When group count is large, we can more efficiently initialize
// by doing host setup and a memcpy. This is only viable if cuda
// graphs arent being used.
if (group_count >= 16 && stream == 0) {
std::vector<KernelArguments> ggemm_kargs;
ggemm_kargs.reserve(group_count);

// Iterate over inputs and get group information.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
KernelArguments group_args = {
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
{reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr())},
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K,
K,
K,
{0, 0},
N};
ggemm_kargs.push_back(group_args);
}
// Copy data onto device.
hipMemcpy(
kernel_args.data_ptr(), // Destination
ggemm_kargs.data(), // Source
sizeof(KernelArguments) * group_count, // Number of bytes
hipMemcpyHostToDevice); // Copy Type
} else {
// We use the smallest reasonable block size since we effectively need only
// 1 thread.
int blockSize = 32;
int numBlocks = 1;
// Launch a kernel for each group to set kernel memory on device.
// Using multiple kernels this way allows us to support arbitrary M,N,K.
// For some reason, this approach is faster than using hipmemcpy.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
// Launch kernel to set kernel arguments.
set_kernel_args_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(
reinterpret_cast<char*>(kernel_args.data_ptr()) +
(i * sizeof(KernelArguments))),
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr()),
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K);
}
}
}

__global__ void set_kernel_args_fixed_nk_kernel(
KernelArguments* kernel_args,
ADataType* XQ,
BDataType* WQ,
D0DataType* w_scale,
D1DataType* x_scale,
EDataType* output,
int32_t* prepad_M,
int64_t* prepad_M,
int M,
int N,
int K,
Expand All @@ -114,6 +207,77 @@ __global__ void set_kernel_args_fixed_nk_kernel(
}
}

void set_dynamic_kernel_args(
at::Tensor kernel_args,
at::TensorList XQ,
at::TensorList WQ,
at::TensorList x_scale,
at::TensorList w_scale,
std::vector<at::Tensor> output,
at::Tensor zero_start_index_M) {
// Get current cuda stream.
auto stream = at::cuda::getCurrentHIPStream().stream();
int group_count = XQ.size();
// Confirm M is on the proper device.
TORCH_CHECK(
XQ[0].device() == zero_start_index_M.device(),
"zero_start_index_M and inputs must be on the same device.");
TORCH_CHECK(
zero_start_index_M.size(0) == group_count,
"zero_start_index_M must have an entry for each group.");
TORCH_CHECK(
zero_start_index_M.dtype() == at::kLong,
"zero_start_index_M must be int64.");

// We assume that M, N, and K are fixed across groups.
// The actual m values are sstored in the passed M tensor.
int M = XQ[0].size(0);
int K = XQ[0].size(1);
int N = WQ[0].size(0);

// Make sure that inputs are allocated in sequential memory as required by
// this mode.
for (int i = 1; i < group_count; i++) {
// Check that all inputs are allocated directly following preceding input.
TORCH_CHECK(
XQ[i].data_ptr() ==
(reinterpret_cast<ADataType*>(XQ[i - 1].data_ptr()) + (M * K)),
"Inputs must be sequential in memory to support dynamic M, but XQ is not.");
TORCH_CHECK(
WQ[i].data_ptr() ==
(reinterpret_cast<BDataType*>(WQ[i - 1].data_ptr()) + (N * K)),
"Inputs must be sequential in memory to support dynamic M, but WQ is not.");
TORCH_CHECK(
x_scale[i].data_ptr() ==
(reinterpret_cast<D0DataType*>(x_scale[i - 1].data_ptr()) + (M)),
"Inputs must be sequential in memory to support dynamic M, but x_scale is not.");
TORCH_CHECK(
w_scale[i].data_ptr() ==
(reinterpret_cast<D1DataType*>(w_scale[i - 1].data_ptr()) + (N)),
"Inputs must be sequential in memory to support dynamic M, but w_scale is not.");
TORCH_CHECK(
output[i].data_ptr() ==
(reinterpret_cast<EDataType*>(output[i - 1].data_ptr()) + (M * N)),
"Inputs must be sequential in memory to support dynamic M, but output is not.");
}

// Launch a kernel that sets kernel argument memory.
int const blockSize = std::min(1024, group_count);
int const numBlocks = (group_count + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ[0].data_ptr()),
reinterpret_cast<BDataType*>(WQ[0].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[0].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[0].data_ptr()),
reinterpret_cast<EDataType*>(output[0].data_ptr()),
reinterpret_cast<int64_t*>(zero_start_index_M.data_ptr()),
M,
N,
K,
group_count);
}

at::Tensor get_grouped_kernel_args(
at::TensorList XQ,
at::TensorList WQ,
Expand All @@ -122,8 +286,6 @@ at::Tensor get_grouped_kernel_args(
std::optional<at::Tensor> zero_start_index_M,
std::vector<at::Tensor> output) {
int group_count = XQ.size();
auto stream = at::cuda::getCurrentHIPStream().stream();

// Get space on device for the kernel argument tensor.
at::Tensor kernel_args = at::empty(
{static_cast<long>(group_count * sizeof(KernelArguments))},
Expand All @@ -139,104 +301,17 @@ at::Tensor get_grouped_kernel_args(
// using zero_start_index_M.

if (zero_start_index_M.has_value()) {
// Make sure zero_start_index_M is configured properly.
at::Tensor prepad_M = zero_start_index_M.value();
// Confirm M is on the proper device.
TORCH_CHECK(
XQ[0].device() == prepad_M.device(),
"zero_start_index_M and inputs must be on the same device.");
TORCH_CHECK(
prepad_M.size(0) == group_count,
"zero_start_index_M must have an entry for each group.");

// We assume that M, N, and K are fixed across groups.
// The actual m values are sstored in the passed M tensor.
int M = XQ[0].size(0);
int K = XQ[0].size(1);
int N = WQ[0].size(0);

// Make sure that inputs are allocated in sequential memory as required by
// this mode.
for (int i = 1; i < group_count; i++) {
// Check that all inputs are allocated directly following preceding input.
TORCH_CHECK(
XQ[i].data_ptr() ==
(reinterpret_cast<ADataType*>(XQ[i - 1].data_ptr()) + (M * K)),
"Inputs must be sequential in memory to support dynamic M, but XQ is not.");
TORCH_CHECK(
WQ[i].data_ptr() ==
(reinterpret_cast<BDataType*>(WQ[i - 1].data_ptr()) + (N * K)),
"Inputs must be sequential in memory to support dynamic M, but WQ is not.");
TORCH_CHECK(
x_scale[i].data_ptr() ==
(reinterpret_cast<D0DataType*>(x_scale[i - 1].data_ptr()) + (M)),
"Inputs must be sequential in memory to support dynamic M, but x_scale is not.");
TORCH_CHECK(
w_scale[i].data_ptr() ==
(reinterpret_cast<D1DataType*>(w_scale[i - 1].data_ptr()) + (N)),
"Inputs must be sequential in memory to support dynamic M, but w_scale is not.");
TORCH_CHECK(
output[i].data_ptr() ==
(reinterpret_cast<EDataType*>(output[i - 1].data_ptr()) +
(M * N)),
"Inputs must be sequential in memory to support dynamic M, but output is not.");
}

// Launch a kernel that sets kernel argument memory.
int const blockSize = std::min(1024, group_count);
int const numBlocks = (group_count + blockSize - 1) / blockSize;
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
reinterpret_cast<ADataType*>(XQ[0].data_ptr()),
reinterpret_cast<BDataType*>(WQ[0].data_ptr()),
reinterpret_cast<D0DataType*>(w_scale[0].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[0].data_ptr()),
reinterpret_cast<EDataType*>(output[0].data_ptr()),
reinterpret_cast<int32_t*>(prepad_M.data_ptr()),
M,
N,
K,
group_count);
return kernel_args;
set_dynamic_kernel_args(
kernel_args,
XQ,
WQ,
x_scale,
w_scale,
output,
zero_start_index_M.value());
} else {
// When running in eager mode, we assume we can directly interact with host
// values.
// Note that this version is not supported with cuda graphs.
TORCH_CHECK(
stream == 0,
"f8f8bf16_rowwise_grouped eager mode is not supported with cuda graphs.");

std::vector<KernelArguments> ggemm_kargs;
ggemm_kargs.reserve(group_count);

// Iterate over inputs and get group information.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int K = XQ[i].size(1);
int N = WQ[i].size(0);
KernelArguments group_args = {
reinterpret_cast<ADataType*>(XQ[i].data_ptr()),
reinterpret_cast<BDataType*>(WQ[i].data_ptr()),
{reinterpret_cast<D0DataType*>(w_scale[i].data_ptr()),
reinterpret_cast<D1DataType*>(x_scale[i].data_ptr())},
reinterpret_cast<EDataType*>(output[i].data_ptr()),
M,
N,
K,
K,
K,
{0, 0},
N};
ggemm_kargs.push_back(group_args);
}
// Copy data onto device.
hipMemcpy(
kernel_args.data_ptr(), // Destination
ggemm_kargs.data(), // Source
sizeof(KernelArguments) * group_count, // Number of bytes
hipMemcpyHostToDevice); // Copy Type
set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, output);
}

return kernel_args;
}

Expand Down
Loading