Skip to content
14 changes: 14 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,20 @@ def forward_cpu(
positions, query, key, offsets, fused_set_kv_buffer_arg
)

def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# The fallback JIT kernel depends on CUDA toolchain discovery in tvm_ffi.
# On ROCm-only environments this can fail with missing CUDA_HOME.
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)

def forward_cuda(
self,
positions: torch.Tensor,
Expand Down
Loading