diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 638da70bcce..dc8ef406de7 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -21,6 +21,18 @@ logger = get_logger("prefix_cache_manager", "prefix_cache_manager.log") +DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { + "Ernie5ForCausalLM", +} + + +def is_mm_model_disable_prefix_cache(model_config): + """ + check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL + """ + return model_config._architecture in DISABLE_PREFIX_CACHE_MM_MODEL + + class CacheStatus(Enum): """ cache status enum class diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index fc1ee075163..10bc9d32ac5 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -86,6 +86,13 @@ def __init__( self.enable_splitwise = splitwise_role != "mixed" max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 + if self.enable_mm and self.enable_prefix_caching: + from fastdeploy.cache_manager.cache_data import ( + is_mm_model_disable_prefix_cache, + ) + + self.disable_prefix_mm = is_mm_model_disable_prefix_cache(model_config) + if tensor_parallel_size <= max_chips_per_node: self.is_master = True else: @@ -152,6 +159,16 @@ async def format_and_add_data(self, prompts: dict): await self.add_requests(prompts) return prompts["prompt_token_ids"] + def _check_mm_disable_prefix_cache(self, task): + is_multimodal_data = False + if self.disable_prefix_mm: + multimodal_inputs = task.get("multimodal_inputs", []) + if multimodal_inputs: + token_type_ids = multimodal_inputs.get("token_type_ids", []) + if token_type_ids: + is_multimodal_data = np.sum(token_type_ids) > 0 + return is_multimodal_data + async def add_requests(self, task): """ Add a new request to the queue. @@ -174,6 +191,16 @@ async def add_requests(self, task): else: self.data_processor.process_request_dict(task, self.max_model_len) + if self.enable_mm and self.enable_prefix_caching: + if self._check_mm_disable_prefix_cache(task): + api_server_logger.error( + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" + ) + raise EngineError( + "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", + error_code=400, + ) + task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) input_ids_len = task["prompt_token_ids_len"] task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens")) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 7d87033e883..22ffe719d45 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -258,7 +258,7 @@ def _process_batch_output_use_zmq(self, receive_datas): if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages - result.num_cached_tokens = task.num_cached_tokens + result.num_cached_tokens = task.num_cached_tokens is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill" result = self._process_per_token(task, i, token_ids, result, is_prefill) diff --git a/tests/entrypoints/test_chat.py b/tests/entrypoints/test_chat.py index 7167ce19aa0..75ff3a3e050 100644 --- a/tests/entrypoints/test_chat.py +++ b/tests/entrypoints/test_chat.py @@ -27,11 +27,13 @@ class TestChat(unittest.TestCase): """Test case for chat functionality""" + COMMON_PREFIX = "I am a highly capable, compassionate, and trustworthy AI assistant dedicated to providing you with exceptional support. Whatever questions or challenges you may have, I will utilize my full capabilities to offer thoughtful and comprehensive assistance. As your intelligent companion, I consistently maintain honesty, transparency, and patience to ensure our interactions are both productive and enjoyable." + PROMPTS = [ - [{"content": "The color of tomato is ", "role": "user"}], - [{"content": "The equation 2+3= ", "role": "user"}], - [{"content": "The equation 4-1= ", "role": "user"}], [{"content": "PaddlePaddle is ", "role": "user"}], + [{"content": COMMON_PREFIX + "The color of tomato is ", "role": "user"}], + [{"content": COMMON_PREFIX + "The equation 2+3= ", "role": "user"}], + [{"content": COMMON_PREFIX + "The equation 4-1= ", "role": "user"}], ] @classmethod @@ -58,6 +60,8 @@ def tearDownClass(cls): def test_chat(self): outputs = self.llm.chat(messages=self.PROMPTS, sampling_params=None) self.assertEqual(len(self.PROMPTS), len(outputs)) + self.assertEqual(outputs[-1].num_cached_tokens, outputs[-2].num_cached_tokens) + self.assertEqual(outputs[-1].num_cached_tokens, 64) def test_chat_with_tools(self): """Test chat with tools: