-
Notifications
You must be signed in to change notification settings - Fork 676
feat: dynamic endpoint registration #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 19 commits
b317939
6d25ddc
f90eb88
3b492f4
2124d91
ac9357e
999c716
889f0c5
9562483
48bc524
ad934a4
c2d10ee
f3d74b8
3820e80
a791ca7
5886ae4
5166638
35aa49e
b0ae2d4
c5307e3
f50c54b
3e6bf2c
f6f7c91
e774001
4502810
18f590d
0cb978a
145fb06
95e3a4f
cdc4da0
3d79298
7ffe3e2
28e268d
7852920
a3e2839
11e8439
2b7d61a
6f194ee
c6c39cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # SGLang Native APIs: https://docs.sglang.ai/basic_usage/native_api.html | ||
| # Code: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py | ||
|
|
||
| import asyncio | ||
| import logging | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| import sglang as sgl | ||
| from sglang.srt.managers.io_struct import ProfileReqInput | ||
|
|
||
| from dynamo._core import Component | ||
|
|
||
|
|
||
| class NativeApiHandler: | ||
| """Handler to add sglang native API endpoints to workers""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| component: Component, | ||
| engine: sgl.Engine, | ||
| metrics_labels: Optional[List[Tuple[str, str]]] = None, | ||
| ): | ||
| self.component = component | ||
| self.engine = engine | ||
| self.metrics_labels = metrics_labels | ||
| self.native_api_tasks = [] | ||
|
|
||
| async def init_native_apis( | ||
| self, | ||
| ) -> List[asyncio.Task]: | ||
| """ | ||
| Initialize and register native API endpoints. | ||
| Returns list of tasks to be gathered. | ||
| """ | ||
| logging.info("Initializing native SGLang API endpoints") | ||
|
|
||
| self.tm = self.engine.tokenizer_manager | ||
|
|
||
| tasks = [] | ||
|
|
||
| model_info_ep = self.component.endpoint("get_model_info") | ||
| start_profile_ep = self.component.endpoint("start_profile") | ||
| stop_profile_ep = self.component.endpoint("stop_profile") | ||
| tasks.extend( | ||
| [ | ||
| model_info_ep.serve_endpoint( | ||
| self.get_model_info, | ||
| graceful_shutdown=True, | ||
| metrics_labels=self.metrics_labels, | ||
| http_endpoint_path="/get_model_info", | ||
| ), | ||
| start_profile_ep.serve_endpoint( | ||
| self.start_profile, | ||
| graceful_shutdown=True, | ||
| metrics_labels=self.metrics_labels, | ||
| http_endpoint_path="/start_profile", | ||
| ), | ||
| stop_profile_ep.serve_endpoint( | ||
| self.stop_profile, | ||
| graceful_shutdown=True, | ||
| metrics_labels=self.metrics_labels, | ||
| http_endpoint_path="/stop_profile", | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| self.native_api_tasks = tasks | ||
| logging.info(f"Registered {len(tasks)} native API endpoints") | ||
| return tasks | ||
|
|
||
| async def get_model_info(self, request: dict): | ||
| result = { | ||
| "model_path": self.tm.server_args.model_path, | ||
| "tokenizer_path": self.tm.server_args.tokenizer_path, | ||
| "preferred_sampling_params": self.tm.server_args.preferred_sampling_params, | ||
| "weight_version": self.tm.server_args.weight_version, | ||
| } | ||
|
|
||
| yield {"data": [result]} | ||
|
|
||
| async def start_profile(self, request: dict): | ||
| try: | ||
| obj = ProfileReqInput.model_validate(request) | ||
| except Exception: | ||
| obj = None | ||
|
|
||
| if obj is None: | ||
| obj = ProfileReqInput() | ||
|
|
||
| output_dir = obj.output_dir or f"profile_{self.tm.server_args.model_path}" | ||
|
|
||
| await self.tm.start_profile( | ||
| output_dir=output_dir, | ||
| start_step=obj.start_step, | ||
| num_steps=obj.num_steps, | ||
| activities=obj.activities, | ||
| with_stack=obj.with_stack, | ||
| record_shapes=obj.record_shapes, | ||
| profile_by_stage=obj.profile_by_stage, | ||
| ) | ||
|
|
||
| yield {"data": [{"status": "started profile"}]} | ||
|
|
||
| async def stop_profile(self, request: dict): | ||
| asyncio.create_task(self.tm.stop_profile()) | ||
| yield { | ||
| "data": [ | ||
| { | ||
| "status": ( | ||
| "Stopped profile. This might take a long time to complete. " | ||
| f"Results should be available in the 'profile_{self.tm.server_args.model_path}' directory." | ||
| ) | ||
| } | ||
| ] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| use super::{RouteDoc, service_v2}; | ||
| use crate::types::Annotated; | ||
| use axum::{ | ||
| Json, Router, | ||
| http::{Method, StatusCode}, | ||
| response::IntoResponse, | ||
| routing::post, | ||
| }; | ||
| use dynamo_runtime::instances::list_all_instances; | ||
| use dynamo_runtime::{DistributedRuntime, Runtime, component::Client}; | ||
| use dynamo_runtime::{pipeline::PushRouter, stream::StreamExt}; | ||
| use std::sync::Arc; | ||
|
|
||
| pub fn dynamic_endpoint_router( | ||
| state: Arc<service_v2::State>, | ||
| path: Option<String>, | ||
| ) -> (Vec<RouteDoc>, Router) { | ||
| let wildcard_path = "/{*path}"; | ||
| let path = path.unwrap_or_else(|| wildcard_path.to_string()); | ||
|
|
||
| let docs: Vec<RouteDoc> = vec![RouteDoc::new(Method::POST, &path)]; | ||
|
|
||
| let router = Router::new() | ||
| .route(&path, post(dynamic_endpoint_handler)) | ||
| .with_state(state); | ||
|
|
||
| (docs, router) | ||
| } | ||
|
|
||
| async fn inner_dynamic_endpoint_handler( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A general design question - correct me if I'm understanding the changes properly or not Currently, this change more or less lets you call an arbitrary route
This is my understanding of the current changes. Assuming it's roughly accurate - my next question is why not use something like the I'm a little hesitant about the additional per-request checking here rather than something more discovery-oriented.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a blocking comment yet, just looking to understand the approach here better, and get more context on use case, etc.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreeing with Ryan here. I would rather attach a route for the endpoint when it appears, instead of attaching a wildcard route and hitting etcd on every request.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I believe it will go to the specific handler.
This is actually how I implemented it before. But I feel like it made things a little bit messy. Why have a separate map when it can all just be under a single endpoint entry? If the endpoint goes down, this will also take the etcd entry down which will also take the
Endpoints that are implemented in this fashion are not meant to be endpoints where we serve heavy traffic by any means. Checking etcd here doesn't seem like it costs much. Why have another watcher?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think right now it's a "hey you can expose any endpoint here" feature - so I won't be surprised at all if someone tries calling some native endpoint for some other use case that maybe dynamo doesn't natively support yet but the framework does as a stopgap solution until we do support it. And if so, then we lose this assumption right? I think I'm less worried optimiznig for the extreme heavy load case on one of these custom endpoints (in terms of expecting it to happen) and moreso just general code smell of doing unnecessary work and checking something on every request if we can instead only act to do the bare minimum when necessary (on discovery). At the end of the day we have limited resources (threads, CPUs, etc.) and the less we use them, the more resources the heavy load endpoints (chat, completions, nats, etcd, etc.) have to work freely with and less we have to worry about later. For example, if any of these native endpoints are things that may not get heavy load, but may get polled say every second or every few seconds, that could be non trivial at some point. Though this implementation is completely custom support for anything, so I can't really guess what all it would be used for.
Do you have a draft/commit to refer to the original solution?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not complete but check out b317939
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spoke to Ryan offline. I think my approach before stemmed from a lack of understanding the watcher pattern and how our ModelWatcher worked. Refactored in latest commit |
||
| state: Arc<service_v2::State>, | ||
| path: String, | ||
| ) -> Result<impl IntoResponse, &'static str> { | ||
| let etcd_client = state.etcd_client().ok_or("Failed to get etcd client")?; | ||
|
|
||
| let instances = list_all_instances(etcd_client) | ||
| .await | ||
| .map_err(|_| "Failed to get instances")?; | ||
|
|
||
| let dynamic_endpoints = instances | ||
| .iter() | ||
| .filter_map(|instance| instance.http_endpoint_path.clone()) | ||
| .collect::<Vec<String>>(); | ||
|
|
||
| let fmt_path = format!("/{}", &path); | ||
| if !dynamic_endpoints.contains(&fmt_path) { | ||
| return Err("Dynamic endpoint not found"); | ||
ishandhanani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| let rt = Runtime::from_current().map_err(|_| "Failed to get runtime")?; | ||
| let drt = DistributedRuntime::from_settings(rt) | ||
ishandhanani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .await | ||
| .map_err(|_| "Failed to get distributed runtime")?; | ||
|
|
||
| let target_instances = instances | ||
ishandhanani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .iter() | ||
| .filter(|instance| instance.http_endpoint_path == Some(fmt_path.clone())) | ||
| .collect::<Vec<_>>(); | ||
|
|
||
| let mut target_clients: Vec<Client> = Vec::new(); | ||
| for instance in target_instances { | ||
| let ns = drt | ||
| .namespace(instance.namespace.clone()) | ||
| .map_err(|_| "Failed to get namespace")?; | ||
| let c = ns | ||
| .component(instance.component.clone()) | ||
| .map_err(|_| "Failed to get component")?; | ||
| let ep = c.endpoint(path.clone()); | ||
| let client = ep.client().await.map_err(|_| "Failed to get client")?; | ||
| target_clients.push(client); | ||
| } | ||
|
|
||
| let mut all_responses = Vec::new(); | ||
| for client in target_clients { | ||
| let router = | ||
| PushRouter::<(), Annotated<serde_json::Value>>::from_client(client, Default::default()) | ||
| .await | ||
| .map_err(|_| "Failed to get router")?; | ||
|
|
||
| let mut stream = router | ||
| .round_robin(().into()) | ||
ishandhanani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .await | ||
| .map_err(|_| "Failed to route")?; | ||
|
|
||
| while let Some(resp) = stream.next().await { | ||
| all_responses.push(resp); | ||
| } | ||
| } | ||
|
|
||
| Ok(Json(serde_json::json!({ | ||
| "responses": all_responses | ||
| }))) | ||
| } | ||
|
|
||
| async fn dynamic_endpoint_handler( | ||
ishandhanani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>, | ||
| axum::extract::Path(path): axum::extract::Path<String>, | ||
| ) -> impl IntoResponse { | ||
| inner_dynamic_endpoint_handler(state, path) | ||
| .await | ||
| .map_err(|err_string| { | ||
| ( | ||
| StatusCode::INTERNAL_SERVER_ERROR, | ||
| Json(serde_json::json!({ | ||
| "message": err_string | ||
| })), | ||
| ) | ||
| }) | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.