From 4bc170c350bb122d25c4645ed97b454dc58467a9 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Fri, 10 Oct 2025 21:09:34 +0800 Subject: [PATCH 01/22] [Feature] support prefix cache in DP --- fastdeploy/engine/args_utils.py | 4 +-- fastdeploy/engine/common_engine.py | 11 ++++---- fastdeploy/engine/expert_service.py | 39 ++++++++++++++++++++++++--- fastdeploy/worker/gpu_model_runner.py | 4 +-- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index cc5d56d524c..fcb7088d603 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -198,7 +198,7 @@ class EngineArgs: The amount of CPU memory to offload to. """ - cache_queue_port: int = 8003 + cache_queue_port: str = "8003" """ Port for cache queue. """ @@ -741,7 +741,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--cache-queue-port", - type=int, + type=lambda s: [int(item.strip()) for item in s.split(",")] if s else None, default=EngineArgs.cache_queue_port, help="port for cache queue", ) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index c132a71df09..068f5172c45 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -69,6 +69,11 @@ def __init__(self, cfg, start_queue=True): """ self.cfg = cfg + if self.cfg.cache_config.enable_prefix_caching: + self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] + if self.cfg.parallel_config.enable_expert_parallel: self.llm_logger = get_logger( "fastdeploy", f"fastdeploy_rank{self.cfg.parallel_config.local_data_parallel_id}.log" @@ -251,11 +256,7 @@ def start_worker_queue_service(self, start_queue): local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) - if ( - self.cfg.cache_config.enable_prefix_caching - or self.cfg.scheduler_config.splitwise_role != "mixed" - and self.cfg.parallel_config.local_data_parallel_id == 0 - ): + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.cache_task_queue = EngineCacheQueue( address=( self.cfg.master_ip, diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index f26e1ef9dc9..d05e9b3b004 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -57,6 +57,11 @@ def __init__(self, cfg, local_data_parallel_id, start_queue=True): llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}") self.cfg.disaggregate_info = None + if self.cfg.cache_config.num_gpu_blocks_override is None: + self.do_profile = True + else: + self.do_profile = False + if cfg.scheduler_config.splitwise_role != "mixed": if len(self.cfg.cache_config.pd_comm_port) == 1: self.cfg.cache_config.pd_comm_port[0] = ( @@ -97,10 +102,30 @@ def start( ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0] llm_logger.info(f"start expert service {local_data_parallel_id}") - if self.cfg.scheduler_config.splitwise_role != "mixed": - ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id] - self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix_cache) + if self.cfg.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching: + if self.do_profile: + get_profile_block_num = np.zeros([1], dtype=np.int32) + while True: + try: + self.get_profile_block_num_signal = IPCSignal( + name="get_profile_block_num", + array=get_profile_block_num, + dtype=np.int32, + suffix=int(self.cfg.engine_worker_queue_port[0]), + create=False, + ) + break + except: + time.sleep(1) + self.reset_kvcache_blocks() + ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id] + self.cache_manager_processes = self.engine.start_cache_service( + self.cfg.local_device_ids, ipc_signal_suffix_cache + ) + if self.cfg.splitwise_role != "mixed": + self.engine.split_mode_get_tasks() + if self.cfg.scheduler_config.name == "splitwise": self.cfg.init_cache_info() role = self.cfg.scheduler_config.splitwise_role @@ -134,6 +159,14 @@ def start( f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds." ) return True + + def reset_kvcache_blocks(self): + self.do_profile = 0 + while self.get_profile_block_num_signal.value[0] == 0: + time.sleep(1) + num_gpu_blocks = self.get_profile_block_num_signal.value[0] + self.cfg.cache_config.reset(num_gpu_blocks) + self.engine.resource_manager.reset_cache_config(self.cfg.cache_config) def _exit_sub_services(self): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index ece6331f117..7eda4f8e1bf 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1185,7 +1185,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: if not create_cache_tensor: logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") - while cache_ready_signal.value[self.local_rank] != 1: + while cache_ready_signal.value[local_rank] != 1: time.sleep(1) logger.info(f"OK! Stop waiting. {cache_ready_signal.value}") @@ -1236,7 +1236,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: self.share_inputs["caches"] = cache_kvs_list if not profile and create_cache_tensor: - cache_ready_signal.value[self.local_rank] = 1 + cache_ready_signal.value[local_rank] = 1 logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}") paddle.device.cuda.empty_cache() From 1851e802f87f43ec807e343343867e6cddff0732 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Fri, 10 Oct 2025 21:15:55 +0800 Subject: [PATCH 02/22] fix --- fastdeploy/engine/expert_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index d05e9b3b004..4a08fe07514 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -125,7 +125,7 @@ def start( ) if self.cfg.splitwise_role != "mixed": self.engine.split_mode_get_tasks() - + if self.cfg.scheduler_config.name == "splitwise": self.cfg.init_cache_info() role = self.cfg.scheduler_config.splitwise_role @@ -159,7 +159,7 @@ def start( f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds." ) return True - + def reset_kvcache_blocks(self): self.do_profile = 0 while self.get_profile_block_num_signal.value[0] == 0: From 2a9e046f088a41e2b8d53c6bf2814e1f4a160395 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Fri, 10 Oct 2025 22:59:16 +0800 Subject: [PATCH 03/22] Update common_engine.py --- fastdeploy/engine/common_engine.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 068f5172c45..0a9e86f1aeb 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -68,11 +68,9 @@ def __init__(self, cfg, start_queue=True): cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg - - if self.cfg.cache_config.enable_prefix_caching: - self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ - self.cfg.parallel_config.local_data_parallel_id - ] + self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] if self.cfg.parallel_config.enable_expert_parallel: self.llm_logger = get_logger( From f5733dd96c892d1bc064c39e7272c8d490ace5dc Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Fri, 10 Oct 2025 23:46:27 +0800 Subject: [PATCH 04/22] Update common_engine.py --- fastdeploy/engine/common_engine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 0a9e86f1aeb..75615e809ae 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -68,9 +68,10 @@ def __init__(self, cfg, start_queue=True): cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg - self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ - self.cfg.parallel_config.local_data_parallel_id - ] + if isinstance(self.cfg.cache_config.cache_queue_port, list): + self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] if self.cfg.parallel_config.enable_expert_parallel: self.llm_logger = get_logger( From 427cf4753a128b0a6a939052ff183a02c611c394 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Fri, 10 Oct 2025 23:50:11 +0800 Subject: [PATCH 05/22] Update common_engine.py --- fastdeploy/engine/common_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 75615e809ae..6f093988922 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -68,10 +68,12 @@ def __init__(self, cfg, start_queue=True): cfg (Config): Config object containing all the configuration parameters. """ self.cfg = cfg + if isinstance(self.cfg.cache_config.cache_queue_port, str): + self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",") if isinstance(self.cfg.cache_config.cache_queue_port, list): - self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port[ + self.cfg.cache_config.cache_queue_port = int(self.cfg.cache_config.cache_queue_port[ self.cfg.parallel_config.local_data_parallel_id - ] + ]) if self.cfg.parallel_config.enable_expert_parallel: self.llm_logger = get_logger( From 00bdcc2c6db3fcf3f488b107cee69a3d5986d89b Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Fri, 10 Oct 2025 23:52:14 +0800 Subject: [PATCH 06/22] Update common_engine.py --- fastdeploy/engine/common_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 6f093988922..e10bc93c487 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -71,9 +71,9 @@ def __init__(self, cfg, start_queue=True): if isinstance(self.cfg.cache_config.cache_queue_port, str): self.cfg.cache_config.cache_queue_port = self.cfg.cache_config.cache_queue_port.split(",") if isinstance(self.cfg.cache_config.cache_queue_port, list): - self.cfg.cache_config.cache_queue_port = int(self.cfg.cache_config.cache_queue_port[ - self.cfg.parallel_config.local_data_parallel_id - ]) + self.cfg.cache_config.cache_queue_port = int( + self.cfg.cache_config.cache_queue_port[self.cfg.parallel_config.local_data_parallel_id] + ) if self.cfg.parallel_config.enable_expert_parallel: self.llm_logger = get_logger( From 0822d6361b9f53545aea1045b7f38b37a68bed73 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 11 Oct 2025 17:23:44 +0800 Subject: [PATCH 07/22] [BugFix] fix workers more than 1 --- fastdeploy/entrypoints/openai/api_server.py | 35 +++++++++++++++------ 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 235f1cccd26..10b875b7981 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -29,6 +29,7 @@ from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, Response, StreamingResponse +from gunicorn.app.base import BaseApplication from prometheus_client import CONTENT_TYPE_LATEST from fastdeploy.engine.args_utils import EngineArgs @@ -81,6 +82,21 @@ llm_engine = None +class StandaloneApplication(BaseApplication): + def __init__(self, app, options=None): + self.application = app + self.options = options or {} + super().__init__() + + def load_config(self): + config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} + for key, value in config.items(): + self.cfg.set(key.lower(), value) + + def load(self): + return self.application + + def load_engine(): """ load engine @@ -414,16 +430,17 @@ def launch_api_server() -> None: api_server_logger.info(f"args: {args.__dict__}") fd_start_span("FD_START") + options = { + "bind": f"{args.host}:{args.port}", + "workers": args.workers, + "worker_class": "uvicorn.workers.UvicornWorker", + "loglevel": "info", + "log_config": UVICORN_CONFIG, + "timeout_graceful_shutdown": args.timeout_graceful_shutdown, + } + try: - uvicorn.run( - app="fastdeploy.entrypoints.openai.api_server:app", - host=args.host, - port=args.port, - workers=args.workers, - log_config=UVICORN_CONFIG, - log_level="info", - timeout_graceful_shutdown=args.timeout_graceful_shutdown, - ) # set log level to error to avoid log + StandaloneApplication(app, options).run() except Exception as e: api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}") From 0acf059a365b45649ca1f728700979c3f55f831f Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sat, 11 Oct 2025 20:25:14 +0800 Subject: [PATCH 08/22] fix --- requirements.txt | 1 + requirements_dcu.txt | 1 + requirements_iluvatar.txt | 1 + requirements_metaxgpu.txt | 1 + 4 files changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 8eb02b628a6..32acf860de3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,6 +30,7 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +gunicorn modelscope opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 diff --git a/requirements_dcu.txt b/requirements_dcu.txt index 79bac3a6223..a622320a9e0 100644 --- a/requirements_dcu.txt +++ b/requirements_dcu.txt @@ -28,6 +28,7 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +gunicorn opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 opentelemetry-instrumentation-redis diff --git a/requirements_iluvatar.txt b/requirements_iluvatar.txt index d481e3febb1..7983b3b5843 100644 --- a/requirements_iluvatar.txt +++ b/requirements_iluvatar.txt @@ -29,6 +29,7 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +gunicorn opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 opentelemetry-instrumentation-redis diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index 26f6de09548..c17f3b3545b 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -30,6 +30,7 @@ use-triton-in-paddle crcmod fastsafetensors==0.1.14 msgpack +gunicorn modelscope opentelemetry-api>=1.24.0 opentelemetry-sdk>=1.24.0 From 667d146aa1ac0a99c7b7935fc8bac05131e18df8 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Sun, 12 Oct 2025 17:13:58 +0800 Subject: [PATCH 09/22] Update api_server.py --- fastdeploy/entrypoints/openai/api_server.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 10b875b7981..8ef77bf203a 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -105,10 +105,10 @@ def load_engine(): if llm_engine is not None: return llm_engine - api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}") + api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}, port: {args.port}") engine_args = EngineArgs.from_cli_args(args) engine = LLMEngine.from_engine_args(engine_args) - if not engine.start(api_server_pid=os.getpid()): + if not engine.start(api_server_pid=args.port): api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!") return None @@ -129,12 +129,12 @@ def load_data_service(): global llm_engine if llm_engine is not None: return llm_engine - api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}") + api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}, port: {args.port}") engine_args = EngineArgs.from_cli_args(args) config = engine_args.create_engine_config() api_server_logger.info(f"local_data_parallel_id: {config.parallel_config}") expert_service = ExpertService(config, config.parallel_config.local_data_parallel_id) - if not expert_service.start(os.getpid(), config.parallel_config.local_data_parallel_id): + if not expert_service.start(args.port, config.parallel_config.local_data_parallel_id): api_server_logger.error("Failed to initialize FastDeploy LLM expert service, service exit now!") return None llm_engine = expert_service @@ -149,10 +149,7 @@ async def lifespan(app: FastAPI): if args.tokenizer is None: args.tokenizer = args.model - if current_process().name != "MainProcess": - pid = os.getppid() - else: - pid = os.getpid() + pid = args.port api_server_logger.info(f"{pid}") if args.served_model_name is not None: From a03dfe637d2c30b957263d138c7eb86aacc98af2 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Sun, 12 Oct 2025 17:16:51 +0800 Subject: [PATCH 10/22] fix --- fastdeploy/entrypoints/openai/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 8ef77bf203a..ec9eeef2acc 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -22,7 +22,6 @@ import traceback from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from multiprocessing import current_process import uvicorn import zmq From 141abd7f92505b8dc81f0a330e1583c29dd1db2f Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 13 Oct 2025 21:19:47 +0800 Subject: [PATCH 11/22] Update api_server.py --- fastdeploy/entrypoints/openai/api_server.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index ec9eeef2acc..10f5a984b31 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -145,6 +145,21 @@ async def lifespan(app: FastAPI): """ async context manager for FastAPI lifespan """ + import logging + uvicorn_access = logging.getLogger("uvicorn.access") + uvicorn_access.handlers.clear() + + # 使用 gunicorn 的格式 + formatter = logging.Formatter( + '[%(asctime)s] [%(process)d] [INFO] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + handler = logging.StreamHandler() + handler.setFormatter(formatter) + uvicorn_access.addHandler(handler) + uvicorn_access.propagate = False + if args.tokenizer is None: args.tokenizer = args.model From 1f07ecdd201e3a8e98d6f80c40d82c20df6c66d8 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 13 Oct 2025 21:22:43 +0800 Subject: [PATCH 12/22] fix --- fastdeploy/entrypoints/openai/api_server.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 10f5a984b31..7c22354b1fe 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -146,20 +146,17 @@ async def lifespan(app: FastAPI): async context manager for FastAPI lifespan """ import logging + uvicorn_access = logging.getLogger("uvicorn.access") uvicorn_access.handlers.clear() - + # 使用 gunicorn 的格式 - formatter = logging.Formatter( - '[%(asctime)s] [%(process)d] [INFO] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - + formatter = logging.Formatter("[%(asctime)s] [%(process)d] [INFO] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + handler = logging.StreamHandler() handler.setFormatter(formatter) uvicorn_access.addHandler(handler) uvicorn_access.propagate = False - if args.tokenizer is None: args.tokenizer = args.model From 90cd313fe22ed343b3d12805dc8896dcdb21c637 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 16 Oct 2025 17:39:29 +0800 Subject: [PATCH 13/22] [Fearture] Support mm model close prefix cache --- fastdeploy/entrypoints/engine_client.py | 20 ++++++++++++++++++++ fastdeploy/multimodal/registry.py | 11 +++++++++++ fastdeploy/output/token_processor.py | 2 +- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 525498ed5ce..f8d60ef73b6 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -73,6 +73,7 @@ def __init__( architectures = ModelConfig({"model": model_name_or_path}).architectures[0] if MultimodalRegistry.contains_model(architectures): self.enable_mm = True + self.disable_prefix_mm = MultimodalRegistry.contains_mm_disable_prefix_cache_model(architectures) else: self.enable_mm = False @@ -158,6 +159,16 @@ async def format_and_add_data(self, prompts: dict): await self.add_requests(prompts) return prompts["prompt_token_ids"] + def _check_mm_disable_prefix_cache(self, task): + is_multimodal_data = False + if self.disable_prefix_mm: + multimodal_inputs = task.get("multimodal_inputs", []) + if multimodal_inputs: + token_type_ids = multimodal_inputs.get("token_type_ids", []) + if token_type_ids: + is_multimodal_data = np.sum(token_type_ids) > 0 + return is_multimodal_data + async def add_requests(self, task): """ Add a new request to the queue. @@ -180,6 +191,15 @@ async def add_requests(self, task): else: self.data_processor.process_request_dict(task, self.max_model_len) + if self.enable_mm and self.enable_prefix_caching: + if self._check_mm_disable_prefix_cache(task): + api_server_logger.error( + f"Current model doesn't support multimodal data with prefix caching, {task}" + ) + raise EngineError( + "Current model doesn't support multimodal data with prefix caching", error_code=400 + ) + task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) input_ids_len = task["prompt_token_ids_len"] task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py index f014ba55532..c3ea30015df 100644 --- a/fastdeploy/multimodal/registry.py +++ b/fastdeploy/multimodal/registry.py @@ -27,9 +27,20 @@ class MultimodalRegistry: "Ernie5ForCausalLM", } + mm_disable_prefix_cache_models: set[str] = { + "Ernie5ForCausalLM", + } + @classmethod def contains_model(cls, name: str) -> bool: """ Check if the given name exists in registry. """ return name in cls.mm_models + + @classmethod + def contains_mm_disable_prefix_cache_model(cls, name: str) -> bool: + """ + Check if the given name exists in registry. + """ + return name in cls.mm_disable_prefix_cache_models diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7d87033e883..22ffe719d45 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -258,7 +258,7 @@ def _process_batch_output_use_zmq(self, receive_datas): if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages - result.num_cached_tokens = task.num_cached_tokens + result.num_cached_tokens = task.num_cached_tokens is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" result = self._process_per_token(task, i, token_ids, result, is_prefill) From 0053f1765ab5c2b5279c1b2e6cfff8d9016cdf34 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Thu, 16 Oct 2025 19:43:43 +0800 Subject: [PATCH 14/22] Update api_server.py --- fastdeploy/entrypoints/openai/api_server.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 05fd7dd26b4..766a07b6acf 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -92,21 +92,6 @@ connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) -class StandaloneApplication(BaseApplication): - def __init__(self, app, options=None): - self.application = app - self.options = options or {} - super().__init__() - - def load_config(self): - config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} - for key, value in config.items(): - self.cfg.set(key.lower(), value) - - def load(self): - return self.application - - class StandaloneApplication(BaseApplication): def __init__(self, app, options=None): self.application = app From 03d9f22423fe4506df3d9d72a637818011dcdec3 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 20 Oct 2025 16:34:31 +0800 Subject: [PATCH 15/22] Update engine_client.py --- fastdeploy/entrypoints/engine_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index cd84525083c..219a7a40b9c 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -194,10 +194,11 @@ async def add_requests(self, task): if self.enable_mm and self.enable_prefix_caching: if self._check_mm_disable_prefix_cache(task): api_server_logger.error( - f"Current model doesn't support multimodal data with prefix caching, {task}" + f"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" ) raise EngineError( - "Current model doesn't support multimodal data with prefix caching", error_code=400 + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", + error_code=400 ) task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) From 9ade15e4fe65d79831fa5d2af582d567495ed897 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 20 Oct 2025 16:36:14 +0800 Subject: [PATCH 16/22] Update engine_client.py --- fastdeploy/entrypoints/engine_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 219a7a40b9c..c6fbe282ed4 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -194,11 +194,11 @@ async def add_requests(self, task): if self.enable_mm and self.enable_prefix_caching: if self._check_mm_disable_prefix_cache(task): api_server_logger.error( - f"The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" ) raise EngineError( "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", - error_code=400 + error_code=400, ) task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) From 7a93f0c5dea5543dd07a716cf474dffcff951440 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 20 Oct 2025 19:30:16 +0800 Subject: [PATCH 17/22] add test --- tests/entrypoints/test_chat.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py index 0078cd8a18e..afef26a16b9 100644 --- a/tests/entrypoints/test_chat.py +++ b/tests/entrypoints/test_chat.py @@ -26,11 +26,13 @@ class TestChat(unittest.TestCase): """Test case for chat functionality""" + COMMON_PREFIX = "I am a highly capable, compassionate, and trustworthy AI assistant dedicated to providing you with exceptional support. Whatever questions or challenges you may have, I will utilize my full capabilities to offer thoughtful and comprehensive assistance. As your intelligent companion, I consistently maintain honesty, transparency, and patience to ensure our interactions are both productive and enjoyable." + PROMPTS = [ - [{"content": "The color of tomato is ", "role": "user"}], - [{"content": "The equation 2+3= ", "role": "user"}], - [{"content": "The equation 4-1= ", "role": "user"}], [{"content": "PaddlePaddle is ", "role": "user"}], + [{"content": COMMON_PREFIX + "The color of tomato is ", "role": "user"}], + [{"content": COMMON_PREFIX + "The equation 2+3= ", "role": "user"}], + [{"content": COMMON_PREFIX + "The equation 4-1= ", "role": "user"}], ] @classmethod @@ -57,6 +59,8 @@ def tearDownClass(cls): def test_chat(self): outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None) self.assertEqual(len(self.PROMPTS), len(outputs)) + self.assertEqual(outputs[-1].num_cached_tokens == outputs[-2].num_cached_tokens) + self.assertEqual(outputs[-1].num_cached_token , 64) if __name__ == "__main__": From 842cde78239967e86d85a14c5acbb6e4017c075d Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 20 Oct 2025 19:39:56 +0800 Subject: [PATCH 18/22] Update test_chat.py --- tests/entrypoints/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py index cf701d47fe6..8cc54938344 100644 --- a/tests/entrypoints/test_chat.py +++ b/tests/entrypoints/test_chat.py @@ -61,7 +61,7 @@ def test_chat(self): outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None) self.assertEqual(len(self.PROMPTS), len(outputs)) self.assertEqual(outputs[-1].num_cached_tokens == outputs[-2].num_cached_tokens) - self.assertEqual(outputs[-1].num_cached_token , 64) + self.assertEqual(outputs[-1].num_cached_token, 64) def test_chat_with_tools(self): """Test chat with tools: From bd4ec3cf164a49869a96b6cf0f7fcfd1f71ec7d8 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 20 Oct 2025 21:37:30 +0800 Subject: [PATCH 19/22] fix --- fastdeploy/entrypoints/engine_client.py | 14 ++------ fastdeploy/multimodal/registry.py | 47 ------------------------- 2 files changed, 3 insertions(+), 58 deletions(-) delete mode 100644 fastdeploy/multimodal/registry.py diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 904281019bd..b2a8ab18155 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -37,7 +37,6 @@ ZmqIpcClient, ) from fastdeploy.metrics.work_metrics import work_process_metrics -from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform from fastdeploy.utils import ( EngineError, @@ -62,7 +61,6 @@ def __init__( port, limit_mm_per_prompt, mm_processor_kwargs, - # enable_mm=False, reasoning_parser=None, data_parallel_size=1, enable_logprob=False, @@ -71,21 +69,15 @@ def __init__( enable_prefix_caching=None, splitwise_role=None, ): - architectures = ModelConfig({"model": model_name_or_path}).architectures[0] - if MultimodalRegistry.contains_model(architectures): - self.enable_mm = True - self.disable_prefix_mm = MultimodalRegistry.contains_mm_disable_prefix_cache_model(architectures) - else: - self.enable_mm = False - + model_config = ModelConfig({"model": model_name_or_path}) input_processor = InputPreprocessor( - tokenizer, + model_config, reasoning_parser, limit_mm_per_prompt, mm_processor_kwargs, - self.enable_mm, tool_parser, ) + self.enable_mm = model_config.enable_mm self.enable_logprob = enable_logprob self.reasoning_parser = reasoning_parser self.data_processor = input_processor.create_processor() diff --git a/fastdeploy/multimodal/registry.py b/fastdeploy/multimodal/registry.py deleted file mode 100644 index 8297a2bd160..00000000000 --- a/fastdeploy/multimodal/registry.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - - -class MultimodalRegistry: - """ - A registry for multimodal models - """ - - mm_models: set[str] = { - "Ernie4_5_VLMoeForConditionalGeneration", - "Ernie5MoeForCausalLM", - "Qwen2_5_VLForConditionalGeneration", - "Ernie5ForCausalLM", - "Ernie4_5_VLMoeForProcessRewardModel", - } - - mm_disable_prefix_cache_models: set[str] = { - "Ernie5ForCausalLM", - } - - @classmethod - def contains_model(cls, name: str) -> bool: - """ - Check if the given name exists in registry. - """ - return name in cls.mm_models - - @classmethod - def contains_mm_disable_prefix_cache_model(cls, name: str) -> bool: - """ - Check if the given name exists in registry. - """ - return name in cls.mm_disable_prefix_cache_models From 33d9093e16b6010f7843143a68f44f646f31f087 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Mon, 20 Oct 2025 21:44:42 +0800 Subject: [PATCH 20/22] fix --- fastdeploy/cache_manager/cache_data.py | 12 ++++++++++++ fastdeploy/entrypoints/engine_client.py | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 638da70bcce..dc8ef406de7 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -21,6 +21,18 @@ logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { + "Ernie5ForCausalLM", +} + + +def is_mm_model_disable_prefix_cache(model_config): + """ + check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL + """ + return model_config._architecture in DISABLE_PREFIX_CACHE_MM_MODEL + + class CacheStatus(Enum): """ cache status enum class diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index b2a8ab18155..10bc9d32ac5 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -86,6 +86,13 @@ def __init__( self.enable_splitwise = splitwise_role != "mixed" max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + if self.enable_mm and self.enable_prefix_caching: + from fastdeploy.cache_manager.cache_data import ( + is_mm_model_disable_prefix_cache, + ) + + self.disable_prefix_mm = is_mm_model_disable_prefix_cache(model_config) + if tensor_parallel_size <= max_chips_per_node: self.is_master = True else: From 334d29aea3a7026505eb3c6220ba872f3abcf813 Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:20:45 +0800 Subject: [PATCH 21/22] Update test_chat.py --- tests/entrypoints/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py index 8cc54938344..450bb43403f 100644 --- a/tests/entrypoints/test_chat.py +++ b/tests/entrypoints/test_chat.py @@ -60,7 +60,7 @@ def tearDownClass(cls): def test_chat(self): outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None) self.assertEqual(len(self.PROMPTS), len(outputs)) - self.assertEqual(outputs[-1].num_cached_tokens == outputs[-2].num_cached_tokens) + self.assertEqual(outputs[-1].num_cached_tokens, outputs[-2].num_cached_tokens) self.assertEqual(outputs[-1].num_cached_token, 64) def test_chat_with_tools(self): From 6957bdca36536e965c2a37685ee5cc3e9ee8015c Mon Sep 17 00:00:00 2001 From: ltd0924 <32387785+ltd0924@users.noreply.github.com> Date: Tue, 21 Oct 2025 00:31:50 +0800 Subject: [PATCH 22/22] Update test_chat.py --- tests/entrypoints/test_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py index 450bb43403f..75ff3a3e050 100644 --- a/tests/entrypoints/test_chat.py +++ b/tests/entrypoints/test_chat.py @@ -61,7 +61,7 @@ def test_chat(self): outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None) self.assertEqual(len(self.PROMPTS), len(outputs)) self.assertEqual(outputs[-1].num_cached_tokens, outputs[-2].num_cached_tokens) - self.assertEqual(outputs[-1].num_cached_token, 64) + self.assertEqual(outputs[-1].num_cached_tokens, 64) def test_chat_with_tools(self): """Test chat with tools: