diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 76879c96c31e..6f5b14bd9fc4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest): @app.get("/v1/models") async def show_available_models(): - models = await openai_serving_chat.show_available_models() + models = await openai_serving_completion.show_available_models() return JSONResponse(content=models.model_dump()) @@ -236,7 +236,8 @@ async def authentication(request: Request, call_next): args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules, + args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 59ad73bf097c..81c474ecc808 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,8 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) from vllm.utils import FlexibleArgumentParser @@ -23,6 +24,16 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, lora_list) +class PromptAdapterParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + adapter_list = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + def make_arg_parser(): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -65,6 +76,14 @@ def make_arg_parser(): action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") parser.add_argument("--chat-template", type=nullable_str, default=None, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1bd095655388..bf0fc104165b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -22,7 +22,8 @@ TokenizeResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]]): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=prompt_adapters) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -102,6 +105,7 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + prompt_adapter_request = self._maybe_get_prompt_adapter(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend @@ -147,6 +151,7 @@ async def create_completion(self, request: CompletionRequest, sampling_params, f"{request_id}-{i}", lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 84e4127725bb..e6c14ac7bd72 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,12 +16,19 @@ ModelPermission, TokenizeRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + @dataclass class LoRAModulePath: name: str @@ -30,9 +37,14 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): super().__init__() self.engine = engine @@ -59,6 +71,19 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, ) for i, lora in enumerate(lora_modules, start=1) ] + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with open(f"./{prompt_adapter.local_path}/adapter_config.json") as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -74,6 +99,13 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.served_model_names[0], + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(prompt_adapter_cards) model_cards.extend(lora_cards) return ModelList(data=model_cards) @@ -108,6 +140,11 @@ async def _check_model( return None if request.model in [lora.lora_name for lora in self.lora_requests]: return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.prompt_adapter_requests + ]: + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", @@ -122,8 +159,22 @@ def _maybe_get_lora( for lora in self.lora_requests: if request.model == lora.lora_name: return lora + return None + # if _check_model has been called earlier, this will be unreachable + #raise ValueError(f"The model `{request.model}` does not exist.") + + def _maybe_get_prompt_adapter( + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] + ) -> Optional[PromptAdapterRequest]: + if request.model in self.served_model_names: + return None + for prompt_adapter in self.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return prompt_adapter + return None # if _check_model has been called earlier, this will be unreachable - raise ValueError(f"The model `{request.model}` does not exist.") + #raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self,