Skip to content

Commit 6f4258a

Browse files
authored
p2l stuff (#3660)
1 parent 8664268 commit 6f4258a

File tree

4 files changed

+115
-5
lines changed

4 files changed

+115
-5
lines changed

fastchat/model/model_adapter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
24892489

24902490
class NoSystemAdapter(BaseModelAdapter):
24912491
def match(self, model_path: str):
2492-
keyword_list = ["athene-70b"]
2492+
keyword_list = ["athene-70b", "p2l"]
24932493

24942494
for keyword in keyword_list:
24952495
if keyword == model_path.lower():

fastchat/serve/api_provider.py

+76
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,17 @@ def get_api_provider_stream_iter(
246246
api_key=model_api_dict["api_key"],
247247
conversation_id=state.conv_id,
248248
)
249+
elif model_api_dict["api_type"] == "p2l":
250+
prompt = conv.to_openai_api_messages()
251+
stream_iter = p2l_api_stream_iter(
252+
model_api_dict["model_name"],
253+
prompt,
254+
temperature,
255+
top_p,
256+
max_new_tokens,
257+
api_base=model_api_dict["api_base"],
258+
api_key=model_api_dict["api_key"],
259+
)
249260
else:
250261
raise NotImplementedError()
251262

@@ -412,6 +423,71 @@ def column_api_stream_iter(
412423
}
413424

414425

426+
def p2l_api_stream_iter(
427+
model_name,
428+
messages,
429+
temperature,
430+
top_p,
431+
max_new_tokens,
432+
api_base=None,
433+
api_key=None,
434+
):
435+
import openai
436+
437+
client = openai.OpenAI(
438+
base_url=api_base,
439+
api_key=api_key or "-",
440+
timeout=180,
441+
)
442+
443+
# Make requests for logging
444+
text_messages = []
445+
for message in messages:
446+
if type(message["content"]) == str: # text-only model
447+
text_messages.append(message)
448+
else: # vision model
449+
filtered_content_list = [
450+
content for content in message["content"] if content["type"] == "text"
451+
]
452+
text_messages.append(
453+
{"role": message["role"], "content": filtered_content_list}
454+
)
455+
456+
gen_params = {
457+
"model": model_name,
458+
"prompt": text_messages,
459+
"temperature": None,
460+
"top_p": None,
461+
"max_new_tokens": max_new_tokens,
462+
}
463+
logger.info(f"==== request ====\n{gen_params}")
464+
465+
res = client.chat.completions.create(
466+
model=model_name,
467+
messages=messages,
468+
max_tokens=max_new_tokens,
469+
stream=True,
470+
)
471+
text = ""
472+
for chunk_idx, chunk in enumerate(res):
473+
if len(chunk.choices) > 0:
474+
text += chunk.choices[0].delta.content or ""
475+
476+
data = {
477+
"text": text,
478+
"error_code": 0,
479+
}
480+
481+
if chunk_idx == 0:
482+
if hasattr(chunk.choices[0].delta, "model"):
483+
data["ans_model"] = chunk.choices[0].delta.model
484+
485+
if hasattr(chunk, "router_outputs"):
486+
data["router_outputs"] = chunk.router_outputs
487+
488+
yield data
489+
490+
415491
def upload_openai_file_to_gcs(file_id):
416492
import openai
417493
from google.cloud import storage

fastchat/serve/gradio_web_server.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
import time
1313
import uuid
14-
from typing import List
14+
from typing import List, Dict
1515

1616
import gradio as gr
1717
import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119
self.model_name = model_name
120120
self.oai_thread_id = None
121121
self.is_vision = is_vision
122+
self.ans_models = []
123+
self.router_outputs = []
122124

123125
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126
self.has_csam_image = False
@@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
128130
self.regen_support = False
129131
self.init_system_prompt(self.conv, is_vision)
130132

133+
def update_ans_models(self, ans: str) -> None:
134+
self.ans_models.append(ans)
135+
136+
def update_router_outputs(self, outputs: Dict[str, float]) -> None:
137+
self.router_outputs.append(outputs)
138+
131139
def init_system_prompt(self, conv, is_vision):
132140
system_prompt = conv.get_system_message(is_vision)
133141
if len(system_prompt) == 0:
@@ -154,6 +162,20 @@ def dict(self):
154162
}
155163
)
156164

165+
if self.ans_models:
166+
base.update(
167+
{
168+
"ans_models": self.ans_models,
169+
}
170+
)
171+
172+
if self.router_outputs:
173+
base.update(
174+
{
175+
"router_outputs": self.router_outputs,
176+
}
177+
)
178+
157179
if self.is_vision:
158180
base.update({"has_csam_image": self.has_csam_image})
159181
return base
@@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):
420442

421443

422444
def bot_response(
423-
state,
445+
state: State,
424446
temperature,
425447
top_p,
426448
max_new_tokens,
@@ -532,6 +554,18 @@ def bot_response(
532554
try:
533555
data = {"text": ""}
534556
for i, data in enumerate(stream_iter):
557+
# Change for P2L:
558+
if i == 0:
559+
if "ans_model" in data:
560+
ans_model = data.get("ans_model")
561+
562+
state.update_ans_models(ans_model)
563+
564+
if "router_outputs" in data:
565+
router_outputs = data.get("router_outputs")
566+
567+
state.update_router_outputs(router_outputs)
568+
535569
if data["error_code"] == 0:
536570
output = data["text"].strip()
537571
conv.update_last_message(output + "▌")

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dependencies = [
1919
]
2020

2121
[project.optional-dependencies]
22-
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
23-
webui = ["gradio>=4.10"]
22+
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"]
23+
webui = ["gradio>=4.10", "plotly", "scipy"]
2424
train = ["einops", "flash-attn>=2.0", "wandb"]
2525
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
2626
dev = ["black==23.3.0", "pylint==2.8.2"]

0 commit comments

Comments
 (0)