diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index c5ef36e650..0bcca9f3b4 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -293,7 +293,11 @@ def chat_streaming_chunk(content): yield chat_streaming_chunk('') # generate reply ####################################### - prompt = generate_chat_prompt(user_input, generate_params) + try: + prompt = generate_chat_prompt(user_input, generate_params, except_on_truncation_fail=True) + except ValueError as e: + raise InvalidRequestError(message=str(e), param='messages') + token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index 1420879cc9..30472a5877 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -51,7 +51,7 @@ def load_embedding_model(model: str): print(f"Loaded embedding model: {model}") except Exception as e: embeddings_model = None - raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) + raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model} with exception {e}") def get_embeddings_model(): diff --git a/extensions/openai/errors.py b/extensions/openai/errors.py index 838d1e7cc6..cf1fbc534a 100644 --- a/extensions/openai/errors.py +++ b/extensions/openai/errors.py @@ -1,20 +1,15 @@ -class OpenAIError(Exception): - def __init__(self, message=None, code=500, internal_message=''): +from fastapi import HTTPException + +class OpenAIError(HTTPException): + def __init__(self, message, param, code): + body = { + 'message' : message, + 'code': code, + 'param' : param + } + super().__init__(status_code=code, detail=body) self.message = message self.code = code - self.internal_message = internal_message - - def __repr__(self): - return "%s(message=%r, code=%d)" % ( - self.__class__.__name__, - self.message, - self.code, - ) - - -class InvalidRequestError(OpenAIError): - def __init__(self, message, param, code=400, internal_message=''): - super().__init__(message, code, internal_message) self.param = param def __repr__(self): @@ -26,6 +21,11 @@ def __repr__(self): ) +class InvalidRequestError(OpenAIError): + def __init__(self, message, param, code=400): + super().__init__(message, param, code) + + class ServiceUnavailableError(OpenAIError): - def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''): - super().__init__(message, code, internal_message) + def __init__(self, message="Service unavailable, please try again later.", param='', code=503): + super().__init__(message, param, code) diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 92bd85f08b..2e02b97bf6 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -59,7 +59,7 @@ def generations(prompt: str, size: str, response_format: str, n: int): r = response.json() if response.status_code != 200 or 'images' not in r: print(r) - raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None)) + raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code) # r['parameters']... for b64_json in r['images']: if response_format == 'b64_json': diff --git a/modules/chat.py b/modules/chat.py index 5380f1ac8e..917aa7598b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -80,6 +80,7 @@ def generate_chat_prompt(user_input, state, **kwargs): _continue = kwargs.get('_continue', False) also_return_rows = kwargs.get('also_return_rows', False) history = kwargs.get('history', state['history'])['internal'] + except_on_truncation_fail = kwargs.get('except_on_truncation_fail', False) # Templates chat_template = jinja_env.from_string(state['chat_template_str']) @@ -166,18 +167,74 @@ def make_prompt(messages): prompt = remove_extra_bos(prompt) return prompt + def try_truncate_prompt(prompt): + # Handle truncation + max_length = get_max_prompt_length(state) + encoded_length = get_encoded_length(prompt) + if len(messages) > 0 and encoded_length > max_length: + orig_msg = messages + orig_enc_len = encoded_length + truncation_failed = False + truncation_retries = 0 + other_content_len = 0 + user_idx = -1 + + # Find the index of the user message, and calculate the length of the other messages + for i in range(len(messages)): + if messages[i]['role'] == 'user': + user_idx = i + else: + other_content_len += len(messages[i]['content']) + + # If there are no user messages, fail + if user_idx == -1: + logger.warn(f'Truncation failed with no user messages') + truncation_failed = True + + _verbose_log = shared.args.verbose + + # Loop until we have a successful truncation + while encoded_length > max_length and not truncation_failed: + truncation_retries += 1 + ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length + user_content_len = max(len(messages[user_idx]['content']), 1) + if _verbose_log: + logger.debug(f'Truncating user message attemp #{truncation_retries} by {ratio_truncation_start* 100:.2f}% from len {user_content_len}') + + # Scale the truncation start based on the length of the user messages relative to the total length + ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len + if _verbose_log: + logger.debug(f'Ratio adjust to {ratio_truncation_start* 100:.2f}% after factoring in heuristics') + + if ratio_truncation_start >= 1.0 or truncation_retries >= 0: + logger.warn(f'Truncation failed with ratio {ratio_truncation_start}, retries {truncation_retries}') + truncation_failed = True + + if not truncation_failed: + # Truncate the user messages from the start + messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):] + if _verbose_log: + logger.debug(f'to len {len(messages[user_idx]["content"])}') + + # Re-render the prompt + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + if truncation_failed: + ctx_len = state['truncation_length'] + completion_len = state['max_new_tokens'] + # Error message similar to openai's error message on exceeding max context length + err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion." + logger.error(f'{err_msg}') + if _verbose_log: + logger.error(f'Messages failed from truncation:\n {orig_msg}') + if except_on_truncation_fail: + raise ValueError(err_msg) + + return prompt + prompt = make_prompt(messages) - - # Handle truncation - max_length = get_max_prompt_length(state) - while len(messages) > 0 and get_encoded_length(prompt) > max_length: - # Try to save the system message - if len(messages) > 1 and messages[0]['role'] == 'system': - messages.pop(1) - else: - messages.pop(0) - - prompt = make_prompt(messages) + prompt = try_truncate_prompt(prompt) if also_return_rows: return prompt, [message['content'] for message in messages]