diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1a9579f96791..adf2f7e6b999 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1108,6 +1108,10 @@ def load_model(self): self.remote_instance_transfer_engine_weight_info = ( self.loader.remote_instance_transfer_engine_weight_info ) + # Cache needs to be cleared after loading model weights (in the self.loader.load_model function). + # To avoid conflict with memory_saver_adapter.region, empty_cache operation is now moved here. + if _is_npu: + torch.npu.empty_cache() monkey_patch_vllm_parallel_state(reverse=True) # Publish metadata to ModelExpress if running as seed source diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 27d189d65622..5998eb0234ab 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -707,8 +707,6 @@ def load_weights_and_postprocess(model, weights, target_device): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if _is_npu: - torch.npu.empty_cache() class LayeredModelLoader(DefaultModelLoader): diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index c682a99c797b..d7ea34c7681d 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -365,7 +365,7 @@ def get_int_env_var(name: str, default: int = 0) -> int: def support_triton(backend: str) -> bool: - return backend not in ["torch_native", "intel_amx"] + return backend not in ["torch_native", "intel_amx", "ascend"] _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(