Skip to content

Commit c36b28e

Browse files
committed
Remove single List[str] input handling in serving_embedding.py
- The handling for a single string in a list can be removed as #7396 is merged. - Add UT cases in test_openai_server for such case
1 parent 61b352d commit c36b28e

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

python/sglang/srt/entrypoints/openai/serving_embedding.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,7 @@ def _convert_to_internal_request(
7979
prompt_kwargs = {"text": prompt}
8080
elif isinstance(prompt, list):
8181
if len(prompt) > 0 and isinstance(prompt[0], str):
82-
# List of strings - if it's a single string in a list, treat as single string
83-
if len(prompt) == 1:
84-
prompt_kwargs = {"text": prompt[0]}
85-
else:
86-
prompt_kwargs = {"text": prompt}
82+
prompt_kwargs = {"text": prompt}
8783
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
8884
# Handle multimodal embedding inputs
8985
texts = []

test/srt/test_openai_server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,31 @@ def test_embedding_batch(self):
535535
self.assertTrue(len(response.data[0].embedding) > 0)
536536
self.assertTrue(len(response.data[1].embedding) > 0)
537537

538+
def test_embedding_single_batch_str(self):
539+
"""Test embedding with a List[str] and length equals to 1"""
540+
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
541+
response = client.embeddings.create(model=self.model, input=["Hello world"])
542+
self.assertEqual(len(response.data), 1)
543+
self.assertTrue(len(response.data[0].embedding) > 0)
544+
545+
def test_embedding_single_int_list(self):
546+
"""Test embedding with a List[int] or List[List[int]]]"""
547+
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
548+
response = client.embeddings.create(
549+
model=self.model,
550+
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
551+
)
552+
self.assertEqual(len(response.data), 1)
553+
self.assertTrue(len(response.data[0].embedding) > 0)
554+
555+
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
556+
response = client.embeddings.create(
557+
model=self.model,
558+
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
559+
)
560+
self.assertEqual(len(response.data), 1)
561+
self.assertTrue(len(response.data[0].embedding) > 0)
562+
538563
def test_empty_string_embedding(self):
539564
"""Test embedding an empty string."""
540565

0 commit comments

Comments
 (0)