[Perf] Use 2D-grid to eliminate divmod in W8W8 group quant#42153
Conversation
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
|
e2e benchmark on DeepSeek V4 Flash TP4 Concurrency 1 OPT |
|
e2e benchmark on DeepSeek V4 Flash TP4 Concurrency 1024 OPT |
There was a problem hiding this comment.
Code Review
This pull request refactors the 8-bit packed per-token group quantization kernel to use a 2D grid layout and template-based block dimensions. It introduces the GetGroupsPerBlockX helper function to determine optimal tiling and updates the kernel launch logic to support varying X and Y block configurations. Feedback suggests improving the robustness of the kernel launch macro by explicitly checking for all supported tiling factors and adding an error path for unsupported values.
| #define LAUNCH_REG_KERNEL(T, DST_DTYPE) \ | ||
| do { \ | ||
| if (kx == 16 && ry == 1) { \ | ||
| LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \ | ||
| } else if (kx == 8 && ry == 2) { \ | ||
| LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \ | ||
| } else { \ | ||
| LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \ | ||
| } \ | ||
| } while (0) |
There was a problem hiding this comment.
The LAUNCH_REG_KERNEL macro assumes that kx can only be 16, 8, or 4 based on the GetGroupsPerBlockX implementation. While this is currently true because padded_groups_per_row is always a multiple of 4, it would be safer and more robust to include a static assertion or a more explicit check to ensure that kx * ry always equals the expected number of groups per block (16) for which the kernel is optimized. If kx were to ever be a value not handled by the if-else chain, it would default to the 4, 4 case which might be incorrect for the actual kx value.
#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
do { \
if (kx == 16) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \
} else if (kx == 8) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \
} else if (kx == 4) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \
} else { \
STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \
} \
} while (0)
|
Hi @jiahanc, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com> (cherry picked from commit dd6b3a5)
…ect#42153) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
…ect#42153) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
…ect#42153) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
…ect#42153) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Purpose
Replace the 1D-grid
(global_group_id % padded_gpr, / padded_gpr)divmod with a 2D grid + template-constant index unpack. Since the tile sizes are compile-time constants, the compiler can lower the coordinate math to simple bit operations instead of emitting runtime division/modulo setup.Micro-benchmark on DeepSeek V4 Pro shape
Test Plan
lm_eval gsm8k
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.