Skip to content

Commit 7f2839b

Browse files
committed
WIP/POC commit - may want to split apart into separate PR
1 parent 4f657c3 commit 7f2839b

File tree

4 files changed

+75
-35
lines changed

4 files changed

+75
-35
lines changed

launch/dynamo-run/src/input/http.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub async fn run(
3232
.port(flags.http_port)
3333
.enable_chat_endpoints(true)
3434
.enable_cmpl_endpoints(true)
35+
.enable_embeddings_endpoints(true)
3536
.with_request_template(template)
3637
.build()?;
3738
match engine_config {

launch/dynamo-run/src/subprocess/sglang_inc.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,38 @@ async def generate(self, request):
7777
num_output_tokens_so_far = next_total_toks
7878

7979

80+
class EmbeddingRequestHandler(RequestHandler):
81+
"""
82+
Request handler for the embedding endpoint
83+
"""
84+
85+
async def generate(self, request):
86+
gen = await self.engine_client.async_encode(prompt=request["input"])
87+
tokens = 0
88+
embeddings = []
89+
for idx, res in enumerate(gen):
90+
embeddings.append(
91+
{
92+
"index": idx,
93+
"object": "embedding",
94+
"embedding": res["embedding"],
95+
}
96+
)
97+
tokens += res["meta_info"]["prompt_tokens"]
98+
99+
out = {
100+
"object": "list",
101+
"model": "TODO",
102+
"data": embeddings,
103+
"usage": {
104+
"prompt_tokens": tokens,
105+
"total_tokens": tokens,
106+
},
107+
}
108+
109+
yield out
110+
111+
80112
@dynamo_worker(static=False)
81113
async def worker(runtime: DistributedRuntime):
82114
await init(runtime, cmd_line_args())
@@ -94,8 +126,8 @@ async def init(runtime: DistributedRuntime, config: Config):
94126
"base_gpu_id": config.base_gpu_id,
95127
}
96128

97-
if config.kv_block_size:
98-
arg_map["page_size"] = config.kv_block_size
129+
# if config.kv_block_size:
130+
# arg_map["page_size"] = config.kv_block_size
99131

100132
if config.context_length:
101133
arg_map["context_length"] = config.context_length
@@ -129,13 +161,18 @@ async def init(runtime: DistributedRuntime, config: Config):
129161
await component.create_service()
130162

131163
endpoint = component.endpoint(config.endpoint)
132-
await register_llm(
133-
ModelType.Backend, endpoint, config.model_path, config.model_name
164+
model_type = (
165+
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
134166
)
167+
await register_llm(model_type, endpoint, config.model_path, config.model_name)
135168

136169
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
137170
# after the lease is revoked
138-
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
171+
await endpoint.serve_endpoint(
172+
RequestHandler(engine_client).generate
173+
if not engine_args.is_embedding
174+
else EmbeddingRequestHandler(engine_client).generate
175+
)
139176

140177

141178
def cmd_line_args():
@@ -230,7 +267,6 @@ def cmd_line_args():
230267
config.node_rank = args.node_rank
231268
config.dist_init_addr = args.dist_init_addr
232269
config.extra_engine_args = args.extra_engine_args
233-
234270
return config
235271

236272

lib/llm/src/http/service/service_v2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
7575
#[builder(default = "true")]
7676
enable_cmpl_endpoints: bool,
7777

78-
#[builder(default = "false")]
78+
#[builder(default = "true")]
7979
enable_embeddings_endpoints: bool,
8080

8181
#[builder(default = "None")]

lib/llm/src/preprocessor/prompt/template/formatters.rs

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ impl HfTokenizerConfigJsonFormatter {
4242
pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> {
4343
let mut env = JinjaEnvironment::default().env();
4444

45-
let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
46-
"chat_template field is required in the tokenizer_config.json file"
47-
))?;
45+
// let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
46+
// "chat_template field is required in the tokenizer_config.json file"
47+
// ))?;
4848

4949
// add pycompat
5050
// todo: should we use this: minijinja_contrib::add_to_environment(&mut env);
@@ -57,40 +57,43 @@ impl HfTokenizerConfigJsonFormatter {
5757

5858
let mut supports_add_generation_prompt = None;
5959

60-
match &chat_template.0 {
61-
Either::Left(x) => {
62-
if x.contains("add_generation_prompt") {
63-
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
64-
supports_add_generation_prompt = Some(true);
60+
if let Some(chat_template) = config.chat_template.as_ref() {
61+
match &chat_template.0 {
62+
Either::Left(x) => {
63+
if x.contains("add_generation_prompt") {
64+
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
65+
supports_add_generation_prompt = Some(true);
66+
}
67+
env.add_template_owned("default", x.to_string())?;
68+
env.add_template_owned("tool_use", x.to_string())?;
6569
}
66-
env.add_template_owned("default", x.to_string())?;
67-
env.add_template_owned("tool_use", x.to_string())?;
68-
}
69-
Either::Right(map) => {
70-
for t in map {
71-
for (k, v) in t.iter() {
72-
if v.contains("add_generation_prompt") {
73-
match supports_add_generation_prompt {
74-
Some(true) | None => {
75-
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
76-
supports_add_generation_prompt = Some(true);
77-
}
78-
Some(false) => {
79-
tracing::warn!("Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt.");
70+
Either::Right(map) => {
71+
for t in map {
72+
for (k, v) in t.iter() {
73+
if v.contains("add_generation_prompt") {
74+
match supports_add_generation_prompt {
75+
Some(true) | None => {
76+
tracing::debug!("Chat template contains `add_generation_prompt` key. This model supports add_generation_prompt.");
77+
supports_add_generation_prompt = Some(true);
78+
}
79+
Some(false) => {
80+
tracing::warn!("Not all templates contain `add_generation_prompt` key. This model does not support add_generation_prompt.");
81+
}
8082
}
83+
} else {
84+
supports_add_generation_prompt = Some(false);
8185
}
82-
} else {
83-
supports_add_generation_prompt = Some(false);
86+
env.add_template_owned(k.to_string(), v.to_string())?;
8487
}
85-
env.add_template_owned(k.to_string(), v.to_string())?;
8688
}
87-
}
88-
if env.templates().count() == 0 {
89-
anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
89+
if env.templates().count() == 0 {
90+
anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
91+
}
9092
}
9193
}
9294
}
9395

96+
9497
Ok(HfTokenizerConfigJsonFormatter {
9598
env,
9699
config,

0 commit comments

Comments
 (0)