Skip to content
24 changes: 24 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,20 @@ def forward_cpu(
positions, query, key, offsets, fused_set_kv_buffer_arg
)

def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# The fallback JIT kernel depends on CUDA toolchain discovery in tvm_ffi.
# On ROCm-only environments this can fail with missing CUDA_HOME.
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)

def forward_cuda(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -1728,6 +1742,16 @@ def forward_native(
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(positions, query, key, fused_set_kv_buffer_arg)

def forward_cuda(
self,
positions: torch.Tensor,
Expand Down
Loading