diff --git a/changes/3282.feature.md b/changes/3282.feature.md new file mode 100644 index 00000000000..e12c3f04043 --- /dev/null +++ b/changes/3282.feature.md @@ -0,0 +1 @@ +Add model service process restart (`POST /{service_id}/routings/{route_id}/restart`) API diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 907a31e5276..6d631c3ace7 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -50,7 +50,6 @@ import aiotools import attrs import pkg_resources -import yaml import zmq import zmq.asyncio from async_timeout import timeout @@ -63,11 +62,10 @@ stop_after_delay, wait_fixed, ) -from trafaret import DataError +from ai.backend.agent.model_service import ModelServiceManager from ai.backend.common import msgpack, redis_helper from ai.backend.common.bgtask import BackgroundTaskManager -from ai.backend.common.config import model_definition_iv from ai.backend.common.defs import REDIS_STAT_DB, REDIS_STREAM_DB from ai.backend.common.docker import MAX_KERNELSPEC, MIN_KERNELSPEC, ImageRef from ai.backend.common.events import ( @@ -105,7 +103,6 @@ from ai.backend.common.plugin.monitor import ErrorPluginContext, StatsPluginContext from ai.backend.common.service_ports import parse_service_ports from ai.backend.common.types import ( - MODEL_SERVICE_RUNTIME_PROFILES, AbuseReportValue, AcceleratorMetadata, AgentId, @@ -162,6 +159,7 @@ ContainerStatus, KernelLifecycleStatus, LifecycleEvent, + ModelServiceInfo, MountInfo, ) from .utils import generate_local_instance_id, get_arch_name @@ -341,6 +339,8 @@ async def prepare_container( environ: Mapping[str, str], service_ports, cluster_info: ClusterInfo, + *, + model_service_info: ModelServiceInfo | None = None, ) -> KernelObjectType: raise NotImplementedError @@ -2013,12 +2013,7 @@ async def create_kernel( }) model_definition: Optional[Mapping[str, Any]] = None - # Read model config - model_folders = [ - folder - for folder in vfolder_mounts - if folder.usage_mode == VFolderUsageMode.MODEL - ] + model_service_info: ModelServiceInfo | None = None if ctx.kernel_config["cluster_role"] in ("main", "master"): for sport in parse_service_ports( @@ -2063,14 +2058,45 @@ async def create_kernel( exposed_ports.append(port) log.debug("exposed ports: {!r}", exposed_ports) if kernel_config["session_type"] == SessionTypes.INFERENCE: - model_definition = await self.load_model_definition( - RuntimeVariant( - (kernel_config["internal_data"] or {}).get("runtime_variant", "custom") - ), - model_folders, - environ, - service_ports, - kernel_config, + # Read model config + model_folders = [ + folder + for folder in vfolder_mounts + if folder.usage_mode == VFolderUsageMode.MODEL + ] + + if len(model_folders) == 0: + raise AgentError("No model folder loaded for inference session") + model_folder = model_folders[0] + runtime_variant = RuntimeVariant( + (kernel_config["internal_data"] or {}).get("runtime_variant", "custom") + ) + + image_command = await self.extract_image_command( + kernel_config["image"]["canonical"] + ) + model_definition_path: str | None = ( + kernel_config.get("internal_data") or {} + ).get("model_definition_path") + + model_service_manager = ModelServiceManager( + runtime_variant, model_folder, model_definition_path=model_definition_path + ) + model_definition = await model_service_manager.load_model_definition( + image_command=image_command, + ) + environ.update(model_service_manager.create_environs(model_definition)) + service_ports.extend( + model_service_manager.create_service_port_definitions( + model_definition, + service_ports, + ) + ) + + model_service_info = ModelServiceInfo( + runtime_variant, + model_folder, + model_definition_path, ) runtime_type = image_labels.get("ai.backend.runtime-type", "app") @@ -2128,6 +2154,7 @@ async def create_kernel( environ, service_ports, cluster_info, + model_service_info=model_service_info, ) async with self.registry_lock: self.kernel_registry[kernel_id] = kernel_obj @@ -2240,6 +2267,7 @@ async def create_kernel( asyncio.create_task( self.start_and_monitor_model_service_health(kernel_obj, model) ) + kernel_obj.current_model_definition = model_definition # Finally we are done. await self.produce_event( @@ -2282,148 +2310,40 @@ async def start_and_monitor_model_service_health( ) ) - async def load_model_definition( + async def restart_model_service( self, - runtime_variant: RuntimeVariant, - model_folders: list[VFolderMount], - environ: MutableMapping[str, Any], - service_ports: list[ServicePort], - kernel_config: KernelCreationConfig, - ) -> Any: - image_command = await self.extract_image_command(kernel_config["image"]["canonical"]) - if runtime_variant != RuntimeVariant.CUSTOM and not image_command: + kernel_id: KernelId, + ) -> None: + try: + kernel_obj = self.kernel_registry[kernel_id] + except KeyError: + raise AgentError(f"Kernel {kernel_id} not found") + if not kernel_obj.model_service_info: raise AgentError( - "image should have its own command when runtime variant is set to values other than CUSTOM!" + "Model service info not loaded on kernel. Perhaps your kernel is not new enough to call this function." ) - assert len(model_folders) > 0 - model_folder: VFolderMount = model_folders[0] - - match runtime_variant: - case RuntimeVariant.VLLM: - _model = { - "name": "vllm-model", - "model_path": model_folder.kernel_path.as_posix(), - "service": { - "start_command": image_command, - "port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port, - "health_check": { - "path": MODEL_SERVICE_RUNTIME_PROFILES[ - runtime_variant - ].health_check_endpoint, - }, - }, - } - raw_definition = {"models": [_model]} - - case RuntimeVariant.HUGGINGFACE_TGI: - _model = { - "name": "tgi-model", - "model_path": model_folder.kernel_path.as_posix(), - "service": { - "start_command": image_command, - "port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port, - "health_check": { - "path": MODEL_SERVICE_RUNTIME_PROFILES[ - runtime_variant - ].health_check_endpoint, - }, - }, - } - raw_definition = {"models": [_model]} - - case RuntimeVariant.NIM: - _model = { - "name": "nim-model", - "model_path": model_folder.kernel_path.as_posix(), - "service": { - "start_command": image_command, - "port": MODEL_SERVICE_RUNTIME_PROFILES[runtime_variant].port, - "health_check": { - "path": MODEL_SERVICE_RUNTIME_PROFILES[ - runtime_variant - ].health_check_endpoint, - }, - }, - } - raw_definition = {"models": [_model]} - - case RuntimeVariant.CMD: - _model = { - "name": "image-model", - "model_path": model_folder.kernel_path.as_posix(), - "service": { - "start_command": image_command, - "port": 8000, - }, - } - raw_definition = {"models": [_model]} - case RuntimeVariant.CUSTOM: - if _fname := (kernel_config.get("internal_data") or {}).get( - "model_definition_path" - ): - model_definition_candidates = [_fname] - else: - model_definition_candidates = [ - "model-definition.yaml", - "model-definition.yml", - ] + model_service_manager = ModelServiceManager( + kernel_obj.model_service_info.runtime_variant, + kernel_obj.model_service_info.model_folder, + model_definition_path=kernel_obj.model_service_info.model_definition_path, + ) - model_definition_path = None - for filename in model_definition_candidates: - if (Path(model_folder.host_path) / filename).is_file(): - model_definition_path = Path(model_folder.host_path) / filename - break + image_command = await self.extract_image_command(kernel_obj.image.canonical) + model_definition = await model_service_manager.load_model_definition( + image_command=image_command, + ) - if not model_definition_path: - raise AgentError( - f"Model definition file ({" or ".join(model_definition_candidates)}) does not exist under vFolder" - f" {model_folder.name} (ID {model_folder.vfid})", - ) - try: - model_definition_yaml = await asyncio.get_running_loop().run_in_executor( - None, model_definition_path.read_text - ) - except FileNotFoundError as e: - raise AgentError( - "Model definition file (model-definition.yml) does not exist under" - f" vFolder {model_folder.name} (ID {model_folder.vfid})", - ) from e - try: - raw_definition = yaml.load(model_definition_yaml, Loader=yaml.FullLoader) - except yaml.error.YAMLError as e: - raise AgentError(f"Invalid YAML syntax: {e}") from e - try: - model_definition = model_definition_iv.check(raw_definition) - assert model_definition is not None - for model in model_definition["models"]: - if "BACKEND_MODEL_NAME" not in environ: - environ["BACKEND_MODEL_NAME"] = model["name"] - environ["BACKEND_MODEL_PATH"] = model["model_path"] - if service := model.get("service"): - if service["port"] in (2000, 2001): - raise AgentError("Port 2000 and 2001 are reserved for internal use") - overlapping_services = [ - s for s in service_ports if service["port"] in s["container_ports"] - ] - if len(overlapping_services) > 0: - raise AgentError( - f"Port {service["port"]} overlaps with built-in service" - f" {overlapping_services[0]["name"]}" - ) - service_ports.append({ - "name": f"{model["name"]}-{service["port"]}", - "protocol": ServicePortProtocols.PREOPEN, - "container_ports": (service["port"],), - "host_ports": (None,), - "is_inference": True, - }) - return model_definition - except DataError as e: - raise AgentError( - "Failed to validate model definition from vFolder" - f" {model_folder.name} (ID {model_folder.vfid})", - ) from e + for model_info in kernel_obj.current_model_definition["models"]: + log.debug("Shutting down model service {}", model_info["name"]) + await kernel_obj.shutdown_model_service(model_info) + + for model_info in model_definition["models"]: + log.debug("Starting model service {}", model_info["name"]) + await self.start_and_monitor_model_service_health( + cast(KernelObjectType, kernel_obj), model_info + ) + kernel_obj.current_model_definition = model_definition def get_public_service_ports(self, service_ports: list[ServicePort]) -> list[ServicePort]: return [port for port in service_ports if port["protocol"] != ServicePortProtocols.INTERNAL] diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 3153d88b9fd..8d5220d7dbb 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -88,7 +88,15 @@ from ..resources import AbstractComputePlugin, KernelResourceSpec, Mount, known_slot_types from ..scratch import create_loop_filesystem, destroy_loop_filesystem from ..server import get_extra_volumes -from ..types import AgentEventData, Container, ContainerStatus, LifecycleEvent, MountInfo, Port +from ..types import ( + AgentEventData, + Container, + ContainerStatus, + LifecycleEvent, + ModelServiceInfo, + MountInfo, + Port, +) from ..utils import ( closing_async, container_pid_to_host_pid, @@ -640,6 +648,8 @@ async def prepare_container( environ: Mapping[str, str], service_ports: List[ServicePort], cluster_info: ClusterInfo, + *, + model_service_info: ModelServiceInfo | None = None, ) -> DockerKernel: loop = current_loop() @@ -764,6 +774,7 @@ def _populate_ssh_config(): resource_spec=resource_spec, environ=environ, data={}, + model_service_info=model_service_info, ) return kernel_obj diff --git a/src/ai/backend/agent/docker/kernel.py b/src/ai/backend/agent/docker/kernel.py index 427ca6cfa30..beddaf080ff 100644 --- a/src/ai/backend/agent/docker/kernel.py +++ b/src/ai/backend/agent/docker/kernel.py @@ -31,7 +31,7 @@ from ..kernel import AbstractCodeRunner, AbstractKernel from ..resources import KernelResourceSpec -from ..types import AgentEventData +from ..types import AgentEventData, ModelServiceInfo from ..utils import closing_async, get_arch_name log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -58,6 +58,7 @@ def __init__( service_ports: Any, # TODO: type-annotation environ: Mapping[str, Any], data: Dict[str, Any], + model_service_info: Optional[ModelServiceInfo], ) -> None: super().__init__( kernel_id, @@ -71,6 +72,7 @@ def __init__( service_ports=service_ports, data=data, environ=environ, + model_service_info=model_service_info, ) self.network_driver = network_driver @@ -153,6 +155,10 @@ async def shutdown_service(self, service: str): assert self.runner is not None await self.runner.feed_shutdown_service(service) + async def shutdown_model_service(self, model_service: Mapping[str, Any]): + assert self.runner is not None + await self.runner.feed_shutdown_model_service(model_service) + async def get_service_apps(self): assert self.runner is not None result = await self.runner.feed_service_apps() diff --git a/src/ai/backend/agent/dummy/agent.py b/src/ai/backend/agent/dummy/agent.py index 9450a0d34cb..69b09e23979 100644 --- a/src/ai/backend/agent/dummy/agent.py +++ b/src/ai/backend/agent/dummy/agent.py @@ -41,7 +41,7 @@ from ..exception import UnsupportedResource from ..kernel import AbstractKernel from ..resources import AbstractComputePlugin, KernelResourceSpec, Mount, known_slot_types -from ..types import Container, ContainerStatus, MountInfo +from ..types import Container, ContainerStatus, ModelServiceInfo, MountInfo from .config import DEFAULT_CONFIG_PATH, dummy_local_config from .kernel import DummyKernel from .resources import load_resources, scan_available_resources @@ -167,6 +167,8 @@ async def prepare_container( environ: Mapping[str, str], service_ports, cluster_info: ClusterInfo, + *, + model_service_info: ModelServiceInfo | None = None, ) -> DummyKernel: delay = self.creation_ctx_config["delay"]["spawn"] await asyncio.sleep(delay) @@ -183,6 +185,7 @@ async def prepare_container( environ=environ, data={}, dummy_config=self.dummy_config, + model_service_info=model_service_info, ) async def start_container( diff --git a/src/ai/backend/agent/dummy/config.py b/src/ai/backend/agent/dummy/config.py index 0c22259924a..7affe130895 100644 --- a/src/ai/backend/agent/dummy/config.py +++ b/src/ai/backend/agent/dummy/config.py @@ -55,6 +55,7 @@ t.Key("start-service", default=0.1): tx.Delay, t.Key("start-model-service", default=0.1): tx.Delay, t.Key("shutdown-service", default=0.1): tx.Delay, + t.Key("shutdown-model-service", default=0.1): tx.Delay, t.Key("commit", default=5.0): tx.Delay, t.Key("get-service-apps", default=0.1): tx.Delay, t.Key("accept-file", default=0.1): tx.Delay, diff --git a/src/ai/backend/agent/dummy/kernel.py b/src/ai/backend/agent/dummy/kernel.py index b47c83e9e76..e7cf6427701 100644 --- a/src/ai/backend/agent/dummy/kernel.py +++ b/src/ai/backend/agent/dummy/kernel.py @@ -3,7 +3,7 @@ import asyncio import os from collections import OrderedDict -from typing import Any, Dict, FrozenSet, Mapping, Sequence, override +from typing import Any, Dict, FrozenSet, Mapping, Optional, Sequence, override from ai.backend.common.docker import ImageRef from ai.backend.common.events import EventProducer @@ -11,7 +11,7 @@ from ..kernel import AbstractCodeRunner, AbstractKernel, NextResult, ResultRecord from ..resources import KernelResourceSpec -from ..types import AgentEventData +from ..types import AgentEventData, ModelServiceInfo class DummyKernel(AbstractKernel): @@ -32,6 +32,7 @@ def __init__( environ: Mapping[str, Any], data: Dict[str, Any], dummy_config: Mapping[str, Any], + model_service_info: Optional[ModelServiceInfo], ) -> None: super().__init__( kernel_id, @@ -45,6 +46,7 @@ def __init__( service_ports=service_ports, data=data, environ=environ, + model_service_info=model_service_info, ) self.is_commiting = False self.dummy_config = dummy_config @@ -115,6 +117,11 @@ async def shutdown_service(self, service): delay = self.dummy_kernel_cfg["delay"]["shutdown-service"] await asyncio.sleep(delay) + async def shutdown_model_service(self, model_service: Mapping[str, Any]): + delay = self.dummy_kernel_cfg["delay"]["shutdown-model-service"] + await asyncio.sleep(delay) + return {} + async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: if self.is_commiting: return CommitStatus.ONGOING diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py index 198ebe18f3b..70f1bf1819c 100644 --- a/src/ai/backend/agent/kernel.py +++ b/src/ai/backend/agent/kernel.py @@ -56,7 +56,7 @@ from .exception import UnsupportedBaseDistroError from .resources import KernelResourceSpec -from .types import AgentEventData, KernelLifecycleStatus +from .types import AgentEventData, KernelLifecycleStatus, ModelServiceInfo log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -181,6 +181,11 @@ class AbstractKernel(UserDict, aobject, metaclass=ABCMeta): environ: Mapping[str, Any] status: KernelLifecycleStatus + model_service_info: Optional[ModelServiceInfo] + """Set only if kernel is `INFERENCE` type""" + + current_model_definition: Any + _tasks: Set[asyncio.Task] runner: Optional[AbstractCodeRunner] @@ -199,6 +204,7 @@ def __init__( service_ports: Any, # TODO: type-annotation data: Dict[Any, Any], environ: Mapping[str, Any], + model_service_info: Optional[ModelServiceInfo], ) -> None: self.agent_config = agent_config self.kernel_id = kernel_id @@ -219,6 +225,9 @@ def __init__( self.runner = None self.container_id = None self.state = KernelLifecycleStatus.PREPARING + self.model_service_info = model_service_info + + self.current_model_definition = {} async def init(self, event_producer: EventProducer) -> None: log.debug( @@ -242,6 +251,10 @@ def __setstate__(self, props) -> None: # Used when a `Kernel` object is loaded from pickle data. if "state" not in props: props["state"] = KernelLifecycleStatus.RUNNING + if "model_service_info" not in props: + props["model_service_info"] = None + if "current_model_definition" not in props: + props["current_model_definition"] = {} self.__dict__.update(props) # agent_config is set by the pickle.loads() caller. self.clean_event = None @@ -307,6 +320,10 @@ async def start_model_service(self, model_service): async def shutdown_service(self, service): raise NotImplementedError + @abstractmethod + async def shutdown_model_service(self, model_service): + raise NotImplementedError + @abstractmethod async def check_duplicate_commit(self, kernel_id, subdir) -> CommitStatus: raise NotImplementedError @@ -442,6 +459,7 @@ class AbstractCodeRunner(aobject, metaclass=ABCMeta): completion_queue: asyncio.Queue[bytes] service_queue: asyncio.Queue[bytes] model_service_queue: asyncio.Queue[bytes] + shutdown_model_service_queue: asyncio.Queue[bytes] service_apps_info_queue: asyncio.Queue[bytes] status_queue: asyncio.Queue[bytes] output_queue: Optional[asyncio.Queue[ResultRecord]] @@ -482,6 +500,7 @@ def __init__( self.completion_queue = asyncio.Queue(maxsize=128) self.service_queue = asyncio.Queue(maxsize=128) self.model_service_queue = asyncio.Queue(maxsize=128) + self.shutdown_model_service_queue = asyncio.Queue(maxsize=128) self.service_apps_info_queue = asyncio.Queue(maxsize=128) self.status_queue = asyncio.Queue(maxsize=128) self.output_queue = None @@ -513,6 +532,7 @@ def __getstate__(self): del props["completion_queue"] del props["service_queue"] del props["model_service_queue"] + del props["shutdown_model_service_queue"] del props["service_apps_info_queue"] del props["status_queue"] del props["output_queue"] @@ -535,6 +555,7 @@ def __setstate__(self, props): self.completion_queue = asyncio.Queue(maxsize=128) self.service_queue = asyncio.Queue(maxsize=128) self.model_service_queue = asyncio.Queue(maxsize=128) + self.shutdown_model_service_queue = asyncio.Queue(maxsize=128) self.service_apps_info_queue = asyncio.Queue(maxsize=128) self.status_queue = asyncio.Queue(maxsize=128) self.output_queue = None @@ -726,6 +747,23 @@ async def feed_shutdown_service(self, service_name: str): json.dumps(service_name).encode("utf8"), ]) + async def feed_shutdown_model_service(self, model_info): + if self.input_sock.closed: + raise asyncio.CancelledError + await self.input_sock.send_multipart([ + b"shutdown-model-service", + json.dumps(model_info).encode("utf8"), + ]) + try: + with timeout(60): + result = await self.shutdown_model_service_queue.get() + self.shutdown_model_service_queue.task_done() + return json.loads(result) + except asyncio.CancelledError: + return {"status": "failed", "error": "cancelled"} + except asyncio.TimeoutError: + return {"status": "failed", "error": "timeout"} + async def feed_service_apps(self): await self.input_sock.send_multipart([ b"get-apps", @@ -984,6 +1022,8 @@ async def read_output(self) -> None: await self.service_queue.put(msg_data) case b"model-service-result": await self.model_service_queue.put(msg_data) + case b"shutdown-model-service-result": + await self.shutdown_model_service_queue.put(msg_data) case b"model-service-status": response = json.loads(msg_data) event = ModelServiceStatusEvent( diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index 8db627eb487..05c5c6e2d48 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -62,7 +62,7 @@ from ..exception import K8sError, UnsupportedResource from ..kernel import AbstractKernel, KernelFeatures from ..resources import AbstractComputePlugin, KernelResourceSpec, Mount, known_slot_types -from ..types import Container, ContainerStatus, MountInfo, Port +from ..types import Container, ContainerStatus, ModelServiceInfo, MountInfo, Port from .kernel import KubernetesKernel from .kube_object import ( ConfigMap, @@ -525,6 +525,8 @@ async def prepare_container( environ: Mapping[str, str], service_ports, cluster_info: ClusterInfo, + *, + model_service_info: ModelServiceInfo | None = None, ) -> KubernetesKernel: loop = current_loop() if self.restarting: @@ -655,6 +657,7 @@ def _write_config(file_name: str, content: str): resource_spec=resource_spec, environ=environ, data={}, + model_service_info=model_service_info, ) return kernel_obj diff --git a/src/ai/backend/agent/kubernetes/kernel.py b/src/ai/backend/agent/kubernetes/kernel.py index 1cd21fb0107..ead1c4b6af8 100644 --- a/src/ai/backend/agent/kubernetes/kernel.py +++ b/src/ai/backend/agent/kubernetes/kernel.py @@ -25,7 +25,7 @@ from ..kernel import AbstractCodeRunner, AbstractKernel from ..resources import KernelResourceSpec -from ..types import AgentEventData +from ..types import AgentEventData, ModelServiceInfo log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -47,6 +47,7 @@ def __init__( service_ports: Any, # TODO: type-annotation data: Dict[str, Any], environ: Mapping[str, Any], + model_service_info: Optional[ModelServiceInfo], ) -> None: super().__init__( kernel_id, @@ -60,6 +61,7 @@ def __init__( service_ports=service_ports, data=data, environ=environ, + model_service_info=model_service_info, ) self.deployment_name = f"kernel-{kernel_id}" @@ -209,6 +211,10 @@ async def shutdown_service(self, service: str): assert self.runner is not None await self.runner.feed_shutdown_service(service) + async def shutdown_model_service(self, model_service: Mapping[str, Any]): + assert self.runner is not None + await self.runner.feed_shutdown_model_service(model_service) + async def get_service_apps(self): assert self.runner is not None result = await self.runner.feed_service_apps() diff --git a/src/ai/backend/agent/model_service.py b/src/ai/backend/agent/model_service.py new file mode 100644 index 00000000000..37653123123 --- /dev/null +++ b/src/ai/backend/agent/model_service.py @@ -0,0 +1,189 @@ +import asyncio +from pathlib import Path +from typing import Any, Mapping + +import yaml +from trafaret import DataError + +from ai.backend.agent.exception import AgentError +from ai.backend.common.config import model_definition_iv +from ai.backend.common.types import ( + MODEL_SERVICE_RUNTIME_PROFILES, + RuntimeVariant, + ServicePort, + ServicePortProtocols, + VFolderMount, +) + + +class ModelServiceManager: + runtime_variant: RuntimeVariant + model_folder: VFolderMount + model_definition_path: str | None + + def __init__( + self, + runtime_variant: RuntimeVariant, + model_folder: VFolderMount, + *, + model_definition_path: str | None = None, + ) -> None: + self.runtime_variant = runtime_variant + self.model_folder = model_folder + self.model_definition_path = model_definition_path + + async def load_model_definition( + self, + image_command: str | None = None, + ) -> Any: + """ + Generates a model definition config (check `model_definition_iv` for schema) based on + the runtime configuration of the kernel. When `runtime_variant` is set as `CUSTOM` then + model definition will be populated base on the YAML file located at `model_definition_path` + (`model-definition.yaml` or `model-definition.yml` by default). + """ + + if self.runtime_variant != RuntimeVariant.CUSTOM and not image_command: + raise AgentError( + "image should have its own command when runtime variant is set to values other than CUSTOM!" + ) + + match self.runtime_variant: + case RuntimeVariant.VLLM: + _model = { + "name": "vllm-model", + "model_path": self.model_folder.kernel_path.as_posix(), + "service": { + "start_command": image_command, + "port": MODEL_SERVICE_RUNTIME_PROFILES[self.runtime_variant].port, + "health_check": { + "path": MODEL_SERVICE_RUNTIME_PROFILES[ + self.runtime_variant + ].health_check_endpoint, + }, + }, + } + raw_definition = {"models": [_model]} + + case RuntimeVariant.HUGGINGFACE_TGI: + _model = { + "name": "tgi-model", + "model_path": self.model_folder.kernel_path.as_posix(), + "service": { + "start_command": image_command, + "port": MODEL_SERVICE_RUNTIME_PROFILES[self.runtime_variant].port, + "health_check": { + "path": MODEL_SERVICE_RUNTIME_PROFILES[ + self.runtime_variant + ].health_check_endpoint, + }, + }, + } + raw_definition = {"models": [_model]} + + case RuntimeVariant.NIM: + _model = { + "name": "nim-model", + "model_path": self.model_folder.kernel_path.as_posix(), + "service": { + "start_command": image_command, + "port": MODEL_SERVICE_RUNTIME_PROFILES[self.runtime_variant].port, + "health_check": { + "path": MODEL_SERVICE_RUNTIME_PROFILES[ + self.runtime_variant + ].health_check_endpoint, + }, + }, + } + raw_definition = {"models": [_model]} + + case RuntimeVariant.CMD: + _model = { + "name": "image-model", + "model_path": self.model_folder.kernel_path.as_posix(), + "service": { + "start_command": image_command, + "port": 8000, + }, + } + raw_definition = {"models": [_model]} + + case RuntimeVariant.CUSTOM: + if self.model_definition_path: + model_definition_candidates = [self.model_definition_path] + else: + model_definition_candidates = [ + "model-definition.yaml", + "model-definition.yml", + ] + + _def_path: Path | None = None + for filename in model_definition_candidates: + if (Path(self.model_folder.host_path) / filename).is_file(): + _def_path = Path(self.model_folder.host_path) / filename + break + + if not _def_path: + raise AgentError( + f"Model definition file ({" or ".join([str(x) for x in model_definition_candidates])}) does not exist under vFolder" + f" {self.model_folder.name} (ID {self.model_folder.vfid})", + ) + try: + model_definition_yaml = await asyncio.get_running_loop().run_in_executor( + None, _def_path.read_text + ) + except FileNotFoundError as e: + raise AgentError( + "Model definition file (model-definition.yml) does not exist under" + f" vFolder {self.model_folder.name} (ID {self.model_folder.vfid})", + ) from e + try: + raw_definition = yaml.load(model_definition_yaml, Loader=yaml.FullLoader) + except yaml.error.YAMLError as e: + raise AgentError(f"Invalid YAML syntax: {e}") from e + try: + model_definition = model_definition_iv.check(raw_definition) + assert model_definition is not None + except DataError as e: + raise AgentError( + "Failed to validate model definition from vFolder" + f" {self.model_folder.name} (ID {self.model_folder.vfid})", + ) from e + + def create_environs(self, model_definition: Any) -> Mapping[str, Any]: + environ: dict[str, Any] = {} + for model in model_definition["models"]: + if "BACKEND_MODEL_NAME" not in environ: + environ["BACKEND_MODEL_NAME"] = model["name"] + environ["BACKEND_MODEL_PATH"] = model["model_path"] + return model_definition + + def create_service_port_definitions( + self, model_definition: Any, existing_service_ports: list[ServicePort] + ) -> list[ServicePort]: + """ + Extracts service port definition of model services. Requires definition generated by + `ModelServiceManager.create_service_port_definitions()`. + """ + new_service_ports: list[ServicePort] = [] + for model in model_definition["models"]: + if service := model.get("service"): + if service["port"] in (2000, 2001): + raise AgentError("Port 2000 and 2001 are reserved for internal use") + overlapping_services = [ + s for s in existing_service_ports if service["port"] in s["container_ports"] + ] + if len(overlapping_services) > 0: + raise AgentError( + f"Port {service["port"]} overlaps with built-in service" + f" {overlapping_services[0]["name"]}" + ) + new_service_ports.append({ + "name": f"{model["name"]}-{service["port"]}", + "protocol": ServicePortProtocols.PREOPEN, + "container_ports": (service["port"],), + "host_ports": (None,), + "is_inference": True, + }) + + return new_service_ports diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index c29fd8110b5..35b70448cba 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -693,6 +693,12 @@ async def restart_kernel( cast(KernelCreationConfig, updated_config), ) + @rpc_function + @collect_error + async def restart_model_service(self, kernel_id: str): + log.info("rpc::restart_model_service(k:{0})", kernel_id) + return await self.agent.restart_model_service(KernelId(UUID(kernel_id))) + @rpc_function @collect_error async def execute( diff --git a/src/ai/backend/agent/types.py b/src/ai/backend/agent/types.py index 8d7d6ffe018..f15c0c6443e 100644 --- a/src/ai/backend/agent/types.py +++ b/src/ai/backend/agent/types.py @@ -1,5 +1,6 @@ import asyncio import enum +from dataclasses import dataclass from pathlib import Path from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence @@ -8,7 +9,14 @@ from aiohttp.typedefs import Handler from ai.backend.common.events import KernelLifecycleEventReason -from ai.backend.common.types import ContainerId, KernelId, MountTypes, SessionId +from ai.backend.common.types import ( + ContainerId, + KernelId, + MountTypes, + RuntimeVariant, + SessionId, + VFolderMount, +) class AgentBackend(enum.StrEnum): @@ -113,3 +121,10 @@ def __str__(self): [web.Request, Handler], Awaitable[web.StreamResponse], ] + + +@dataclass +class ModelServiceInfo: + runtime_variant: RuntimeVariant + model_folder: VFolderMount + model_definition_path: str | None diff --git a/src/ai/backend/kernel/base.py b/src/ai/backend/kernel/base.py index 762f49582d0..cd2f0a38c2e 100644 --- a/src/ai/backend/kernel/base.py +++ b/src/ai/backend/kernel/base.py @@ -731,6 +731,26 @@ async def start_model_service(self, model_info): ), ]) + async def shutdown_model_service(self, model_info) -> None: + try: + if self._health_check_task: + self._health_check_task.cancel() + await asyncio.sleep(0) + if not self._health_check_task.done(): + await self._health_check_task + await self._shutdown_service(model_info["name"]) + log.info("shutdown_model_service(): shutdown completed") + await self.outsock.send_multipart([ + b"shutdown-model-service-result", + json.dumps({"success": True}).encode("utf-8"), + ]) + except Exception as e: + log.info("shutdown_model_service(): Failed to shutdown model service: {}", e) + await self.outsock.send_multipart([ + b"shutdown-model-service-result", + json.dumps({"success": False, "error": str(e)}).encode("utf-8"), + ]) + async def check_model_health(self, model_name, model_service_info): health_check_info = model_service_info.get("health_check") health_check_endpoint = ( @@ -1144,6 +1164,9 @@ async def main_loop(self, cmdargs): elif op_type == "start-service": # activate a service port data = json.loads(text) asyncio.create_task(self._start_service_and_feed_result(data)) + elif op_type == "shutdown-model-service": # shutdown the service by its name + data = json.loads(text) + await self.shutdown_model_service(data) elif op_type == "shutdown-service": # shutdown the service by its name data = json.loads(text) await self._shutdown_service(data) diff --git a/src/ai/backend/kernel/service_actions.py b/src/ai/backend/kernel/service_actions.py index 1bd1566725c..853a2c2caaf 100644 --- a/src/ai/backend/kernel/service_actions.py +++ b/src/ai/backend/kernel/service_actions.py @@ -41,10 +41,12 @@ async def write_tempfile( async def run_command( variables: Mapping[str, Any], command: Iterable[str], - echo=False, + echo=True, ) -> Optional[MutableMapping[str, str]]: + concrete_command = [str(piece).format_map(variables) for piece in command] + logger.info(f"run_command(): executing command {concrete_command}") proc = await create_subprocess_exec( - *(str(piece).format_map(variables) for piece in command), + *concrete_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) diff --git a/src/ai/backend/manager/api/service.py b/src/ai/backend/manager/api/service.py index a58040bbe78..b3300784554 100644 --- a/src/ai/backend/manager/api/service.py +++ b/src/ai/backend/manager/api/service.py @@ -1025,6 +1025,51 @@ async def delete_route(request: web.Request) -> SuccessResponseModel: return SuccessResponseModel() +@auth_required +@server_status_required(READ_ALLOWED) +@pydantic_response_api_handler +async def restart_model_service(request: web.Request) -> SuccessResponseModel: + """ + Restarts the model service process while retaining the container. Kernels created before 24.12.0 will not work. + """ + root_ctx: RootContext = request.app["_root.context"] + access_key = request["keypair"]["access_key"] + service_id = uuid.UUID(request.match_info["service_id"]) + route_id = uuid.UUID(request.match_info["route_id"]) + + log.info( + "SERVE.RESTART_MODEL_SERVICE (email:{}, ak:{}, s:{})", + request["user"]["email"], + access_key, + service_id, + ) + async with root_ctx.db.begin_readonly_session() as db_sess: + try: + route = await RoutingRow.get(db_sess, route_id, load_endpoint=True) + except NoResultFound: + raise ObjectNotFound + if route.endpoint != service_id: + raise ObjectNotFound + + await get_user_uuid_scopes(request, {"owner_uuid": route.endpoint_row.session_owner}) + if route.status in (RouteStatus.TERMINATING, RouteStatus.FAILED_TO_START): + raise InvalidAPIParameters(f"Cannot remove route in {route.status.name} status") + assert route.session + + session = await SessionRow.get_session( + db_sess, + route.session, + kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY, + ) + + await root_ctx.registry.restart_model_service( + session.main_kernel.agent, + session.main_kernel.id, + ) + + return SuccessResponseModel() + + class TokenRequestModel(BaseModel): duration: tv.TimeDuration | None = Field( default=None, description="The lifetime duration of the token." @@ -1265,5 +1310,6 @@ def create_app( cors.add(add_route("POST", "/{service_id}/sync", sync)) cors.add(add_route("PUT", "/{service_id}/routings/{route_id}", update_route)) cors.add(add_route("DELETE", "/{service_id}/routings/{route_id}", delete_route)) + cors.add(add_route("POST", "/{service_id}/routings/{route_id}/restart", restart_model_service)) cors.add(add_route("POST", "/{service_id}/token", generate_token)) return app, [] diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index db1437719fc..efec88ca710 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -3678,6 +3678,14 @@ async def get_abusing_report( "abuse_report": result, } + async def restart_model_service( + self, + agent_id: AgentId, + kernel_id: KernelId, + ) -> None: + async with self.agent_cache.rpc_context(agent_id) as rpc: + await rpc.call.restart_model_service(str(kernel_id)) + async def update_appproxy_endpoint_routes( self, db_sess: AsyncSession, endpoint: EndpointRow, active_routes: list[RoutingRow] ) -> None: