[Perf] Gemma4: output int64 from fused routing kernel to avoid redundant dtype copy#40565
[Perf] Gemma4: output int64 from fused routing kernel to avoid redundant dtype copy#40565yintong-lu wants to merge 2 commits into
Conversation
|
@claude review |
There was a problem hiding this comment.
Code Review
This pull request modifies the routing logic in custom_routing_router.py to handle topk_ids dtypes and updates the Gemma4 model to use int64 for routing IDs. A review comment suggests refining the dtype conversion in custom_routing_router.py to ensure that the requested indices_type is strictly respected, which prevents redundant memory copies and potential type inconsistencies in downstream functions.
| if topk_ids.dtype != target_dtype and topk_ids.dtype != torch.int64: | ||
| topk_ids = topk_ids.to(target_dtype) | ||
|
|
||
| return topk_weights.to(torch.float32), topk_ids |
There was a problem hiding this comment.
is it suitable for other cases?
There was a problem hiding this comment.
I think so. I would evaluate on CUDA and see if there's regression.
There was a problem hiding this comment.
It has been verified that on CUDA, there is no enhancement as the dtype conversions are fused into captured graph nodes (CompiledFxGraph) rather than appearing as standalone aten::copy_ calls.
On CUDA, no regression is observed either.
|
@claude review |
|
@jikunshang can you review? |
Signed-off-by: yintong-lu <yintong.lu@intel.com>
2fa6dc5 to
2901451
Compare
Summary:
This PR is a furthur optimization based on PR [https://github.com//pull/39083].
The Gemma4 fused Triton routing kernel (#39083) outputs
topk_idsasint32, but all downstream consumers requireint64:remap_hidden_statesC++ kernel uses int64 indexingcpu_fused_moe.py:topk_ids.to(torch.int64)for scattergpt_oss_triton_kernels_moe.py:topk_ids_raw.to(torch.long)compressed_tensors_moe:topk_ids.to(torch.long)This mismatch causes ~30 unnecessary
aten::copy_calls per forward pass (one int32→int64 conversion per MoE layer), which becomes a measurable bottleneck on bandwidth-constrained platforms.Performance (Intel XPU × 4, TP=4, Gemma4-26B-A4B-it)
Relative to PR #39083 baseline (commit
45232a454):copy_callscopy_XPU%