Skip to content

Commit d59a541

Browse files
committed
feat: implement reverse proxy for custom command services
Signed-off-by: Frost Ming <[email protected]>
1 parent 3c08403 commit d59a541

File tree

9 files changed

+257
-125
lines changed

9 files changed

+257
-125
lines changed

docs/source/build-with-bentoml/services.rst

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,28 @@ Alternatively, compute the command at runtime:
306306
307307
Use this method when there are parameters whose values can only be determined at runtime.
308308

309-
When a custom command is provided, BentoML will launch a single process for that Service using your command. It will set the ``PORT`` environment variable (and ``BENTOML_HOST``/``BENTOML_PORT`` for the entry Service). Your process must listen on the provided ``PORT`` and serve HTTP endpoints. Server-level options like CORS/SSL/timeouts defined in BentoML won't apply automatically—configure them in your own server if needed.
309+
BentoML operates by establishing a proxy service that directs all requests to the HTTP server initiated by the custom command. The default proxy port is ``8000``, specify a different one if the custom command is listening on another port:
310+
311+
.. code-block:: python
312+
313+
@bentoml.service(cmd=["myserver", "--port", "$PORT"], http={"proxy_port": 9000})
314+
class ExternalServer:
315+
pass
316+
317+
Metrics Rewriting
318+
-----------------
319+
320+
When starting a server with a custom command, it can be helpful to include metrics from that server. Alternatively, you can modify the metrics provided by the Prometheus exporter.
321+
To achieve this, you can implement the ``__metrics__`` method in your Service class. This method takes the original metrics text as input and returns the modified metrics text:
322+
323+
.. code-block:: python
324+
325+
@bentoml.service(cmd=["myserver", "--port", "$PORT"])
326+
class ExternalServer:
327+
def __metrics__(self, original_metrics: str) -> str:
328+
# Modify the original metrics as needed
329+
modified_metrics = original_metrics.replace('sglang', 'vllm')
330+
return modified_metrics
310331
311332
.. _bentoml-tasks:
312333

src/_bentoml_impl/server/app.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
from starlette.responses import Response
2020
from starlette.staticfiles import StaticFiles
2121

22+
from _bentoml_impl.server.proxy import create_proxy_app
2223
from _bentoml_sdk import Service
2324
from _bentoml_sdk.service import set_current_service
2425
from bentoml._internal.container import BentoMLContainer
2526
from bentoml._internal.marshal.dispatcher import CorkDispatcher
2627
from bentoml._internal.resource import system_resources
2728
from bentoml._internal.server.base_app import BaseAppFactory
2829
from bentoml._internal.server.http_app import log_exception
30+
from bentoml._internal.utils import is_async_callable
2931
from bentoml._internal.utils.metrics import exponential_buckets
3032
from bentoml.exceptions import BentoMLException
3133
from bentoml.exceptions import ServiceUnavailable
@@ -178,6 +180,10 @@ def __call__(self) -> Starlette:
178180
app = super().__call__()
179181
app.add_route("/schema.json", self.schema_view, name="schema")
180182

183+
if self.service.has_custom_command():
184+
# This may obscure all the routes behind, but this is expected.
185+
self.service.mount_asgi_app(create_proxy_app(self.service), name="proxy")
186+
181187
for mount_app, path, name in self.service.mount_apps:
182188
app.router.routes.append(PassiveMount(path, mount_app, name=name))
183189

@@ -418,6 +424,22 @@ async def readyz(self, _: Request) -> Response:
418424

419425
return PlainTextResponse("\n", status_code=200)
420426

427+
async def metrics(self, _: Request) -> Response: # type: ignore[override]
428+
metrics_client = BentoMLContainer.metrics_client.get()
429+
metrics_content = await anyio.to_thread.run_sync(metrics_client.generate_latest)
430+
if hasattr(self.service.inner, "__metrics__"):
431+
func = getattr(self._service_instance, "__metrics__")
432+
if not is_async_callable(func):
433+
func = functools.partial(anyio.to_thread.run_sync, func)
434+
metrics_content = (await func(metrics_content.decode("utf-8"))).encode(
435+
"utf-8"
436+
)
437+
return Response(
438+
metrics_content,
439+
status_code=200,
440+
media_type=metrics_client.CONTENT_TYPE_LATEST,
441+
)
442+
421443
@contextlib.asynccontextmanager
422444
async def lifespan(self, app: Starlette) -> t.AsyncGenerator[None, None]:
423445
from starlette.applications import Starlette
@@ -430,12 +452,11 @@ async def lifespan(self, app: Starlette) -> t.AsyncGenerator[None, None]:
430452

431453
for mount_app, *_ in self.service.mount_apps:
432454
if isinstance(mount_app, Starlette):
433-
maybe_state = await stack.enter_async_context(
455+
_ = await stack.enter_async_context(
434456
mount_app.router.lifespan_context(mount_app)
435457
)
436-
if maybe_state is not None:
437-
mount_app.state.update(maybe_state)
438-
# TODO: support other ASGI apps
458+
else:
459+
pass # TODO: support other ASGI apps
439460
yield
440461

441462
async def schema_view(self, request: Request) -> Response:

src/_bentoml_impl/server/proxy.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import logging
5+
import typing as t
6+
7+
import anyio
8+
import httpx
9+
from pyparsing import cast
10+
from starlette.requests import Request
11+
12+
from _bentoml_sdk import Service
13+
from bentoml import get_current_service
14+
from bentoml._internal.utils import expand_envs
15+
from bentoml.exceptions import BentoMLConfigException
16+
17+
if t.TYPE_CHECKING:
18+
from starlette.applications import Starlette
19+
20+
logger = logging.getLogger("bentoml.server")
21+
22+
23+
async def _check_health(client: httpx.AsyncClient, health_endpoint: str) -> bool:
24+
try:
25+
response = await client.get(health_endpoint, timeout=5.0)
26+
if response.status_code == 404:
27+
raise BentoMLConfigException(
28+
f"Health endpoint {health_endpoint} not found (404). Please make sure the health "
29+
"endpoint is correctly configured in the service config."
30+
)
31+
return response.is_success
32+
except (httpx.HTTPError, httpx.RequestError):
33+
return False
34+
35+
36+
def create_proxy_app(service: Service[t.Any]) -> Starlette:
37+
"""A reverse-proxy that forwards all requests to the HTTP server started
38+
by the custom command.
39+
"""
40+
import fastapi
41+
from fastapi.responses import StreamingResponse
42+
43+
health_endpoint = service.config.get("endpoints", {}).get("livez", "/health")
44+
45+
@contextlib.asynccontextmanager
46+
async def lifespan(
47+
app: fastapi.FastAPI,
48+
) -> t.AsyncGenerator[dict[str, t.Any], None]:
49+
server_instance = get_current_service()
50+
assert server_instance is not None, "Current service is not initialized"
51+
async with contextlib.AsyncExitStack() as stack:
52+
if cmd_getter := getattr(server_instance, "__command__", None):
53+
if not callable(cmd_getter):
54+
raise TypeError(
55+
f"__command__ must be a callable that returns a list of strings, got {type(cmd_getter)}"
56+
)
57+
cmd = cast("list[str]", cmd_getter())
58+
else:
59+
cmd = service.cmd
60+
assert cmd is not None, "must have a command"
61+
cmd = [expand_envs(c) for c in cmd]
62+
logger.info("Running service with command: %s", " ".join(cmd))
63+
if (
64+
instance_client := getattr(server_instance, "client", None)
65+
) is not None and isinstance(instance_client, httpx.AsyncClient):
66+
# TODO: support aiohttp client
67+
client = instance_client
68+
else:
69+
proxy_port = service.config.get("http", {}).get("proxy_port")
70+
if proxy_port is None:
71+
raise BentoMLConfigException(
72+
"proxy_port must be set in service config to use custom command"
73+
)
74+
proxy_url = f"http://localhost:{proxy_port}"
75+
client = await stack.enter_async_context(
76+
httpx.AsyncClient(base_url=proxy_url, timeout=None)
77+
)
78+
proc = await anyio.open_process(cmd, stdout=None, stderr=None)
79+
while proc.returncode is None:
80+
if await _check_health(client, health_endpoint):
81+
break
82+
await anyio.sleep(0.5)
83+
else:
84+
raise RuntimeError(
85+
"Service process exited before becoming healthy, see the error above"
86+
)
87+
88+
app.state.client = client
89+
try:
90+
state = {"proc": proc, "client": client}
91+
service.context.state.update(state)
92+
yield state
93+
finally:
94+
proc.terminate()
95+
await proc.wait()
96+
97+
assert service.has_custom_command(), "Service does not have custom command"
98+
app = fastapi.FastAPI(lifespan=lifespan)
99+
100+
# TODO: support websocket endpoints
101+
@app.api_route(
102+
"/{path:path}",
103+
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
104+
)
105+
async def reverse_proxy(request: Request, path: str):
106+
url = httpx.URL(
107+
path=f"/{path}", query=request.url.query.encode("utf-8") or None
108+
)
109+
client = t.cast(httpx.AsyncClient, app.state.client)
110+
headers = dict(request.headers)
111+
headers.pop("host", None)
112+
req = client.build_request(
113+
method=request.method, url=url, headers=headers, content=request.stream()
114+
)
115+
try:
116+
resp = await client.send(req, stream=True)
117+
except httpx.ConnectError:
118+
return fastapi.Response(503)
119+
except httpx.RequestError:
120+
return fastapi.Response(500)
121+
122+
return StreamingResponse(
123+
resp.aiter_raw(),
124+
status_code=resp.status_code,
125+
headers=resp.headers,
126+
background=resp.aclose,
127+
)
128+
129+
return app

0 commit comments

Comments
 (0)