diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp index 3e2cebd8f..bdbafd5fe 100644 --- a/csrc/layernorm.cpp +++ b/csrc/layernorm.cpp @@ -204,6 +204,12 @@ void rms_norm( torch::Tensor& input, torch::Tensor& weight, double epsilon) { + TORCH_CHECK(out.is_contiguous()); + if (input.stride(-1) != 1) { + input = input.contiguous(); + } + TORCH_CHECK(input.stride(-1) == 1); + TORCH_CHECK(weight.is_contiguous()); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "call_rms_norm_kernel", [&] { vllm::call_rms_norm_kernel(out, input, weight, epsilon);