[AMD] Fix RotaryEmbedding crash on AMD/ROCm (regression from #17934)#18903
Conversation
…P/ROCm support in RotaryEmbedding class
Summary of ChangesHello @michaelzhang-ai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a significant regression that led to RotaryEmbedding JIT kernel failures on AMD/ROCm platforms. The change introduces a dedicated Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Ignored Files
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a regression that caused crashes on AMD/ROCm platforms by introducing a forward_hip method in the RotaryEmbedding class. This new method correctly bypasses the NVIDIA-specific JIT kernel path by dispatching to the pure PyTorch forward_native implementation. The change is well-contained, correct, and effectively resolves the issue for ROCm users. The accompanying docstring clearly explains the rationale for the fix. The implementation looks solid.
…ROCm 7.2; refactor forward_hip method in RotaryEmbedding class to accept *args/**kwargs for better subclass compatibility.
…ect#17934) (sgl-project#18903) Co-authored-by: michaelzhang-ai <michaelzhang-ai@users.noreply.github.com>
…ect#17934) (sgl-project#18903) Co-authored-by: michaelzhang-ai <michaelzhang-ai@users.noreply.github.com>
Motivation
#17934 (@BBuf) changed the HIP rotary embedding fallback from
sgl_kernel.rotary_embeddingtosglang.jit_kernel.pos_enc.rotary_embedding, which usestvm_ffiJIT compilation that requiresnvidia-smi/CUDA_HOME— unavailable on AMD GPUs.This breaks all models on ROCm since
_is_cuda=Falseon HIP always routes through the JIT fallback path.Failing nightlies:
Fix: Add
forward_hip()override toRotaryEmbeddingthat routes toforward_native()(pure PyTorch), bypassing the NVIDIA-specific JIT path. Uses*args, **kwargsbecause subclasses (MRotaryEmbedding,DeepseekScalingRotaryEmbedding, etc.) have differentforward_native()signatures.Pls help review @yctseng0211 @bingxche
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci