Skip to content
18 changes: 16 additions & 2 deletions python/sglang/srt/multimodal/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,26 @@
MultimodalInputFormat,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import envs, is_npu, load_audio, load_image, load_video, logger
from sglang.srt.utils import (
envs,
is_cpu,
is_npu,
is_xpu,
load_audio,
load_image,
load_video,
logger,
)
from sglang.srt.utils.cuda_ipc_transport_utils import (
MM_FEATURE_CACHE_SIZE,
MM_ITEM_MEMORY_POOL_RECYCLE_INTERVAL,
CudaIpcTensorTransportProxy,
MmItemMemoryPool,
)

_is_cpu = is_cpu()
_is_npu = is_npu()
_is_xpu = is_xpu()

SGL_USE_CUDA_IPC = envs.SGLANG_USE_CUDA_IPC_TRANSPORT.get()

Expand Down Expand Up @@ -317,8 +328,10 @@ def process_mm_data(
and isinstance(processor.image_processor, BaseImageProcessorFast)
and not self.server_args.disable_fast_image_processor
):
if get_global_server_args().rl_on_policy_target is not None:
if _is_cpu or get_global_server_args().rl_on_policy_target is not None:
kwargs["device"] = "cpu"
elif _is_xpu:
kwargs["device"] = "xpu"
elif not _is_npu:
kwargs["device"] = "cuda"
elif processor.__class__.__name__ not in {
Expand All @@ -327,6 +340,7 @@ def process_mm_data(
}:
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
kwargs["device"] = "npu"

result = processor.__call__(
text=[input_text],
padding=True,
Expand Down
Loading