-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: MxInt4 x Bf16 TRT-LLM Gen MoE support #2159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
613d5d9
f021624
1b0fc7a
6513090
7e9ff16
8222437
a5b7681
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -137,6 +137,41 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void blockScaleInterleaveHost(TensorView blockScale, TensorView interleavedBlockScale) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto blockScaleShape = blockScale.sizes(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto rows = blockScaleShape.size() == 3 ? blockScaleShape[1] : blockScaleShape[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto cols = blockScaleShape.size() == 3 ? blockScaleShape[2] : blockScaleShape[1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto expert_out_size = tensorrt_llm::computeSwizzledLayoutSFSize(rows, cols); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto rows_padded = PadUpFn(rows, 128); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto cols_padded = PadUpFn(cols, 4); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T* interleavedBlockScalePtr = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto globalRowIdx = eIdx * rows + rIdx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t sf_ori = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sf_ori = blockScalePtr[cIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleavedBlockScalePtr[sf_index] = sf_ori; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+151
to
+167
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid out-of-bounds pointer arithmetic for padded rows in For padded rows ( You can also make zero-initialization of A safer layout: - for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) {
- T* interleavedBlockScalePtr =
- static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
- for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
- auto globalRowIdx = eIdx * rows + rIdx;
- T* blockScalePtr = static_cast<T*>(blockScale.data_ptr()) + globalRowIdx * cols;
- for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
- T sf_ori = 0;
- if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) {
- sf_ori = blockScalePtr[cIdx];
- }
- int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
- tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
- interleavedBlockScalePtr[sf_index] = sf_ori;
- }
- }
- }
+ T* blockScaleBasePtr = static_cast<T*>(blockScale.data_ptr());
+ for (int eIdx = 0; eIdx < static_cast<int>(num_experts); ++eIdx) {
+ T* interleavedBlockScalePtr =
+ static_cast<T*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size;
+ T* blockScaleExpertBasePtr = blockScaleBasePtr + eIdx * rows * cols;
+ for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) {
+ bool const valid_row = rIdx < static_cast<int>(rows);
+ T* blockScaleRowPtr = valid_row ? blockScaleExpertBasePtr + rIdx * cols : nullptr;
+ for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) {
+ T sf_ori{};
+ if (valid_row && cIdx < static_cast<int>(cols)) {
+ sf_ori = blockScaleRowPtr[cIdx];
+ }
+ int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
+ tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4);
+ interleavedBlockScalePtr[sf_index] = sf_ori;
+ }
+ }
+ }This keeps behavior the same while avoiding any out-of-bounds pointer values and strengthens default initialization for all template types. π Committable suggestion
Suggested change
π€ Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template void blockScaleInterleaveHost<uint8_t>(TensorView blockScale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView interleavedBlockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| template void blockScaleInterleaveHost<__nv_bfloat16>(TensorView blockScale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView interleavedBlockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Interleave (and possibly pad) the weights block scaling factor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // blockScale: [num_experts, rows, cols] or [rows, cols] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -148,7 +183,8 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CHECK_CPU(blockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CHECK_CONTIGUOUS(blockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CHECK_INPUT_TYPE(blockScale, dl_uint8); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TVM_FFI_ICHECK(blockScale.dtype() == dl_uint8 || blockScale.dtype() == dl_bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << "Block Scale must be uint8 or bfloat16."; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto blockScaleShape = blockScale.sizes(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << "Block Scale should be 2D or 3D tensor."; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -166,27 +202,28 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const cudaStream_t stream = get_stream(blockScale.device()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::kernels::invokeBlockScaleInterleave( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_experts, rows, rows_padded, cols, cols_padded, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint8_t*>(blockScale.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (blockScale.dtype() == dl_uint8) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::kernels::invokeBlockScaleInterleave( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_experts, rows, rows_padded, cols, cols_padded, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint8_t*>(blockScale.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint8_t*>(interleavedBlockScale.data_ptr()), smCount, stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else if (blockScale.dtype() == dl_bfloat16) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::kernels::invokeBlockScaleInterleave( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_experts, rows, rows_padded, cols, cols_padded, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<__nv_bfloat16*>(blockScale.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<__nv_bfloat16*>(interleavedBlockScale.data_ptr()), smCount, stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TVM_FFI_LOG_AND_THROW(NotImplementedError) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << "block_scale_interleave only supports uint8 and bfloat16."; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int eIdx = 0; eIdx < static_cast<int>(num_experts); eIdx++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t* interleavedBlockScalePtr = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<uint8_t*>(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int rIdx = 0; rIdx < static_cast<int>(rows_padded); ++rIdx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto globalRowIdx = eIdx * rows + rIdx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t* blockScalePtr = static_cast<uint8_t*>(blockScale.data_ptr()) + globalRowIdx * cols; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int cIdx = 0; cIdx < static_cast<int>(cols_padded); ++cIdx) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t sf_ori = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rIdx < static_cast<int>(rows) && cIdx < static_cast<int>(cols)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sf_ori = blockScalePtr[cIdx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int sf_index = computeSFIndex(rIdx, cIdx, rows, cols, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleavedBlockScalePtr[sf_index] = sf_ori; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (blockScale.dtype() == dl_uint8) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockScaleInterleaveHost<uint8_t>(blockScale, interleavedBlockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else if (blockScale.dtype() == dl_bfloat16) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| blockScaleInterleaveHost<__nv_bfloat16>(blockScale, interleavedBlockScale); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TVM_FFI_LOG_AND_THROW(NotImplementedError) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << "blockScaleInterleaveHost only supports uint8 and bfloat16."; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.