From cb09eac8d2d6f3d7877c4a71b50d8601a0d6715a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 3 Jun 2024 06:44:28 +0000 Subject: [PATCH 01/12] wip --- .../flashinfer/group_gemm/cutlass_wrapper.cuh | 21 +++++++++++++++++++ include/flashinfer/group_gemm/sgmv.cuh | 15 +++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 include/flashinfer/group_gemm/cutlass_wrapper.cuh create mode 100644 include/flashinfer/group_gemm/sgmv.cuh diff --git a/include/flashinfer/group_gemm/cutlass_wrapper.cuh b/include/flashinfer/group_gemm/cutlass_wrapper.cuh new file mode 100644 index 0000000000..6004ac205e --- /dev/null +++ b/include/flashinfer/group_gemm/cutlass_wrapper.cuh @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" diff --git a/include/flashinfer/group_gemm/sgmv.cuh b/include/flashinfer/group_gemm/sgmv.cuh new file mode 100644 index 0000000000..458fbd05a3 --- /dev/null +++ b/include/flashinfer/group_gemm/sgmv.cuh @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ From ddebf667ff4a60261932f6129cde9f436bea770a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 3 Jun 2024 09:13:55 +0000 Subject: [PATCH 02/12] upd --- .../flashinfer/group_gemm/cutlass_wrapper.cuh | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/include/flashinfer/group_gemm/cutlass_wrapper.cuh b/include/flashinfer/group_gemm/cutlass_wrapper.cuh index 6004ac205e..7e0e4b8d5a 100644 --- a/include/flashinfer/group_gemm/cutlass_wrapper.cuh +++ b/include/flashinfer/group_gemm/cutlass_wrapper.cuh @@ -14,8 +14,31 @@ * limitations under the License. */ +#include #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" + +namespace flashinfer { + +template +__global__ void group_gemm_args_kernel( + cutlass::gemm::GemmCoord *all_problems, + DType **ptr_y, + DType **ptr_x, + DType **ptr_w, + int64_t *ld_y, + int64_t *ld_x, + int64_t *ld_w, + DType *y, + DType *x, + DType **w, + int64_t d_in, + int64_t d_out +) { + // TODO(Zihao) +} + +} \ No newline at end of file From 9f8eed080f0a329c51555c0312753fc7be60c048 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 3 Jun 2024 19:03:10 +0000 Subject: [PATCH 03/12] wip --- include/flashinfer/allocator.h | 44 +++++++++++++++++++++++ include/flashinfer/attention/handler.cuh | 27 +++----------- include/flashinfer/group_gemm/handler.cuh | 27 ++++++++++++++ 3 files changed, 75 insertions(+), 23 deletions(-) create mode 100644 include/flashinfer/allocator.h create mode 100644 include/flashinfer/group_gemm/handler.cuh diff --git a/include/flashinfer/allocator.h b/include/flashinfer/allocator.h new file mode 100644 index 0000000000..dbeb2c2123 --- /dev/null +++ b/include/flashinfer/allocator.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_ALLOCATOR_H_ +#define FLASHINFER_ALLOCATOR_H_ + +#include +#include + +namespace flashinfer { + +struct AlignedAllocator { + void* ptr; + size_t space; + AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} + template + T* aligned_alloc(size_t size, size_t alignment) { + if (std::align(alignment, size, ptr, space)) { + T* result = reinterpret_cast(ptr); + ptr = (char*)ptr + size; + space -= size; + return result; + } else { + throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor"); + } + return nullptr; + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_ALLOCATOR_H_ diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 2f880f1625..c162121679 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -13,19 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_HANDLER_CUH_ -#define FLASHINFER_HANDLER_CUH_ +#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_ +#define FLASHINFER_ATTENTION_HANDLER_CUH_ #include #include -#include #include -#include #include #include "../page.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" +#include "../allocator.h" namespace flashinfer { @@ -241,24 +240,6 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( return cudaSuccess; } -struct AlignedAllocator { - void* ptr; - size_t space; - AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} - template - T* aligned_alloc(size_t size, size_t alignment) { - if (std::align(alignment, size, ptr, space)) { - T* result = reinterpret_cast(ptr); - ptr = (char*)ptr + size; - space -= size; - return result; - } else { - throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor"); - } - return nullptr; - } -}; - class BatchDecodeHandler { public: template @@ -584,4 +565,4 @@ class BatchPrefillHandler { }; } // namespace flashinfer -#endif // FLASHINFER_HANDLER_CUH_ +#endif // FLASHINFER_ATTENTION_HANDLER_CUH_ diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh new file mode 100644 index 0000000000..297d3580f8 --- /dev/null +++ b/include/flashinfer/group_gemm/handler.cuh @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_HANDLER_H +#define FLASHINFER_GROUP_GEMM_HANDLER_H + +namespace flashinfer { + +class GroupGEMMHandler { + public: +}; + +} // namespace flashinfer + +#endif //FLASHINFER_GROUP_GEMM_HANDLER_H \ No newline at end of file From 120c2fe875d6756f29bca892b9f27b04df417c5b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 10:08:47 +0000 Subject: [PATCH 04/12] upd --- include/flashinfer/allocator.h | 2 +- include/flashinfer/attention/handler.cuh | 2 +- .../group_gemm/group_gemm_cutlass.cuh | 65 +++++++++++ ...utlass_wrapper.cuh => group_gemm_lora.cuh} | 35 ++---- include/flashinfer/group_gemm/group_gemv.cuh | 29 +++++ include/flashinfer/group_gemm/handler.cuh | 110 +++++++++++++++++- include/flashinfer/group_gemm/wrapper.cuh | 88 ++++++++++++++ python/csrc/batch_prefill.cu | 4 +- .../sgmv.cuh => python/csrc/group_gemm.cu | 2 +- python/flashinfer/__init__.py | 1 + python/flashinfer/group_gemm.py | 83 +++++++++++++ python/setup.py | 1 + python/tests/test_batch_prefill_kernels.py | 29 +++-- 13 files changed, 404 insertions(+), 47 deletions(-) create mode 100644 include/flashinfer/group_gemm/group_gemm_cutlass.cuh rename include/flashinfer/group_gemm/{cutlass_wrapper.cuh => group_gemm_lora.cuh} (54%) create mode 100644 include/flashinfer/group_gemm/group_gemv.cuh create mode 100644 include/flashinfer/group_gemm/wrapper.cuh rename include/flashinfer/group_gemm/sgmv.cuh => python/csrc/group_gemm.cu (99%) create mode 100644 python/flashinfer/group_gemm.py diff --git a/include/flashinfer/allocator.h b/include/flashinfer/allocator.h index dbeb2c2123..e4840f167d 100644 --- a/include/flashinfer/allocator.h +++ b/include/flashinfer/allocator.h @@ -16,8 +16,8 @@ #ifndef FLASHINFER_ALLOCATOR_H_ #define FLASHINFER_ALLOCATOR_H_ -#include #include +#include namespace flashinfer { diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index c162121679..0fe1750a32 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -21,10 +21,10 @@ #include #include +#include "../allocator.h" #include "../page.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" -#include "../allocator.h" namespace flashinfer { diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh new file mode 100644 index 0000000000..3acb608963 --- /dev/null +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ +#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +namespace flashinfer { + +namespace group_gemm { + +template +struct cutlass_dtype { + using type = T; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::half_t; +}; + +template <> +struct cutlass_dtype { + using type = cutlass::bfloat16_t; +}; + +template +__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_y, + T** ptr_x, T** ptr_w, int64_t* ld_y, int64_t* ld_x, + int64_t* ld_w, T* y, T* x, T* w, int64_t* xy_indptr, + int64_t* w_indices, size_t d_in, size_t d_out, + bool w_column_major) { + int i = blockIdx.x; + int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; + all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); + ptr_w[i] = w + w_indices[i] * d_in * d_out; + ptr_x[i] = x + xy_indptr[i] * d_in; + ptr_y[i] = y + xy_indptr[i] * d_out; + ld_x[i] = k; // m * k + ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major + ld_y[i] = n; // m * n +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_ \ No newline at end of file diff --git a/include/flashinfer/group_gemm/cutlass_wrapper.cuh b/include/flashinfer/group_gemm/group_gemm_lora.cuh similarity index 54% rename from include/flashinfer/group_gemm/cutlass_wrapper.cuh rename to include/flashinfer/group_gemm/group_gemm_lora.cuh index 7e0e4b8d5a..517419da5d 100644 --- a/include/flashinfer/group_gemm/cutlass_wrapper.cuh +++ b/include/flashinfer/group_gemm/group_gemm_lora.cuh @@ -13,32 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include -#include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" +#ifndef FLASHINFER_GROUP_GEMM_LORA_CUH_ +#define FLASHINFER_GROUP_GEMM_LORA_CUH_ namespace flashinfer { -template -__global__ void group_gemm_args_kernel( - cutlass::gemm::GemmCoord *all_problems, - DType **ptr_y, - DType **ptr_x, - DType **ptr_w, - int64_t *ld_y, - int64_t *ld_x, - int64_t *ld_w, - DType *y, - DType *x, - DType **w, - int64_t d_in, - int64_t d_out -) { - // TODO(Zihao) -} +namespace group_gemm { + +// TODO(Zihao): port punica's sgmv kernel + +} // namespace group_gemm + +} // namespace flashinfer -} \ No newline at end of file +#endif // FLASHINFER_GROUP_GEMM_LORA_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemv.cuh b/include/flashinfer/group_gemm/group_gemv.cuh new file mode 100644 index 0000000000..f44fdee620 --- /dev/null +++ b/include/flashinfer/group_gemm/group_gemv.cuh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMV_CUH_ +#define FLASHINFER_GROUP_GEMV_CUH_ + +namespace flashinfer { + +namespace group_gemm { + +// TODO(Zihao): port punica's bgmv kernel + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMV_CUH_ \ No newline at end of file diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh index 297d3580f8..ebc02c807c 100644 --- a/include/flashinfer/group_gemm/handler.cuh +++ b/include/flashinfer/group_gemm/handler.cuh @@ -13,15 +13,113 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef FLASHINFER_GROUP_GEMM_HANDLER_H -#define FLASHINFER_GROUP_GEMM_HANDLER_H +#ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_ +#define FLASHINFER_GROUP_GEMM_HANDLER_CUH_ + +#include + +#include "../allocator.h" +#include "../utils.cuh" +#include "group_gemm_cutlass.cuh" +#include "group_gemm_lora.cuh" +#include "group_gemv.cuh" namespace flashinfer { -class GroupGEMMHandler { - public: +namespace group_gemm { + +enum class GroupGEMMKernelConfig { + kGeneral, // large d_in, d_out + kShrink, // large d_in, small d_out + kExpand, // small d_in, large d_out +}; + +class CutlassSegmentGEMMHandler { + public: + cutlass::gemm::GemmCoord* GetProblemSizes() const { return problem_sizes_; } + + template + DType** GetXPtr() const { + return static_cast(x_data_); + } + + template + DType** GetWPtr() const { + return static_cast(w_data_); + } + + template + DType** GetYPtr() const { + return static_cast(y_data_); + } + + int64_t* GetLdX() const { return static_cast(ld_x_); } + + int64_t* GetLdY() const { return static_cast(ld_y_); } + + int64_t* GetLdW() const { return static_cast(ld_w_); } + + int64_t GetBatchSize() const { return batch_size_; } + + bool IsWeightColumnMajor() const { return w_column_major_; } + + template + cudaError_t RegisterProblem(void* buffer, size_t workspace_size_in_bytes, int64_t* xy_indptr_d, + int64_t* w_indices_d, size_t batch_size, size_t d_in, size_t d_out, + bool weight_column_major) { + problem_registered_ = true; + batch_size_ = batch_size; + w_column_major_ = weight_column_major; + + AlignedAllocator allocator(buffer_, workspace_size_in_bytes); + problem_size_ = allocator.aligned_alloc(batch_size, 16); + x_data_ = allocator.aligned_alloc(batch_size, 16); + w_data_ = allocator.aligned_alloc(batch_size, 16); + y_data_ = allocator.aligned_alloc(batch_size, 16); + ld_x_ = allocator.aligned_alloc(batch_size, 16); + ld_w_ = allocator.aligned_alloc(batch_size, 16); + ld_y_ = allocator.aligned_alloc(batch_size, 16); + + auto compute_args_kernel = compute_cutlass_group_gemm_args; + + void* args[] = {(void*)&problem_size_, (void*)&x_data_, (void*)&w_data_, + (void*)&y_data_, (void*)&ld_x_, (void*)&ld_w_, + (void*)&ld_y_, (void*)&xy_indptr_d, (void*)&w_indices_d, + (void*)&d_in, (void*)&d_out, (void*)&w_column_major_}; + + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)compute_args_kernel, batch_size, 1, args, 0, stream_);); + + return cudaSuccess; + } + + cudaStream_t GetCUDAStream() const { return stream_; } + + void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } + + bool IsProblemRegistered() const { return problem_registered_; } + + CutlassSegmentGEMMHandler() {} + + ~CutlassSegmentGEMMHandler() {} + + private: + bool problem_registered_; + void* buffer_; + cudaStream_t stream_; + bool w_column_major_; + size_t batch_size_; + cutlass::gemm::GemmCoord* problem_sizes_; + void* x_data_; + void* w_data_; + void* y_data_; + void* ld_x_; + void* ld_w_; + void* ld_y_; }; -} // namespace flashinfer +} // namespace group_gemm + +} // namespace flashinfer -#endif //FLASHINFER_GROUP_GEMM_HANDLER_H \ No newline at end of file +#endif // FLASHINFER_GROUP_GEMM_HANDLER_CUH_ diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/group_gemm/wrapper.cuh new file mode 100644 index 0000000000..38e8eebfb0 --- /dev/null +++ b/include/flashinfer/group_gemm/wrapper.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ +#define FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ + +#include "handler.cuh" + +namespace flashinfer { + +namespace group_gemm { + +template +cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* x, DType* w, + DType* y, cudaStream_t stream) { + using cutlass_t = typename cutlass_dtype::type; + if (handler->IsProblemRegistered()) { + using cutlass::epilogue::thread::LinearCombination; + using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + cutlass_t, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + cutlass_t, // Element B + cutlass::layout::RowMajor, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + cutlass_t, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape + cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape + cutlass::epilogue::thread::LinearCombination, // Epilogue + GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator + 2 // Stages + >::GemmKernel; + + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(handler->GetProblemSizes(), handler->GetBatchSize(), 512, + epilogue_op, handler->GetXPtr(), + handler->GetWPtr(), handler->GetYPtr(), + handler->GetYPtr(), handler->GetLdX(), + handler->GetLdW(), handler->GetLdY(), handler->GetLdY()); + + GemmGrouped gemm; + auto status = gemm.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", cutlassGetStatusString(status)); + return false; + } + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", cutlassGetStatusString(status)); + return false; + } + } else { + std::ostringstream err_msg; + err_msg << "Please call CutlassSegmentGEMMHandler's RegisterProblem() before calling " + "BatchDecodeWithPagedKVCacheWrapper()"; + throw std::runtime_error(err_msg.str()); + } + return cudaSuccess; +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index ec05cf7821..13ab21dfea 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -370,8 +370,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, - unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, - float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, + float rope_theta, bool return_lse) { CHECK_INPUT(q); CHECK_INPUT(qo_indptr); CHECK_INPUT(k); diff --git a/include/flashinfer/group_gemm/sgmv.cuh b/python/csrc/group_gemm.cu similarity index 99% rename from include/flashinfer/group_gemm/sgmv.cuh rename to python/csrc/group_gemm.cu index 458fbd05a3..ab79927f0e 100644 --- a/include/flashinfer/group_gemm/sgmv.cuh +++ b/python/csrc/group_gemm.cu @@ -12,4 +12,4 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ + */ \ No newline at end of file diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 64ce3ce272..9dbf9dbe3a 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -45,6 +45,7 @@ chain_speculative_sampling, ) from .norm import rmsnorm +from .group_gemm import SegmentGEMMWrapper try: from ._build_meta import __version__ as __version__ diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/group_gemm.py new file mode 100644 index 0000000000..78ebbc87e0 --- /dev/null +++ b/python/flashinfer/group_gemm.py @@ -0,0 +1,83 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +class SegmentGEMMWrapper: + r"""Wrapper for segment GEMM kernels.""" + + def __init__(self, workspace_buffer: torch.Tensor): + self._workspace_buffer = workspace_buffer + self._wrapper = _kernels.SegmentGEMMWrapper() + + def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ + self._workspace_buffer = new_workspace_buffer + + def register_problem( + self, + batch_size: int, + d_in: int, + d_out: int, + weight_column_major: bool, + seg_lens: Optional[torch.Tensor] = None, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + ): + if seg_lens is None and seg_indptr is None: + raise ValueError("Either seg_lens or seg_indptr should be provided.") + if seg_indptr is None: + seg_indptr = torch.cat( + [ + torch.tensor([0], device=seg_lens.device, dtype=seg_lens.dtype), + seg_lens.cumsum(0), + ], + dim=0, + ) + self._wrapper.register_problem( + self._workspace_buffer, + batch_size, + d_in, + d_out, + weight_column_major, + seg_indptr, + weight_indices, + ) + + def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + r"""Forward pass of segment GEMM.""" + return self._wrapper.forward(x, weights) diff --git a/python/setup.py b/python/setup.py index ddc035441e..d4be5c6342 100644 --- a/python/setup.py +++ b/python/setup.py @@ -386,6 +386,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/batch_prefill.cu", "csrc/sampling.cu", "csrc/norm.cu", + "csrc/group_gemm.cu", ] + get_instantiation_cu(), include_dirs=[ diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index a72704dc2c..57f280b013 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -43,7 +43,7 @@ def test_batch_prefill_with_paged_kv_cache( causal, kv_layout, pos_encoding_mode, - enable_cuda_graph + enable_cuda_graph, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len @@ -81,23 +81,29 @@ def test_batch_prefill_with_paged_kv_cache( head_dim, page_size, ) - o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward( + q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode + ) else: q_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) kv_indices_buffer = torch.empty(total_num_pages).int().to(0) kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, enable_cuda_graph=True, + workspace_buffer, + kv_layout, + enable_cuda_graph=True, qo_indptr_buf=q_indptr_buffer, paged_kv_indptr_buf=kv_indptr_buffer, paged_kv_indices_buf=kv_indices_buffer, - paged_kv_last_page_len_buf=kv_last_page_len_buffer + paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len kv_indptr_warmup = torch.arange(0, batch_size + 1).int() kv_indices_warmup = torch.arange(0, batch_size).int() - kv_last_page_len_warmup = torch.full((batch_size,), page_size, dtype=torch.int32) + kv_last_page_len_warmup = torch.full( + (batch_size,), page_size, dtype=torch.int32 + ) wrapper.begin_forward( q_indptr_warmup, kv_indptr_warmup, @@ -113,9 +119,7 @@ def test_batch_prefill_with_paged_kv_cache( s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - o = wrapper.forward( - q, kv_data, pos_encoding_mode=pos_encoding_mode - ) + o = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() @@ -148,7 +152,9 @@ def test_batch_prefill_with_paged_kv_cache( ( kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], :] + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : + ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -163,7 +169,9 @@ def test_batch_prefill_with_paged_kv_cache( ( kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], :] + else kv_data[ + kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : + ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -381,4 +389,3 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( ) test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") test_batch_prefill_with_ragged_kv_cache_custom_mask(12, 137, 137, 8, 8, 128, "NONE") - From aaad85241de3752bfa1dfa13b6b94f8f51b3ecee Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 10:10:51 +0000 Subject: [PATCH 05/12] trailing empty lines --- include/flashinfer/group_gemm/group_gemm_cutlass.cuh | 2 +- include/flashinfer/group_gemm/group_gemv.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh index 3acb608963..501a5b2d0b 100644 --- a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -62,4 +62,4 @@ __global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_pr } // namespace flashinfer -#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_ \ No newline at end of file +#endif // FLASHINFER_GROUP_GEMM_CUTLASS_WRAPPER_CUH_ diff --git a/include/flashinfer/group_gemm/group_gemv.cuh b/include/flashinfer/group_gemm/group_gemv.cuh index f44fdee620..4b439355e4 100644 --- a/include/flashinfer/group_gemm/group_gemv.cuh +++ b/include/flashinfer/group_gemm/group_gemv.cuh @@ -26,4 +26,4 @@ namespace group_gemm { } // namespace flashinfer -#endif // FLASHINFER_GROUP_GEMV_CUH_ \ No newline at end of file +#endif // FLASHINFER_GROUP_GEMV_CUH_ From f7bc35e6ced8bab3eb238122e1223936661fb751 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 10:35:34 +0000 Subject: [PATCH 06/12] upd --- .../group_gemm/group_gemm_cutlass.cuh | 2 +- python/csrc/flashinfer_ops.cu | 4 ++ python/csrc/flashinfer_ops.h | 16 ++++++ python/csrc/group_gemm.cu | 56 ++++++++++++++++++- python/flashinfer/group_gemm.py | 10 ++++ 5 files changed, 86 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh index 501a5b2d0b..5b927f6ae5 100644 --- a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -50,7 +50,7 @@ __global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_pr int i = blockIdx.x; int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); - ptr_w[i] = w + w_indices[i] * d_in * d_out; + ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out; ptr_x[i] = x + xy_indptr[i] * d_in; ptr_y[i] = y + xy_indptr[i] * d_out; ld_x[i] = k; // m * k diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index b088d07e26..d53603ddfe 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -72,4 +72,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); + py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") + .def(py::init()) + .def("register_problem", &CutlassSegmentGEMMPyTorchWrapper::RegisterProblem) + .def("forward", &CutlassSegmentGEMMPyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index a42cc1282c..724a3b794d 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -164,3 +165,18 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; + +class CutlassSegmentGEMMPyTorchWrapper { + public: + void RegisterProblem(torch::Tensor workspace_buffer, unsigned int batch_size, unsigned int d_in, + unsigned int d_out, bool weight_column_major, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor empty_data); + + torch::Tensor Forward(torch::Tensor x, torch::Tensor w); + + CutlassSegmentGEMMPyTorchWrapper() + : handler_(std::make_shared()) {} + + private: + std::shared_ptr handler_; +} \ No newline at end of file diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index ab79927f0e..fea633ebf7 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -12,4 +12,58 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ \ No newline at end of file + */ +#include + +#include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer::group_gemm; + +void CutlassSegmentGEMMPyTorchWrapper::RegisterProblem(torch::Tensor workspace_buffer, + unsigned int batch_size, unsigned int d_in, + unsigned int d_out, bool weight_column_major, + torch::Tensor seg_indptr, + torch::Tensor weight_indices, + torch::Tensor empty_data, ) { + CHECK_CUDA(workspace_buffer); + // TODO(Zihao): add more checks here + size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); + // cast seg_indptr to int64 + seg_indptr = seg_indptr.to(torch::kInt64).to(workspace_buffer.device()); + bool weight_indices_defined = weight_indices.numel() > 0; + if (weight_indices_defined) { + weight_indices = weight_indices.to(torch::kInt64).to(workspace_buffer.device()); + } + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + cudaError_t status = handler_->RegisterProblem( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(seg_indptr.data_ptr()), + weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, + batch_size, d_in, d_out, weight_column_major); + TORCH_CHECK(status == cudaSuccess, "Failed to register problem: ", cudaGetErrorString(status)); + }); +} + +torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor x, torch::Tensor weight) { + // TODO(Zihao): Add more checks here + CHECK_CUDA(x); + CHECK_CUDA(weight); + CHECK_DIM(2, x); // x: [sum(m_i), d_in] + CHECK_DIM(2, weight); // weight: [d_out, d_in] if weight_column_major, [d_in, d_out] otherwise + size_t cumulative_batch_size = x.size(0); + size_t d_out = handler_->IsWeightColumnMajor() ? weight.size(0) : weight.size(1); + size_t d_in = handler_->IsWeightColumnMajor() ? weight.size(1) : weight.size(0); + CHECK_EQ(x.size(1), d_in); + auto y = torch::empty({cumulative_batch_size, d_out}, x.options()); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + cudaError_t status = CutlassSegmentGEMMWrapper( + handler_, static_cast(x.data_ptr()), static_cast(weight.data_ptr()), + static_cast(y.data_ptr()), torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); + }); +} \ No newline at end of file diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/group_gemm.py index 78ebbc87e0..3e9c2e3b8d 100644 --- a/python/flashinfer/group_gemm.py +++ b/python/flashinfer/group_gemm.py @@ -57,6 +57,7 @@ def register_problem( seg_lens: Optional[torch.Tensor] = None, seg_indptr: Optional[torch.Tensor] = None, weight_indices: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float16, ): if seg_lens is None and seg_indptr is None: raise ValueError("Either seg_lens or seg_indptr should be provided.") @@ -68,6 +69,14 @@ def register_problem( ], dim=0, ) + if weight_indices is None: + # create an empty CPU tensor as placeholder + weight_indices = torch.empty(0, dtype=torch.int64) + # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info + empty_data = torch.empty( + 0, + dtype=(getattr(torch, dtype) if isinstance(dtype, str) else dtype), + ) self._wrapper.register_problem( self._workspace_buffer, batch_size, @@ -76,6 +85,7 @@ def register_problem( weight_column_major, seg_indptr, weight_indices, + empty_data, ) def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: From 7065fcc910e598dc7a37fe0918af4fe43a314e23 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 10:45:59 +0000 Subject: [PATCH 07/12] bugfix --- include/flashinfer/group_gemm/handler.cuh | 12 ++++++------ python/3rdparty | 1 + python/csrc/flashinfer_ops.cu | 2 +- python/csrc/flashinfer_ops.h | 8 ++++---- python/setup.py | 1 + 5 files changed, 13 insertions(+), 11 deletions(-) create mode 120000 python/3rdparty diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh index ebc02c807c..98ca509110 100644 --- a/include/flashinfer/group_gemm/handler.cuh +++ b/include/flashinfer/group_gemm/handler.cuh @@ -72,7 +72,7 @@ class CutlassSegmentGEMMHandler { w_column_major_ = weight_column_major; AlignedAllocator allocator(buffer_, workspace_size_in_bytes); - problem_size_ = allocator.aligned_alloc(batch_size, 16); + problem_sizes_ = allocator.aligned_alloc(batch_size, 16); x_data_ = allocator.aligned_alloc(batch_size, 16); w_data_ = allocator.aligned_alloc(batch_size, 16); y_data_ = allocator.aligned_alloc(batch_size, 16); @@ -82,13 +82,13 @@ class CutlassSegmentGEMMHandler { auto compute_args_kernel = compute_cutlass_group_gemm_args; - void* args[] = {(void*)&problem_size_, (void*)&x_data_, (void*)&w_data_, - (void*)&y_data_, (void*)&ld_x_, (void*)&ld_w_, - (void*)&ld_y_, (void*)&xy_indptr_d, (void*)&w_indices_d, - (void*)&d_in, (void*)&d_out, (void*)&w_column_major_}; + void* args[] = {(void*)&problem_sizes_, (void*)&x_data_, (void*)&w_data_, + (void*)&y_data_, (void*)&ld_x_, (void*)&ld_w_, + (void*)&ld_y_, (void*)&xy_indptr_d, (void*)&w_indices_d, + (void*)&d_in, (void*)&d_out, (void*)&w_column_major_}; FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)compute_args_kernel, batch_size, 1, args, 0, stream_);); + cudaLaunchKernel((void*)compute_args_kernel, batch_size, 1, args, 0, stream_)); return cudaSuccess; } diff --git a/python/3rdparty b/python/3rdparty new file mode 120000 index 0000000000..303a6484e6 --- /dev/null +++ b/python/3rdparty @@ -0,0 +1 @@ +../3rdparty \ No newline at end of file diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index d53603ddfe..6f90436ce7 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -73,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") - .def(py::init()) + .def(py::init<>()) .def("register_problem", &CutlassSegmentGEMMPyTorchWrapper::RegisterProblem) .def("forward", &CutlassSegmentGEMMPyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 724a3b794d..6df5d117c2 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -17,7 +17,7 @@ #include #include -#include +#include #include #include @@ -175,8 +175,8 @@ class CutlassSegmentGEMMPyTorchWrapper { torch::Tensor Forward(torch::Tensor x, torch::Tensor w); CutlassSegmentGEMMPyTorchWrapper() - : handler_(std::make_shared()) {} + : handler_(std::make_shared()) {} private: - std::shared_ptr handler_; -} \ No newline at end of file + std::shared_ptr handler_; +}; diff --git a/python/setup.py b/python/setup.py index d4be5c6342..5568190696 100644 --- a/python/setup.py +++ b/python/setup.py @@ -391,6 +391,7 @@ def __init__(self, *args, **kwargs) -> None: + get_instantiation_cu(), include_dirs=[ str(root.resolve() / "include"), + str(root.resolve() / "3rdparty" / "cutlass" / "include") # for group gemm ], extra_compile_args={ "cxx": [ From 58ae4982e76ff84daab8917d456e4e42ef218759 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 10:52:17 +0000 Subject: [PATCH 08/12] bugfix --- include/flashinfer/group_gemm/handler.cuh | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh index 98ca509110..619e6f7457 100644 --- a/include/flashinfer/group_gemm/handler.cuh +++ b/include/flashinfer/group_gemm/handler.cuh @@ -16,8 +16,6 @@ #ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_ #define FLASHINFER_GROUP_GEMM_HANDLER_CUH_ -#include - #include "../allocator.h" #include "../utils.cuh" #include "group_gemm_cutlass.cuh" @@ -110,9 +108,9 @@ class CutlassSegmentGEMMHandler { bool w_column_major_; size_t batch_size_; cutlass::gemm::GemmCoord* problem_sizes_; - void* x_data_; - void* w_data_; - void* y_data_; + void** x_data_; + void** w_data_; + void** y_data_; void* ld_x_; void* ld_w_; void* ld_y_; From 526e63837223b63cb9d820bbae5bc5692641737a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 11:27:41 +0000 Subject: [PATCH 09/12] bugfix --- include/flashinfer/group_gemm/handler.cuh | 18 ++--- include/flashinfer/group_gemm/wrapper.cuh | 90 +++++++++++++---------- python/csrc/group_gemm.cu | 15 ++-- 3 files changed, 70 insertions(+), 53 deletions(-) diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh index 619e6f7457..94b9e69967 100644 --- a/include/flashinfer/group_gemm/handler.cuh +++ b/include/flashinfer/group_gemm/handler.cuh @@ -37,17 +37,17 @@ class CutlassSegmentGEMMHandler { cutlass::gemm::GemmCoord* GetProblemSizes() const { return problem_sizes_; } template - DType** GetXPtr() const { + DType** GetXPtr() { return static_cast(x_data_); } template - DType** GetWPtr() const { + DType** GetWPtr() { return static_cast(w_data_); } template - DType** GetYPtr() const { + DType** GetYPtr() { return static_cast(y_data_); } @@ -71,9 +71,9 @@ class CutlassSegmentGEMMHandler { AlignedAllocator allocator(buffer_, workspace_size_in_bytes); problem_sizes_ = allocator.aligned_alloc(batch_size, 16); - x_data_ = allocator.aligned_alloc(batch_size, 16); - w_data_ = allocator.aligned_alloc(batch_size, 16); - y_data_ = allocator.aligned_alloc(batch_size, 16); + x_data_ = allocator.aligned_alloc(batch_size, 16); + w_data_ = allocator.aligned_alloc(batch_size, 16); + y_data_ = allocator.aligned_alloc(batch_size, 16); ld_x_ = allocator.aligned_alloc(batch_size, 16); ld_w_ = allocator.aligned_alloc(batch_size, 16); ld_y_ = allocator.aligned_alloc(batch_size, 16); @@ -108,9 +108,9 @@ class CutlassSegmentGEMMHandler { bool w_column_major_; size_t batch_size_; cutlass::gemm::GemmCoord* problem_sizes_; - void** x_data_; - void** w_data_; - void** y_data_; + void* x_data_; + void* w_data_; + void* y_data_; void* ld_x_; void* ld_w_; void* ld_y_; diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/group_gemm/wrapper.cuh index 38e8eebfb0..7e39a00598 100644 --- a/include/flashinfer/group_gemm/wrapper.cuh +++ b/include/flashinfer/group_gemm/wrapper.cuh @@ -29,48 +29,60 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* if (handler->IsProblemRegistered()) { using cutlass::epilogue::thread::LinearCombination; using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - cutlass_t, // Element A - cutlass::layout::RowMajor, // Layout A - cutlass::ComplexTransform::kNone, // - 8, // Granularity A - cutlass_t, // Element B - cutlass::layout::RowMajor, // Layout B - cutlass::ComplexTransform::kNone, // - 8, // Granularity B - cutlass_t, // Element C&D - cutlass::layout::RowMajor, // Layout C&D - float, // Element Accumulator - cutlass::arch::OpClassTensorOp, // Operator Class Tag - cutlass::arch::Sm80, // Architecture - cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape - cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape - cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape - cutlass::epilogue::thread::LinearCombination, // Epilogue - GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator - 2 // Stages - >::GemmKernel; + if (!handler->IsWeightColumnMajor()) { + // TODO(Zihao): investigate the difference between GroupScheduleMode::kDeviceOnly and + // GroupScheduleMode::kHostPrecompute + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + cutlass_t, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + cutlass_t, // Element B + cutlass::layout::RowMajor, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + cutlass_t, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape + cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape + cutlass::epilogue::thread::LinearCombination, // Epilogue + GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator + 2, // Stages + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly // Group Schedule Mode + >::GemmKernel; - using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - typename GemmGrouped::Arguments args(handler->GetProblemSizes(), handler->GetBatchSize(), 512, - epilogue_op, handler->GetXPtr(), - handler->GetWPtr(), handler->GetYPtr(), - handler->GetYPtr(), handler->GetLdX(), - handler->GetLdW(), handler->GetLdY(), handler->GetLdY()); + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args( + handler->GetProblemSizes(), handler->GetBatchSize(), 512, epilogue_op, + handler->GetXPtr(), handler->GetWPtr(), + handler->GetYPtr(), handler->GetYPtr(), handler->GetLdX(), + handler->GetLdW(), handler->GetLdY(), handler->GetLdY()); - GemmGrouped gemm; - auto status = gemm.initialize(args, nullptr, stream); - if (status != cutlass::Status::kSuccess) { - fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n", cutlassGetStatusString(status)); - return false; - } - status = gemm.run(stream); - if (status != cutlass::Status::kSuccess) { - fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n", cutlassGetStatusString(status)); - return false; + GemmGrouped gemm; + auto status = gemm.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "sgmv_cutlass gemm.initialize failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "sgmv_cutlass gemm.run failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + } else { + std::ostringstream err_msg; + // TODO: support column-major weight matrix + err_msg << "CutlassSegmentGEMMWrapper only supports row-major weight matrix"; + throw std::runtime_error(err_msg.str()); } } else { std::ostringstream err_msg; diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index fea633ebf7..eabe62bae8 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -25,7 +25,7 @@ void CutlassSegmentGEMMPyTorchWrapper::RegisterProblem(torch::Tensor workspace_b unsigned int d_out, bool weight_column_major, torch::Tensor seg_indptr, torch::Tensor weight_indices, - torch::Tensor empty_data, ) { + torch::Tensor empty_data) { CHECK_CUDA(workspace_buffer); // TODO(Zihao): add more checks here size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); @@ -37,12 +37,14 @@ void CutlassSegmentGEMMPyTorchWrapper::RegisterProblem(torch::Tensor workspace_b } DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + using cutlass_t = typename cutlass_dtype::type; cudaError_t status = handler_->RegisterProblem( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(seg_indptr.data_ptr()), weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, batch_size, d_in, d_out, weight_column_major); TORCH_CHECK(status == cudaSuccess, "Failed to register problem: ", cudaGetErrorString(status)); + return true; }); } @@ -52,18 +54,21 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor x, torch:: CHECK_CUDA(weight); CHECK_DIM(2, x); // x: [sum(m_i), d_in] CHECK_DIM(2, weight); // weight: [d_out, d_in] if weight_column_major, [d_in, d_out] otherwise - size_t cumulative_batch_size = x.size(0); - size_t d_out = handler_->IsWeightColumnMajor() ? weight.size(0) : weight.size(1); - size_t d_in = handler_->IsWeightColumnMajor() ? weight.size(1) : weight.size(0); + int64_t cumulative_batch_size = x.size(0); + int64_t d_out = handler_->IsWeightColumnMajor() ? weight.size(0) : weight.size(1); + int64_t d_in = handler_->IsWeightColumnMajor() ? weight.size(1) : weight.size(0); CHECK_EQ(x.size(1), d_in); auto y = torch::empty({cumulative_batch_size, d_out}, x.options()); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { cudaError_t status = CutlassSegmentGEMMWrapper( - handler_, static_cast(x.data_ptr()), static_cast(weight.data_ptr()), + handler_.get(), static_cast(x.data_ptr()), static_cast(weight.data_ptr()), static_cast(y.data_ptr()), torch_current_stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); + return true; }); + + return y; } \ No newline at end of file From e82b19b104700fc97fb8794c60b5a42840336fd0 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 4 Jun 2024 11:46:36 +0000 Subject: [PATCH 10/12] buggy --- .../group_gemm/group_gemm_cutlass.cuh | 6 +-- python/flashinfer/group_gemm.py | 2 +- python/tests/test_group_gemm.py | 46 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 python/tests/test_group_gemm.py diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh index 5b927f6ae5..80827cf4b2 100644 --- a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -42,9 +42,9 @@ struct cutlass_dtype { }; template -__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_y, - T** ptr_x, T** ptr_w, int64_t* ld_y, int64_t* ld_x, - int64_t* ld_w, T* y, T* x, T* w, int64_t* xy_indptr, +__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x, + T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w, + int64_t* ld_y, T* y, T* x, T* w, int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { int i = blockIdx.x; diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/group_gemm.py index 3e9c2e3b8d..b1a0dfd919 100644 --- a/python/flashinfer/group_gemm.py +++ b/python/flashinfer/group_gemm.py @@ -35,7 +35,7 @@ class SegmentGEMMWrapper: def __init__(self, workspace_buffer: torch.Tensor): self._workspace_buffer = workspace_buffer - self._wrapper = _kernels.SegmentGEMMWrapper() + self._wrapper = _kernels.CutlassSegmentGEMMPyTorchWrapper() def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. diff --git a/python/tests/test_group_gemm.py b/python/tests/test_group_gemm.py new file mode 100644 index 0000000000..31efc13d3f --- /dev/null +++ b/python/tests/test_group_gemm.py @@ -0,0 +1,46 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import flashinfer +import torch +import pytest + +@pytest.mark.parametrize("batch_size", [1, 33, 77, 377]) +@pytest.mark.parametrize("num_rows_per_batch", [3, 10, 99]) +@pytest.mark.parametrize("d_in", [128, 1024, 4096]) +@pytest.mark.parametrize("d_out", [128, 1024, 4096]) +def test_segment_gemm( + batch_size, + num_rows_per_batch, + d_in, + d_out, +): + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + segment_gemm = flashinfer.group_gemm.SegmentGEMMWrapper(workspace_buffer) + segment_gemm.register_problem( + batch_size, + d_in, + d_out, + weight_column_major=True, + seg_lens=torch.full((batch_size,), num_rows_per_batch), + seg_indptr=None, + weight_indices=None, + dtype=torch.float16, + ) + + +if __name__ == "__main__": + test_segment_gemm(1, 3, 128, 128) From dea6182529ea3836345bb41e4bd33f8973386482 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 5 Jun 2024 03:25:34 +0000 Subject: [PATCH 11/12] upd --- docs/api/python/decode.rst | 4 + docs/api/python/group_gemm.rst | 13 ++ docs/api/python/prefill.rst | 5 +- docs/index.rst | 1 + .../group_gemm/group_gemm_cutlass.cuh | 2 +- include/flashinfer/group_gemm/handler.cuh | 73 +-------- include/flashinfer/group_gemm/wrapper.cuh | 139 ++++++++++-------- include/flashinfer/utils.cuh | 11 ++ python/MANIFEST.in | 1 + python/csrc/flashinfer_ops.cu | 4 +- python/csrc/flashinfer_ops.h | 13 +- python/csrc/group_gemm.cu | 62 ++++---- python/csrc/pytorch_extension_utils.h | 1 + python/flashinfer/group_gemm.py | 110 ++++++++++---- python/flashinfer/utils.py | 7 + python/setup.py | 4 +- python/tests/test_group_gemm.py | 81 ++++++++-- 17 files changed, 322 insertions(+), 209 deletions(-) create mode 100644 docs/api/python/group_gemm.rst diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst index f789859011..eb4d06a3c1 100644 --- a/docs/api/python/decode.rst +++ b/docs/api/python/decode.rst @@ -25,5 +25,9 @@ Batch Decoding .. autoclass:: BatchDecodeWithPagedKVCacheWrapper :members: + .. automethod:: __init__ + .. autoclass:: CUDAGraphBatchDecodeWithPagedKVCacheWrapper :members: + + .. automethod:: __init__ diff --git a/docs/api/python/group_gemm.rst b/docs/api/python/group_gemm.rst new file mode 100644 index 0000000000..b396a320e4 --- /dev/null +++ b/docs/api/python/group_gemm.rst @@ -0,0 +1,13 @@ +.. _apigroup_gemm: + +flashinfer.group_gemm +===================== + +This module provides a set of functions to group GEMM operations. + +.. currentmodule:: flashinfer.group_gemm + +.. autoclass:: SegmentGEMMWrapper + :members: + + .. automethod:: __init__ diff --git a/docs/api/python/prefill.rst b/docs/api/python/prefill.rst index 9f50f1953e..aad6cbf653 100644 --- a/docs/api/python/prefill.rst +++ b/docs/api/python/prefill.rst @@ -22,6 +22,9 @@ Batch Prefill/Append Attention .. autoclass:: BatchPrefillWithPagedKVCacheWrapper :members: + .. automethod:: __init__ + .. autoclass:: BatchPrefillWithRaggedKVCacheWrapper :members: - \ No newline at end of file + + .. automethod:: __init__ diff --git a/docs/index.rst b/docs/index.rst index 334d3c8b6f..8851b7ff14 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,4 +32,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform api/python/cascade api/python/page api/python/sampling + api/python/group_gemm api/python/norm diff --git a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh index 80827cf4b2..a3422bef90 100644 --- a/include/flashinfer/group_gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/group_gemm/group_gemm_cutlass.cuh @@ -44,7 +44,7 @@ struct cutlass_dtype { template __global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x, T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w, - int64_t* ld_y, T* y, T* x, T* w, int64_t* xy_indptr, + int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { int i = blockIdx.x; diff --git a/include/flashinfer/group_gemm/handler.cuh b/include/flashinfer/group_gemm/handler.cuh index 94b9e69967..39ef0f7838 100644 --- a/include/flashinfer/group_gemm/handler.cuh +++ b/include/flashinfer/group_gemm/handler.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_GROUP_GEMM_HANDLER_CUH_ #define FLASHINFER_GROUP_GEMM_HANDLER_CUH_ +#include + #include "../allocator.h" #include "../utils.cuh" #include "group_gemm_cutlass.cuh" @@ -34,86 +36,27 @@ enum class GroupGEMMKernelConfig { class CutlassSegmentGEMMHandler { public: - cutlass::gemm::GemmCoord* GetProblemSizes() const { return problem_sizes_; } - - template - DType** GetXPtr() { - return static_cast(x_data_); + void RegisterWorkspace(void* buffer, size_t size) { + buffer_ = buffer; + workspace_size_in_bytes_ = size; } - template - DType** GetWPtr() { - return static_cast(w_data_); - } - - template - DType** GetYPtr() { - return static_cast(y_data_); - } - - int64_t* GetLdX() const { return static_cast(ld_x_); } - - int64_t* GetLdY() const { return static_cast(ld_y_); } - - int64_t* GetLdW() const { return static_cast(ld_w_); } - - int64_t GetBatchSize() const { return batch_size_; } + void* GetWorkspace() const { return buffer_; } - bool IsWeightColumnMajor() const { return w_column_major_; } - - template - cudaError_t RegisterProblem(void* buffer, size_t workspace_size_in_bytes, int64_t* xy_indptr_d, - int64_t* w_indices_d, size_t batch_size, size_t d_in, size_t d_out, - bool weight_column_major) { - problem_registered_ = true; - batch_size_ = batch_size; - w_column_major_ = weight_column_major; - - AlignedAllocator allocator(buffer_, workspace_size_in_bytes); - problem_sizes_ = allocator.aligned_alloc(batch_size, 16); - x_data_ = allocator.aligned_alloc(batch_size, 16); - w_data_ = allocator.aligned_alloc(batch_size, 16); - y_data_ = allocator.aligned_alloc(batch_size, 16); - ld_x_ = allocator.aligned_alloc(batch_size, 16); - ld_w_ = allocator.aligned_alloc(batch_size, 16); - ld_y_ = allocator.aligned_alloc(batch_size, 16); - - auto compute_args_kernel = compute_cutlass_group_gemm_args; - - void* args[] = {(void*)&problem_sizes_, (void*)&x_data_, (void*)&w_data_, - (void*)&y_data_, (void*)&ld_x_, (void*)&ld_w_, - (void*)&ld_y_, (void*)&xy_indptr_d, (void*)&w_indices_d, - (void*)&d_in, (void*)&d_out, (void*)&w_column_major_}; - - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)compute_args_kernel, batch_size, 1, args, 0, stream_)); - - return cudaSuccess; - } + size_t GetWorkspaceSizeInBytes() const { return workspace_size_in_bytes_; } cudaStream_t GetCUDAStream() const { return stream_; } void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - bool IsProblemRegistered() const { return problem_registered_; } - CutlassSegmentGEMMHandler() {} ~CutlassSegmentGEMMHandler() {} private: - bool problem_registered_; void* buffer_; + size_t workspace_size_in_bytes_; cudaStream_t stream_; - bool w_column_major_; - size_t batch_size_; - cutlass::gemm::GemmCoord* problem_sizes_; - void* x_data_; - void* w_data_; - void* y_data_; - void* ld_x_; - void* ld_w_; - void* ld_y_; }; } // namespace group_gemm diff --git a/include/flashinfer/group_gemm/wrapper.cuh b/include/flashinfer/group_gemm/wrapper.cuh index 7e39a00598..adc1d077bc 100644 --- a/include/flashinfer/group_gemm/wrapper.cuh +++ b/include/flashinfer/group_gemm/wrapper.cuh @@ -16,80 +16,99 @@ #ifndef FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ #define FLASHINFER_GROUP_GEMM_WRAPPER_CUH_ +#include + +#include "../allocator.h" #include "handler.cuh" namespace flashinfer { namespace group_gemm { +#define DISPATCH_WEIGHT_LAYOUT(is_column_major, WEIGHT_LAYOUT, ...) \ + if (is_column_major) { \ + using WEIGHT_LAYOUT = cutlass::layout::ColumnMajor; \ + __VA_ARGS__ \ + } else { \ + using WEIGHT_LAYOUT = cutlass::layout::RowMajor; \ + __VA_ARGS__ \ + } + template cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType* x, DType* w, - DType* y, cudaStream_t stream) { - using cutlass_t = typename cutlass_dtype::type; - if (handler->IsProblemRegistered()) { - using cutlass::epilogue::thread::LinearCombination; - using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; - if (!handler->IsWeightColumnMajor()) { - // TODO(Zihao): investigate the difference between GroupScheduleMode::kDeviceOnly and - // GroupScheduleMode::kHostPrecompute - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - cutlass_t, // Element A - cutlass::layout::RowMajor, // Layout A - cutlass::ComplexTransform::kNone, // - 8, // Granularity A - cutlass_t, // Element B - cutlass::layout::RowMajor, // Layout B - cutlass::ComplexTransform::kNone, // - 8, // Granularity B - cutlass_t, // Element C&D - cutlass::layout::RowMajor, // Layout C&D - float, // Element Accumulator - cutlass::arch::OpClassTensorOp, // Operator Class Tag - cutlass::arch::Sm80, // Architecture - cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape - cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape - cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape - cutlass::epilogue::thread::LinearCombination, // Epilogue - GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator - 2, // Stages - cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly // Group Schedule Mode - >::GemmKernel; + DType* y, int64_t* xy_indptr_d, int64_t* w_indices_d, + unsigned int batch_size, unsigned int d_in, + unsigned int d_out, bool weight_column_major, + cudaStream_t stream) { + AlignedAllocator allocator(handler->GetWorkspace(), handler->GetWorkspaceSizeInBytes()); + cutlass::gemm::GemmCoord* problem_sizes_device = + allocator.aligned_alloc( + batch_size * sizeof(cutlass::gemm::GemmCoord), 16); + DType** x_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + DType** w_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + DType** y_data = allocator.aligned_alloc(batch_size * sizeof(DType*), 16); + int64_t* ld_x = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + int64_t* ld_w = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + int64_t* ld_y = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16); + + // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, + // so I just use the kernel function directly, need to investigate more. + auto compute_args_kernel = compute_cutlass_group_gemm_args; + compute_args_kernel<<>>( + problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w, + (DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "Failed to launch kernel: " << cudaGetErrorString(err) << std::endl; + return err; + } - using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + using cutlass::epilogue::thread::LinearCombination; + using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + DType, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + DType, // Element B + WEIGHT_LAYOUT, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + DType, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<128, 128, 32>, // Thread Block Shape + cutlass::gemm::GemmShape<64, 64, 32>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape + cutlass::epilogue::thread::LinearCombination, // Epilogue + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling Operator + 8 // Stages + >::GemmKernel; - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - typename GemmGrouped::Arguments args( - handler->GetProblemSizes(), handler->GetBatchSize(), 512, epilogue_op, - handler->GetXPtr(), handler->GetWPtr(), - handler->GetYPtr(), handler->GetYPtr(), handler->GetLdX(), - handler->GetLdW(), handler->GetLdY(), handler->GetLdY()); + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, x_data, + w_data, y_data, y_data, ld_x, ld_w, ld_y, ld_y); - GemmGrouped gemm; - auto status = gemm.initialize(args, nullptr, stream); - if (status != cutlass::Status::kSuccess) { - std::ostringstream err_msg; - err_msg << "sgmv_cutlass gemm.initialize failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); - } - status = gemm.run(stream); - if (status != cutlass::Status::kSuccess) { - std::ostringstream err_msg; - err_msg << "sgmv_cutlass gemm.run failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); - } - } else { + GemmGrouped gemm; + auto status = gemm.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { std::ostringstream err_msg; - // TODO: support column-major weight matrix - err_msg << "CutlassSegmentGEMMWrapper only supports row-major weight matrix"; + err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status); throw std::runtime_error(err_msg.str()); } - } else { - std::ostringstream err_msg; - err_msg << "Please call CutlassSegmentGEMMHandler's RegisterProblem() before calling " - "BatchDecodeWithPagedKVCacheWrapper()"; - throw std::runtime_error(err_msg.str()); - } + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + }); + return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 358674a6d1..2c977fec46 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -308,6 +308,17 @@ std::tuple, std::vector> split_qo_in return {num_frags_x, num_qo_tiles, std::move(request_indices), std::move(tile_indices)}; } +template +inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { + std::vector host_array(size); + std::cout << prefix; + cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T), cudaMemcpyDeviceToHost); + for (size_t i = 0; i < size; ++i) { + std::cout << host_array[i] << " "; + } + std::cout << std::endl; +} + } // namespace flashinfer #endif // FLASHINFER_UTILS_CUH_ diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 070a8eb79f..854badc804 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -10,6 +10,7 @@ include generate_single_prefill_inst.py include literal_map.py recursive-include include * recursive-include csrc * +recursive-include 3rdparty/cutlass * # wheel-only exclude flashinfer/_build_meta.py diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 6f90436ce7..d784665d69 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -73,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) .def("forward_custom_mask", &BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask); py::class_(m, "CutlassSegmentGEMMPyTorchWrapper") - .def(py::init<>()) - .def("register_problem", &CutlassSegmentGEMMPyTorchWrapper::RegisterProblem) + .def(py::init()) + .def("register_workspace", &CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer) .def("forward", &CutlassSegmentGEMMPyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 6df5d117c2..b16b6a5704 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -168,14 +168,15 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { class CutlassSegmentGEMMPyTorchWrapper { public: - void RegisterProblem(torch::Tensor workspace_buffer, unsigned int batch_size, unsigned int d_in, - unsigned int d_out, bool weight_column_major, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor empty_data); + void RegisterWorkspaceBuffer(torch::Tensor workspace_buffer); - torch::Tensor Forward(torch::Tensor x, torch::Tensor w); + torch::Tensor Forward(torch::Tensor seg_indptr, torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, bool weight_column_major); - CutlassSegmentGEMMPyTorchWrapper() - : handler_(std::make_shared()) {} + CutlassSegmentGEMMPyTorchWrapper(torch::Tensor workspace_buffer) + : handler_(std::make_shared()) { + RegisterWorkspaceBuffer(workspace_buffer); + } private: std::shared_ptr handler_; diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index eabe62bae8..f8ee438875 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -20,51 +20,45 @@ using namespace flashinfer::group_gemm; -void CutlassSegmentGEMMPyTorchWrapper::RegisterProblem(torch::Tensor workspace_buffer, - unsigned int batch_size, unsigned int d_in, - unsigned int d_out, bool weight_column_major, - torch::Tensor seg_indptr, - torch::Tensor weight_indices, - torch::Tensor empty_data) { - CHECK_CUDA(workspace_buffer); - // TODO(Zihao): add more checks here - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); - // cast seg_indptr to int64 - seg_indptr = seg_indptr.to(torch::kInt64).to(workspace_buffer.device()); - bool weight_indices_defined = weight_indices.numel() > 0; - if (weight_indices_defined) { - weight_indices = weight_indices.to(torch::kInt64).to(workspace_buffer.device()); - } - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { - using cutlass_t = typename cutlass_dtype::type; - cudaError_t status = handler_->RegisterProblem( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(seg_indptr.data_ptr()), - weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, - batch_size, d_in, d_out, weight_column_major); - TORCH_CHECK(status == cudaSuccess, "Failed to register problem: ", cudaGetErrorString(status)); - return true; - }); +void CutlassSegmentGEMMPyTorchWrapper::RegisterWorkspaceBuffer(torch::Tensor workspace_buffer) { + handler_->RegisterWorkspace(static_cast(workspace_buffer.data_ptr()), + workspace_buffer.size(0) * workspace_buffer.element_size()); } -torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor x, torch::Tensor weight) { +torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr, + torch::Tensor weight_indices, + torch::Tensor x, torch::Tensor weight, + unsigned int batch_size, + bool weight_column_major) { // TODO(Zihao): Add more checks here + CHECK_CUDA(seg_indptr); CHECK_CUDA(x); CHECK_CUDA(weight); CHECK_DIM(2, x); // x: [sum(m_i), d_in] - CHECK_DIM(2, weight); // weight: [d_out, d_in] if weight_column_major, [d_in, d_out] otherwise + CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, + // d_in, d_out] otherwise int64_t cumulative_batch_size = x.size(0); - int64_t d_out = handler_->IsWeightColumnMajor() ? weight.size(0) : weight.size(1); - int64_t d_in = handler_->IsWeightColumnMajor() ? weight.size(1) : weight.size(0); + int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); + int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); CHECK_EQ(x.size(1), d_in); - auto y = torch::empty({cumulative_batch_size, d_out}, x.options()); + auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); + seg_indptr = seg_indptr.to(torch::kInt64); + + bool weight_indices_defined = weight_indices.numel() > 0; + if (weight_indices_defined) { + CHECK_CUDA(weight_indices); + weight_indices = weight_indices.to(torch::kInt64); + } DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { - cudaError_t status = CutlassSegmentGEMMWrapper( - handler_.get(), static_cast(x.data_ptr()), static_cast(weight.data_ptr()), - static_cast(y.data_ptr()), torch_current_stream); + using cutlass_t = typename cutlass_dtype::type; + auto status = CutlassSegmentGEMMWrapper( + handler_.get(), static_cast(x.data_ptr()), + static_cast(weight.data_ptr()), static_cast(y.data_ptr()), + static_cast(seg_indptr.data_ptr()), + weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, + batch_size, d_in, d_out, weight_column_major, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 8d5a6952e4..814b1df708 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -16,6 +16,7 @@ #pragma once #include #include +#include #include #include diff --git a/python/flashinfer/group_gemm.py b/python/flashinfer/group_gemm.py index b1a0dfd919..7bc8e58fa7 100644 --- a/python/flashinfer/group_gemm.py +++ b/python/flashinfer/group_gemm.py @@ -16,6 +16,7 @@ import torch from typing import Optional +from .utils import get_indptr try: from . import _kernels @@ -34,8 +35,18 @@ class SegmentGEMMWrapper: r"""Wrapper for segment GEMM kernels.""" def __init__(self, workspace_buffer: torch.Tensor): + r"""Initialize the wrapper. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose + size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases. + """ self._workspace_buffer = workspace_buffer - self._wrapper = _kernels.CutlassSegmentGEMMPyTorchWrapper() + self._wrapper = _kernels.CutlassSegmentGEMMPyTorchWrapper( + self._workspace_buffer + ) def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -47,47 +58,88 @@ def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): be the same as the device of the input tensors. """ self._workspace_buffer = new_workspace_buffer + self._wrapper.register_workspace_buffer(new_workspace_buffer) - def register_problem( + def forward( self, + x: torch.Tensor, + weights: torch.Tensor, batch_size: int, - d_in: int, - d_out: int, weight_column_major: bool, seg_lens: Optional[torch.Tensor] = None, seg_indptr: Optional[torch.Tensor] = None, weight_indices: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float16, - ): + ) -> torch.Tensor: + r"""Forward pass of segment GEMM. + + Compute the matrix multiplication between a batch of input tensor (with variable number of rows, but fixed + number of columns) and a batch of weight tensor with fixed number of rows and columns: + + .. math:: + + y[i] = x[i] \times W[i] + + if :attr:`weight_indices` is provided, we will select the weight tensor based on the indices in the + :attr:`weight_indices` tensor: + + .. math:: + + y[i] = x[i] \times W[weight_indices[i]] + + We use Ragged Tensor to represent the input tensor :attr:`x` and the output tensor :attr:`y`, and each x[i] + is a segment of the concatenated tensor. Please see :ref:`Ragged Tensor tutorial ` for more details. + We use a ``seg_len`` or ``seg_indptr`` tensor (either would work) to indicate the start and end of each segment, + where the ``seg_indptr`` is the cumulative sum of the ``seg_lens`` tensor (with an additional 0 at the beginning): + + .. math:: + + \text{seg_indptr}[i] = \sum_{j=0}^{i-1} \text{seg_lens}[j], \quad \text{seg_indptr}[0] = 0 + + - If ``seg_lens`` is provided, then :attr:`x` has shape ``(sum(seg_lens), d_in)`` and :attr:`y` has shape + ``(sum(seg_lens), d_out)``, where ``d_in`` is the number of columns of the input tensor and ``d_out`` is the + number of columns of the output tensor. + - If ``seg_indptr`` is provided, then :attr:`x` has shape ``(seg_indptr[-1], d_in)`` and :attr:`y` has shape + ``(seg_indptr[-1], d_out)``. + + Parameters + ---------- + x : torch.Tensor + The input tensor with shape ``(sum(seg_lens), d_in)``. + weights : torch.Tensor + The 3D weight tensor with shape ``(num_weights, d_in, d_out)`` if :attr:`weight_column_major` is ``False``, + or ``(num_weights, d_out, d_in)`` if :attr:`weight_column_major` is ``True``. + batch_size : int + The number of segments. + weight_column_major : bool + Whether the weight tensor is column major. + seg_lens : Optional[torch.Tensor] + The length of each segment, with shape ``(batch_size,)``, expects a 1D tensor of dtype ``torch.int64``. + seg_indptr : Optional[torch.Tensor] + The indptr of the segments, with shape ``(batch_size + 1,)``, expects a 1D tensor of dtype ``torch.int64``. + If this is provided, then :attr:`seg_lens` will be ignored, otherwise ``seg_indptr`` will be computed + internally from :attr:`seg_lens`. + weight_indices : Optional[torch.Tensor] + The indices of the weight tensor to be selected for each segment, with shape ``(batch_size,)``. + Expects a 1D tensor of dtype ``torch.int64``. + If this is provided, then the weight tensor will be selected based on the indices in this tensor. + + Returns + ------- + torch.Tensor + The output tensor with shape ``(sum(seg_lens), d_out)``. + """ if seg_lens is None and seg_indptr is None: raise ValueError("Either seg_lens or seg_indptr should be provided.") if seg_indptr is None: - seg_indptr = torch.cat( - [ - torch.tensor([0], device=seg_lens.device, dtype=seg_lens.dtype), - seg_lens.cumsum(0), - ], - dim=0, - ) + seg_indptr = get_indptr(seg_lens.to(x)) if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info - empty_data = torch.empty( - 0, - dtype=(getattr(torch, dtype) if isinstance(dtype, str) else dtype), - ) - self._wrapper.register_problem( - self._workspace_buffer, - batch_size, - d_in, - d_out, - weight_column_major, + return self._wrapper.forward( seg_indptr, weight_indices, - empty_data, + x, + weights, + batch_size, + weight_column_major, ) - - def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - r"""Forward pass of segment GEMM.""" - return self._wrapper.forward(x, weights) diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index 664ac879d9..c175365242 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -57,3 +57,10 @@ def check_kv_layout(kv_layout: str): def is_float8(x: torch.Tensor): return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + + +def get_indptr(x: torch.Tensor): + x = x.to(torch.int64) + ret = torch.zeros(x.shape[0] + 1, dtype=x.dtype, device=x.device) + ret[1:] = x.cumsum(0) + return ret diff --git a/python/setup.py b/python/setup.py index 5568190696..fff016b753 100644 --- a/python/setup.py +++ b/python/setup.py @@ -391,7 +391,9 @@ def __init__(self, *args, **kwargs) -> None: + get_instantiation_cu(), include_dirs=[ str(root.resolve() / "include"), - str(root.resolve() / "3rdparty" / "cutlass" / "include") # for group gemm + str( + root.resolve() / "3rdparty" / "cutlass" / "include" + ), # for group gemm ], extra_compile_args={ "cxx": [ diff --git a/python/tests/test_group_gemm.py b/python/tests/test_group_gemm.py index 31efc13d3f..3fb4fb8ebd 100644 --- a/python/tests/test_group_gemm.py +++ b/python/tests/test_group_gemm.py @@ -15,32 +15,93 @@ """ import flashinfer +import numpy as np import torch import pytest -@pytest.mark.parametrize("batch_size", [1, 33, 77, 377]) + +@pytest.mark.parametrize("batch_size", [1, 77, 199]) @pytest.mark.parametrize("num_rows_per_batch", [3, 10, 99]) @pytest.mark.parametrize("d_in", [128, 1024, 4096]) @pytest.mark.parametrize("d_out", [128, 1024, 4096]) +@pytest.mark.parametrize("use_weight_indices", [False, True]) +@pytest.mark.parametrize("column_major", [False, True]) def test_segment_gemm( batch_size, num_rows_per_batch, d_in, d_out, + use_weight_indices, + column_major, ): + torch.manual_seed(42) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) segment_gemm = flashinfer.group_gemm.SegmentGEMMWrapper(workspace_buffer) - segment_gemm.register_problem( + x = ( + (torch.randn(batch_size * num_rows_per_batch, d_in) / 10) + .to(0) + .to(torch.float16) + ) + if use_weight_indices: + num_weights = 1024 + if column_major: + weight = ( + (torch.randn(num_weights, d_out, d_in) / 10).to(0).to(torch.float16) + ) + else: + weight = ( + (torch.randn(num_weights, d_in, d_out) / 10).to(0).to(torch.float16) + ) + else: + weight = (torch.randn(batch_size, d_in, d_out) / 10).to(0).to(torch.float16) + y = segment_gemm.forward( + x, + weight, batch_size, - d_in, - d_out, - weight_column_major=True, - seg_lens=torch.full((batch_size,), num_rows_per_batch), - seg_indptr=None, - weight_indices=None, - dtype=torch.float16, + weight_column_major=column_major, + seg_lens=torch.full((batch_size,), num_rows_per_batch, dtype=torch.int64), + weight_indices=( + (torch.arange(0, batch_size) % num_weights).to(0) + if use_weight_indices + else None + ), ) + if use_weight_indices: + for i in range(batch_size): + np.testing.assert_allclose( + y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].cpu().numpy(), + torch.matmul( + x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch], + ( + weight[i % num_weights].T + if column_major + else weight[i % num_weights] + ), + ) + .cpu() + .numpy(), + rtol=1e-3, + atol=1e-3, + err_msg="assertion failed at batch {}".format(i), + ) + else: + np.testing.assert_allclose( + y.cpu().numpy(), + torch.matmul( + x.view(batch_size, num_rows_per_batch, d_in), + weight.transpose(-1, -2) if column_major else weight, + ) + .view(batch_size * num_rows_per_batch, d_out) + .cpu() + .numpy(), + rtol=1e-3, + atol=1e-3, + ) + if __name__ == "__main__": - test_segment_gemm(1, 3, 128, 128) + test_segment_gemm(199, 99, 128, 128, False, False) + test_segment_gemm(199, 99, 128, 128, False, True) + test_segment_gemm(199, 99, 128, 128, True, False) + test_segment_gemm(199, 99, 128, 128, True, True) From 7cb650d7e4f5191665f54aff52d29f25a455be7c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 5 Jun 2024 03:26:52 +0000 Subject: [PATCH 12/12] upd --- python/csrc/pytorch_extension_utils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 814b1df708..8d5a6952e4 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -16,7 +16,6 @@ #pragma once #include #include -#include #include #include