Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion docs/source/build-with-bentoml/services.rst
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,28 @@ Alternatively, compute the command at runtime:

Use this method when there are parameters whose values can only be determined at runtime.

When a custom command is provided, BentoML will launch a single process for that Service using your command. It will set the ``PORT`` environment variable (and ``BENTOML_HOST``/``BENTOML_PORT`` for the entry Service). Your process must listen on the provided ``PORT`` and serve HTTP endpoints. Server-level options like CORS/SSL/timeouts defined in BentoML won't apply automatically—configure them in your own server if needed.
BentoML operates by establishing a proxy service that directs all requests to the HTTP server initiated by the custom command. The default proxy port is ``8000``, specify a different one if the custom command is listening on another port:

.. code-block:: python

@bentoml.service(cmd=["myserver", "--port", "$PORT"], http={"proxy_port": 9000})
class ExternalServer:
pass

Metrics Rewriting
-----------------

When starting a server with a custom command, it can be helpful to include metrics from that server. Alternatively, you can modify the metrics provided by the Prometheus exporter.
To achieve this, you can implement the ``__metrics__`` method in your Service class. This method takes the original metrics text as input and returns the modified metrics text:

.. code-block:: python

@bentoml.service(cmd=["myserver", "--port", "$PORT"])
class ExternalServer:
def __metrics__(self, original_metrics: str) -> str:
# Modify the original metrics as needed
modified_metrics = original_metrics.replace('sglang', 'vllm')
return modified_metrics

.. _bentoml-tasks:

Expand Down
33 changes: 29 additions & 4 deletions src/_bentoml_impl/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
from starlette.responses import Response
from starlette.staticfiles import StaticFiles

from _bentoml_impl.server.proxy import create_proxy_app
from _bentoml_sdk import Service
from _bentoml_sdk.service import set_current_service
from bentoml._internal.container import BentoMLContainer
from bentoml._internal.marshal.dispatcher import CorkDispatcher
from bentoml._internal.resource import system_resources
from bentoml._internal.server.base_app import BaseAppFactory
from bentoml._internal.server.http_app import log_exception
from bentoml._internal.types import LazyType
from bentoml._internal.utils import is_async_callable
from bentoml._internal.utils.metrics import exponential_buckets
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import ServiceUnavailable
Expand Down Expand Up @@ -178,6 +181,10 @@ def __call__(self) -> Starlette:
app = super().__call__()
app.add_route("/schema.json", self.schema_view, name="schema")

if self.service.has_custom_command():
# This may obscure all the routes behind, but this is expected.
self.service.mount_asgi_app(create_proxy_app(self.service), name="proxy")

for mount_app, path, name in self.service.mount_apps:
app.router.routes.append(PassiveMount(path, mount_app, name=name))

Expand Down Expand Up @@ -418,6 +425,22 @@ async def readyz(self, _: Request) -> Response:

return PlainTextResponse("\n", status_code=200)

async def metrics(self, _: Request) -> Response: # type: ignore[override]
metrics_client = BentoMLContainer.metrics_client.get()
metrics_content = await anyio.to_thread.run_sync(metrics_client.generate_latest)
if hasattr(self.service.inner, "__metrics__"):
func = getattr(self._service_instance, "__metrics__")
if not is_async_callable(func):
func = functools.partial(anyio.to_thread.run_sync, func)
metrics_content = (await func(metrics_content.decode("utf-8"))).encode(
"utf-8"
)
return Response(
metrics_content,
status_code=200,
media_type=metrics_client.CONTENT_TYPE_LATEST,
)

@contextlib.asynccontextmanager
async def lifespan(self, app: Starlette) -> t.AsyncGenerator[None, None]:
from starlette.applications import Starlette
Expand All @@ -430,12 +453,14 @@ async def lifespan(self, app: Starlette) -> t.AsyncGenerator[None, None]:

for mount_app, *_ in self.service.mount_apps:
if isinstance(mount_app, Starlette):
maybe_state = await stack.enter_async_context(
await stack.enter_async_context(
mount_app.router.lifespan_context(mount_app)
)
if maybe_state is not None:
mount_app.state.update(maybe_state)
# TODO: support other ASGI apps
elif LazyType("quart.Quart").isinstance(mount_app):
await mount_app.startup() # type: ignore[attr-defined]
stack.push_async_callback(mount_app.shutdown) # type: ignore[attr-defined]
else:
pass # TODO: support other ASGI apps
yield

async def schema_view(self, request: Request) -> Response:
Expand Down
125 changes: 125 additions & 0 deletions src/_bentoml_impl/server/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import contextlib
import logging
import typing as t

import anyio
import httpx
from pyparsing import cast
from starlette.requests import Request

from _bentoml_sdk import Service
from bentoml import get_current_service
from bentoml._internal.utils import expand_envs
from bentoml.exceptions import BentoMLConfigException

if t.TYPE_CHECKING:
from starlette.applications import Starlette

logger = logging.getLogger("bentoml.server")


async def _check_health(client: httpx.AsyncClient, health_endpoint: str) -> bool:
try:
response = await client.get(health_endpoint, timeout=5.0)
if response.status_code == 404:
raise BentoMLConfigException(
f"Health endpoint {health_endpoint} not found (404). Please make sure the health "
"endpoint is correctly configured in the service config."
)
return response.is_success
except (httpx.HTTPError, httpx.RequestError):
return False


def create_proxy_app(service: Service[t.Any]) -> Starlette:
"""A reverse-proxy that forwards all requests to the HTTP server started
by the custom command.
"""
import fastapi
from fastapi.responses import StreamingResponse

health_endpoint = service.config.get("endpoints", {}).get("livez", "/health")

@contextlib.asynccontextmanager
async def lifespan(
app: fastapi.FastAPI,
) -> t.AsyncGenerator[dict[str, t.Any], None]:
server_instance = get_current_service()
assert server_instance is not None, "Current service is not initialized"
async with contextlib.AsyncExitStack() as stack:
if cmd_getter := getattr(server_instance, "__command__", None):
if not callable(cmd_getter):
raise TypeError(
f"__command__ must be a callable that returns a list of strings, got {type(cmd_getter)}"
)
cmd = cast("list[str]", cmd_getter())
else:
cmd = service.cmd
assert cmd is not None, "must have a command"
cmd = [expand_envs(c) for c in cmd]
logger.info("Running service with command: %s", " ".join(cmd))
if (
instance_client := getattr(server_instance, "client", None)
) is not None and isinstance(instance_client, httpx.AsyncClient):
# TODO: support aiohttp client
client = instance_client
else:
proxy_port = service.config.get("http", {}).get("proxy_port", 8000)
proxy_url = f"http://localhost:{proxy_port}"
client = await stack.enter_async_context(
httpx.AsyncClient(base_url=proxy_url, timeout=None)
)
proc = await anyio.open_process(cmd, stdout=None, stderr=None)
while proc.returncode is None:
if await _check_health(client, health_endpoint):
break
await anyio.sleep(0.5)
else:
raise RuntimeError(
"Service process exited before becoming healthy, see the error above"
)

app.state.client = client
try:
state = {"proc": proc, "client": client}
service.context.state.update(state)
yield state
finally:
proc.terminate()
await proc.wait()

assert service.has_custom_command(), "Service does not have custom command"
app = fastapi.FastAPI(lifespan=lifespan)

# TODO: support websocket endpoints
@app.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
)
async def reverse_proxy(request: Request, path: str):
url = httpx.URL(
path=f"/{path}", query=request.url.query.encode("utf-8") or None
)
client = t.cast(httpx.AsyncClient, app.state.client)
headers = dict(request.headers)
headers.pop("host", None)
req = client.build_request(
method=request.method, url=url, headers=headers, content=request.stream()
)
try:
resp = await client.send(req, stream=True)
except httpx.ConnectError:
return fastapi.Response(503)
except httpx.RequestError:
return fastapi.Response(500)

return StreamingResponse(
resp.aiter_raw(),
status_code=resp.status_code,
headers=resp.headers,
background=resp.aclose,
)

return app
Loading
Loading