add contiguous inside rmsnorm kernel#95
Conversation
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR adds input tensor contiguity checks and enforcement directly within the RMSnorm kernel implementation, aligning with a similar change made in vLLM. The change moves contiguity handling from Python to C++ level to prevent potential accuracy issues.
Key Changes:
- Adds contiguity validation for output and weight tensors
- Implements automatic conversion to contiguous layout for input tensor when stride is not 1
- Adds runtime checks to ensure proper tensor layouts before kernel execution
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (input.stride(-1) != 1) { | ||
| input = input.contiguous(); | ||
| } | ||
| TORCH_CHECK(input.stride(-1) == 1); |
There was a problem hiding this comment.
The check at line 211 is redundant. After the contiguous() call on line 209, the stride(-1) is guaranteed to be 1. This assertion will never fail and can be removed.
| TORCH_CHECK(input.stride(-1) == 1); |
| if (input.stride(-1) != 1) { | ||
| input = input.contiguous(); | ||
| } | ||
| TORCH_CHECK(input.stride(-1) == 1); | ||
| TORCH_CHECK(weight.is_contiguous()); | ||
| VLLM_DISPATCH_FLOATING_TYPES( | ||
| input.scalar_type(), "call_rms_norm_kernel", [&] { | ||
| vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon); |
There was a problem hiding this comment.
Modifying the input parameter silently may lead to unexpected behavior for callers. Consider either making input non-const to signal this modification is possible, or document this behavior clearly, or make a copy to avoid mutating the caller's reference.
| if (input.stride(-1) != 1) { | |
| input = input.contiguous(); | |
| } | |
| TORCH_CHECK(input.stride(-1) == 1); | |
| TORCH_CHECK(weight.is_contiguous()); | |
| VLLM_DISPATCH_FLOATING_TYPES( | |
| input.scalar_type(), "call_rms_norm_kernel", [&] { | |
| vllm::call_rms_norm_kernel<scalar_t>(out, input, weight, epsilon); | |
| auto input_ = input; | |
| if (input_.stride(-1) != 1) { | |
| input_ = input_.contiguous(); | |
| } | |
| TORCH_CHECK(input_.stride(-1) == 1); | |
| TORCH_CHECK(weight.is_contiguous()); | |
| VLLM_DISPATCH_FLOATING_TYPES( | |
| input_.scalar_type(), "call_rms_norm_kernel", [&] { | |
| vllm::call_rms_norm_kernel<scalar_t>(out, input_, weight, epsilon); |
|
Cuda may meet the same problem. |
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
add contiguous inside rmsnorm kernel, vllm also add this in kernel in vllm-project/vllm#28103, and remove contiguous call in python level, this may bring potential accuracy issue.
Test Plan
Test Result
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)