Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion extensions/openai/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ def chat_streaming_chunk(content):
yield chat_streaming_chunk('')

# generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params)
try:
prompt = generate_chat_prompt(user_input, generate_params, except_on_truncation_fail=True)
except ValueError as e:
raise InvalidRequestError(message=str(e), param='messages')

token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params})

Expand Down
2 changes: 1 addition & 1 deletion extensions/openai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def load_embedding_model(model: str):
print(f"Loaded embedding model: {model}")
except Exception as e:
embeddings_model = None
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model} with exception {e}")


def get_embeddings_model():
Expand Down
34 changes: 17 additions & 17 deletions extensions/openai/errors.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
class OpenAIError(Exception):
def __init__(self, message=None, code=500, internal_message=''):
from fastapi import HTTPException

class OpenAIError(HTTPException):
def __init__(self, message, param, code):
body = {
'message' : message,
'code': code,
'param' : param
}
super().__init__(status_code=code, detail=body)
self.message = message
self.code = code
self.internal_message = internal_message

def __repr__(self):
return "%s(message=%r, code=%d)" % (
self.__class__.__name__,
self.message,
self.code,
)


class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code=400, internal_message=''):
super().__init__(message, code, internal_message)
self.param = param

def __repr__(self):
Expand All @@ -26,6 +21,11 @@ def __repr__(self):
)


class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code=400):
super().__init__(message, param, code)


class ServiceUnavailableError(OpenAIError):
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''):
super().__init__(message, code, internal_message)
def __init__(self, message="Service unavailable, please try again later.", param='', code=503):
super().__init__(message, param, code)
2 changes: 1 addition & 1 deletion extensions/openai/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
r = response.json()
if response.status_code != 200 or 'images' not in r:
print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code)
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':
Expand Down
79 changes: 68 additions & 11 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', state['history'])['internal']
except_on_truncation_fail = kwargs.get('except_on_truncation_fail', False)

# Templates
chat_template = jinja_env.from_string(state['chat_template_str'])
Expand Down Expand Up @@ -166,18 +167,74 @@ def make_prompt(messages):
prompt = remove_extra_bos(prompt)
return prompt

def try_truncate_prompt(prompt):
# Handle truncation
max_length = get_max_prompt_length(state)
encoded_length = get_encoded_length(prompt)
if len(messages) > 0 and encoded_length > max_length:
orig_msg = messages
orig_enc_len = encoded_length
truncation_failed = False
truncation_retries = 0
other_content_len = 0
user_idx = -1

# Find the index of the user message, and calculate the length of the other messages
for i in range(len(messages)):
if messages[i]['role'] == 'user':
user_idx = i
else:
other_content_len += len(messages[i]['content'])

# If there are no user messages, fail
if user_idx == -1:
logger.warn(f'Truncation failed with no user messages')
truncation_failed = True

_verbose_log = shared.args.verbose

# Loop until we have a successful truncation
while encoded_length > max_length and not truncation_failed:
truncation_retries += 1
ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length
user_content_len = max(len(messages[user_idx]['content']), 1)
if _verbose_log:
logger.debug(f'Truncating user message attemp #{truncation_retries} by {ratio_truncation_start* 100:.2f}% from len {user_content_len}')

# Scale the truncation start based on the length of the user messages relative to the total length
ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len
if _verbose_log:
logger.debug(f'Ratio adjust to {ratio_truncation_start* 100:.2f}% after factoring in heuristics')

if ratio_truncation_start >= 1.0 or truncation_retries >= 0:
logger.warn(f'Truncation failed with ratio {ratio_truncation_start}, retries {truncation_retries}')
truncation_failed = True

if not truncation_failed:
# Truncate the user messages from the start
messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):]
if _verbose_log:
logger.debug(f'to len {len(messages[user_idx]["content"])}')

# Re-render the prompt
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)

if truncation_failed:
ctx_len = state['truncation_length']
completion_len = state['max_new_tokens']
# Error message similar to openai's error message on exceeding max context length
err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion."
logger.error(f'{err_msg}')
if _verbose_log:
logger.error(f'Messages failed from truncation:\n {orig_msg}')
if except_on_truncation_fail:
raise ValueError(err_msg)

return prompt

prompt = make_prompt(messages)

# Handle truncation
max_length = get_max_prompt_length(state)
while len(messages) > 0 and get_encoded_length(prompt) > max_length:
# Try to save the system message
if len(messages) > 1 and messages[0]['role'] == 'system':
messages.pop(1)
else:
messages.pop(0)

prompt = make_prompt(messages)
prompt = try_truncate_prompt(prompt)

if also_return_rows:
return prompt, [message['content'] for message in messages]
Expand Down