Skip to content

Commit

Permalink
Adjust the limit for conversations (#1278)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored May 16, 2023
1 parent c86618f commit e276c2f
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 49 deletions.
13 changes: 11 additions & 2 deletions fastchat/constants.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from enum import IntEnum


# For the gradio web server
SERVER_ERROR_MSG = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
INPUT_CHAR_LEN_LIMIT = 2560
CONVERSATION_LEN_LIMIT = 50
LOGDIR = "."

# For the controller and workers
CONTROLLER_HEART_BEAT_EXPIRATION = 90
WORKER_HEART_BEAT_INTERVAL = 30
WORKER_API_TIMEOUT = 20

LOGDIR = "."


class ErrorCode(IntEnum):
"""
Expand Down
2 changes: 1 addition & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Conversation:
stop_token_ids: List[int] = None

# Used for the state in the gradio servers.
# TODO(lmzheng): refactor this
# TODO(lmzheng): move this out of this class.
conv_id: Any = None
skip_next: bool = False
model_name: str = None
Expand Down
2 changes: 2 additions & 0 deletions fastchat/model/chatglm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def chatglm_generate_stream(

input_echo_len = stream_chat_token_num(tokenizer, query, hist)

output = ""
i = 0
for i, (response, new_hist) in enumerate(
model.stream_chat(tokenizer, query, hist, **gen_kwargs)
):
Expand Down
2 changes: 1 addition & 1 deletion fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_model_info(name: str) -> ModelInfo:
["bard"],
"Bard",
"https://bard.google.com/",
"Bard by Google",
"Bard based on the PaLM 2 Chat API by Google",
)
register_model_info(
["vicuna-13b"],
Expand Down
2 changes: 1 addition & 1 deletion fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens):
while pos < len(content):
# This is a fancy way to simulate token generation latency combined
# with a Poisson process.
pos += random.randint(1, 5)
pos += random.randint(10, 20)
time.sleep(random.expovariate(50))
data = {
"text": content[:pos],
Expand Down
25 changes: 13 additions & 12 deletions fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import requests
import uvicorn

from fastchat.constants import CONTROLLER_HEART_BEAT_EXPIRATION, ErrorCode
from fastchat.utils import build_logger, server_error_msg
from fastchat.constants import (CONTROLLER_HEART_BEAT_EXPIRATION, ErrorCode,
SERVER_ERROR_MSG)
from fastchat.utils import build_logger


logger = build_logger("controller", "controller.log")
Expand Down Expand Up @@ -196,26 +197,26 @@ def remove_stable_workers_by_expiration(self):
for worker_name in to_delete:
self.remove_worker(worker_name)

def handle_no_worker(params, server_error_msg):
def handle_no_worker(params):
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"text": SERVER_ERROR_MSG,
"error_code": ErrorCode.CONTROLLER_NO_WORKER,
}
return json.dumps(ret).encode() + b"\0"

def handle_worker_timeout(worker_address, server_error_msg):
def handle_worker_timeout(worker_address):
logger.info(f"worker timeout: {worker_address}")
ret = {
"text": server_error_msg,
"text": SERVER_ERROR_MSG,
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
}
return json.dumps(ret).encode() + b"\0"

def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
yield self.handle_no_worker(params, server_error_msg)
yield self.handle_no_worker(params)

try:
response = requests.post(
Expand All @@ -228,12 +229,12 @@ def worker_api_generate_stream(self, params):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
yield self.handle_worker_timeout(worker_addr, server_error_msg)
yield self.handle_worker_timeout(worker_addr)

def worker_api_generate_completion(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
return self.handle_no_worker(params, server_error_msg)
return self.handle_no_worker(params)

try:
response = requests.post(
Expand All @@ -243,12 +244,12 @@ def worker_api_generate_completion(self, params):
)
return response.json()
except requests.exceptions.RequestException as e:
return self.handle_worker_timeout(worker_addr, server_error_msg)
return self.handle_worker_timeout(worker_addr)

def worker_api_embeddings(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
return self.handle_no_worker(params, server_error_msg)
return self.handle_no_worker(params)

try:
response = requests.post(
Expand All @@ -258,7 +259,7 @@ def worker_api_embeddings(self, params):
)
return response.json()
except requests.exceptions.RequestException as e:
return self.handle_worker_timeout(worker_addr, server_error_msg)
return self.handle_worker_timeout(worker_addr)

# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
Expand Down
31 changes: 26 additions & 5 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import gradio as gr
import numpy as np

from fastchat.constants import (
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
INPUT_CHAR_LEN_LIMIT,
CONVERSATION_LEN_LIMIT,
)
from fastchat.model.model_adapter import get_conversation_template
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_web_server import (
Expand All @@ -22,7 +28,6 @@
from fastchat.utils import (
build_logger,
violates_moderation,
moderation_msg,
)

logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
Expand Down Expand Up @@ -223,14 +228,30 @@ def add_text(state0, state1, text, request: gr.Request):
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [moderation_msg]
+ [MODERATION_MSG]
+ [
no_change_btn,
]
* 6
)

text = text[:1536] # Hard cut-off
if (len(states[0].messages) - states[0].offset) // 2 >= CONVERSATION_LEN_LIMIT:
logger.info(
f"hit conversation length limit. ip: {request.client.host}. text: {text}"
)
for i in range(num_models):
states[i].skip_next = True
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [CONVERSATION_LIMIT_MSG]
+ [
no_change_btn,
]
* 6
)

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_models):
states[i].append_message(states[i].roles[0], text)
states[i].append_message(states[i].roles[1], None)
Expand Down Expand Up @@ -312,7 +333,7 @@ def build_side_by_side_ui_anony(models):
### Rules
- Chat with two anonymous models side-by-side and vote for which one is better!
- You can do multiple rounds of conversations before voting.
- The names of the models will be revealed after your vote.
- The names of the models will be revealed after your vote. Do not ask for chatbot names as conversations with identity keywords (e.g., ChatGPT, Bard, Vicuna) will not count towards the leaderboard.
- Click "Clear history" to start a new round.
- [[Blog](https://lmsys.org/blog/2023-05-03-arena/)] [[GitHub]](https://github.com/lm-sys/FastChat) [[Twitter]](https://twitter.com/lmsysorg) [[Discord]](https://discord.gg/h6kCZb72G7)
Expand Down Expand Up @@ -383,7 +404,7 @@ def build_side_by_side_ui_anony(models):
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
minimum=16,
maximum=1024,
value=512,
step=64,
Expand Down
29 changes: 25 additions & 4 deletions fastchat/serve/gradio_block_arena_named.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import gradio as gr
import numpy as np

from fastchat.constants import (
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
INPUT_CHAR_LEN_LIMIT,
CONVERSATION_LEN_LIMIT,
)
from fastchat.model.model_adapter import get_conversation_template
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_web_server import (
Expand All @@ -23,7 +29,6 @@
from fastchat.utils import (
build_logger,
violates_moderation,
moderation_msg,
)


Expand Down Expand Up @@ -174,14 +179,30 @@ def add_text(state0, state1, text, request: gr.Request):
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [moderation_msg]
+ [MODERATION_MSG]
+ [
no_change_btn,
]
* 6
)

text = text[:1536] # Hard cut-off
if (len(states[0].messages) - states[0].offset) // 2 >= CONVERSATION_LEN_LIMIT:
logger.info(
f"hit conversation length limit. ip: {request.client.host}. text: {text}"
)
for i in range(num_models):
states[i].skip_next = True
return (
states
+ [x.to_gradio_chatbot() for x in states]
+ [CONVERSATION_LIMIT_MSG]
+ [
no_change_btn,
]
* 6
)

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_models):
states[i].append_message(states[i].roles[0], text)
states[i].append_message(states[i].roles[1], None)
Expand Down Expand Up @@ -356,7 +377,7 @@ def build_side_by_side_ui_named(models):
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
minimum=16,
maximum=1024,
value=512,
step=64,
Expand Down
34 changes: 25 additions & 9 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,16 @@
import requests

from fastchat.conversation import SeparatorStyle
from fastchat.constants import LOGDIR, WORKER_API_TIMEOUT, ErrorCode
from fastchat.constants import (
LOGDIR,
WORKER_API_TIMEOUT,
ErrorCode,
MODERATION_MSG,
CONVERSATION_LIMIT_MSG,
SERVER_ERROR_MSG,
INPUT_CHAR_LEN_LIMIT,
CONVERSATION_LEN_LIMIT,
)
from fastchat.model.model_adapter import get_conversation_template
from fastchat.model.model_registry import model_info
from fastchat.serve.api_provider import (
Expand All @@ -29,9 +38,7 @@
from fastchat.serve.gradio_css import code_highlight_css
from fastchat.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg,
get_window_url_params_js,
)

Expand Down Expand Up @@ -188,11 +195,20 @@ def add_text(state, text, request: gr.Request):
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (
return (state, state.to_gradio_chatbot(), MODERATION_MSG) + (
no_change_btn,
) * 5

text = text[:1536] # Hard cut-off
if (len(state.messages) - state.offset) // 2 >= CONVERSATION_LEN_LIMIT:
logger.info(
f"hit conversation length limit. ip: {request.client.host}. text: {text}"
)
state.skip_next = True
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
no_change_btn,
) * 5

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
Expand Down Expand Up @@ -299,7 +315,7 @@ def http_bot(

# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
state.messages[-1][-1] = SERVER_ERROR_MSG
yield (
state,
state.to_gradio_chatbot(),
Expand Down Expand Up @@ -346,7 +362,7 @@ def http_bot(
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = (
f"{server_error_msg}\n\n"
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
Expand All @@ -359,7 +375,7 @@ def http_bot(
return
except Exception as e:
state.messages[-1][-1] = (
f"{server_error_msg}\n\n"
f"{SERVER_ERROR_MSG}\n\n"
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
)
yield (state, state.to_gradio_chatbot()) + (
Expand Down Expand Up @@ -490,7 +506,7 @@ def build_single_model_ui(models):
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
minimum=16,
maximum=1024,
value=512,
step=64,
Expand Down
Loading

0 comments on commit e276c2f

Please sign in to comment.