From 92925a1a3230b908aa6118c56d26b60430223c05 Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Sun, 24 Dec 2023 04:14:05 +0000 Subject: [PATCH 1/5] Raise exception when prompt exceed truncation --- modules/chat.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 5380f1ac8e..ef6ed077a7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -170,14 +170,11 @@ def make_prompt(messages): # Handle truncation max_length = get_max_prompt_length(state) - while len(messages) > 0 and get_encoded_length(prompt) > max_length: - # Try to save the system message - if len(messages) > 1 and messages[0]['role'] == 'system': - messages.pop(1) - else: - messages.pop(0) - - prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + if len(messages) > 0 and encoded_length > max_length: + if shared.args.verbose: + logger.error(f'messages {messages}') + raise ValueError(f'Prompt encoded_length {encoded_length} > max_length {max_length}') if also_return_rows: return prompt, [message['content'] for message in messages] From 69f1efe79bc3c6fc3c029b3550867f86c170e996 Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Sun, 24 Dec 2023 11:26:59 +0000 Subject: [PATCH 2/5] Throw HttpError so that openai client can resolve to openaierror; handle user prompt truncation --- extensions/openai/completions.py | 6 ++++- extensions/openai/embeddings.py | 2 +- extensions/openai/errors.py | 35 +++++++++++++------------- extensions/openai/images.py | 2 +- modules/chat.py | 43 +++++++++++++++++++++++++++++--- 5 files changed, 65 insertions(+), 23 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index c5ef36e650..8f278b7737 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -293,7 +293,11 @@ def chat_streaming_chunk(content): yield chat_streaming_chunk('') # generate reply ####################################### - prompt = generate_chat_prompt(user_input, generate_params) + try: + prompt = generate_chat_prompt(user_input, generate_params) + except ValueError as e: + raise InvalidRequestError(message=str(e), param=[user_input, generate_params]) + token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index 1420879cc9..30472a5877 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -51,7 +51,7 @@ def load_embedding_model(model: str): print(f"Loaded embedding model: {model}") except Exception as e: embeddings_model = None - raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e)) + raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model} with exception {e}") def get_embeddings_model(): diff --git a/extensions/openai/errors.py b/extensions/openai/errors.py index 838d1e7cc6..7777f254b8 100644 --- a/extensions/openai/errors.py +++ b/extensions/openai/errors.py @@ -1,20 +1,16 @@ -class OpenAIError(Exception): - def __init__(self, message=None, code=500, internal_message=''): +import openai +from fastapi import HTTPException + +class OpenAIError(HTTPException): + def __init__(self, message, param, code): + body = { + 'message' : message, + 'code': code, + 'param' : param + } + super().__init__(status_code=code, detail=body) self.message = message self.code = code - self.internal_message = internal_message - - def __repr__(self): - return "%s(message=%r, code=%d)" % ( - self.__class__.__name__, - self.message, - self.code, - ) - - -class InvalidRequestError(OpenAIError): - def __init__(self, message, param, code=400, internal_message=''): - super().__init__(message, code, internal_message) self.param = param def __repr__(self): @@ -26,6 +22,11 @@ def __repr__(self): ) +class InvalidRequestError(OpenAIError): + def __init__(self, message, param, code=400): + super().__init__(message, param, code) + + class ServiceUnavailableError(OpenAIError): - def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''): - super().__init__(message, code, internal_message) + def __init__(self, message="Service unavailable, please try again later.", param='', code=503): + super().__init__(message, param, code) diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 92bd85f08b..2e02b97bf6 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -59,7 +59,7 @@ def generations(prompt: str, size: str, response_format: str, n: int): r = response.json() if response.status_code != 200 or 'images' not in r: print(r) - raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None)) + raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code) # r['parameters']... for b64_json in r['images']: if response_format == 'b64_json': diff --git a/modules/chat.py b/modules/chat.py index ef6ed077a7..d80b6ce1ba 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -172,9 +172,46 @@ def make_prompt(messages): max_length = get_max_prompt_length(state) encoded_length = get_encoded_length(prompt) if len(messages) > 0 and encoded_length > max_length: - if shared.args.verbose: - logger.error(f'messages {messages}') - raise ValueError(f'Prompt encoded_length {encoded_length} > max_length {max_length}') + orig_msg = messages + orig_enc_len = encoded_length + truncation_failed = False + truncation_retries = 0 + + while encoded_length > max_length and not truncation_failed: + truncation_retries += 1 + ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length + user_content_len = 1 + other_content_len = 0 + + # Count the length of the user messages and other roles messages + for i in range(len(messages)): + if messages[i]['role'] == 'user': + user_content_len = len(messages[i]['content']) + else: + other_content_len += len(messages[i]['content']) + # Scale the truncation start based on the length of the user messages relative to the total length + ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len + + if ratio_truncation_start >= 1.0 or truncation_retries >= 5: + logger.warn(f'Truncation failed with ratio_truncation_start {ratio_truncation_start}, truncation_retries {truncation_retries}') + truncation_failed = True + + if not truncation_failed: + # Truncate the user messages from the start + for i in range(len(messages)): + if messages[i]['role'] == 'user': + if shared.args.verbose: + logger.debug(f'Truncating user message from len {len(messages[i]["content"])}') + messages[i]['content'] = messages[i]['content'][int(user_content_len * ratio_truncation_start):] + if shared.args.verbose: + logger.debug(f'to len {len(messages[i]["content"])}') + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + if truncation_failed: + if shared.args.verbose: + logger.error(f'Prompt encoded_length {orig_enc_len} > max_length {max_length} \nMessages failed from truncation:\n {orig_msg}') + raise ValueError(f'Prompt encoded_length {orig_enc_len} > max_length {max_length}') if also_return_rows: return prompt, [message['content'] for message in messages] From d5cae16a92b2b8dad42df9c24de1adf47add5d8a Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Tue, 26 Dec 2023 04:04:53 +0000 Subject: [PATCH 3/5] Improve truncation code; Emit context length exceeding error similar to OpenAI --- extensions/openai/errors.py | 1 - modules/chat.py | 48 +++++++++++++++++++++++-------------- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/extensions/openai/errors.py b/extensions/openai/errors.py index 7777f254b8..cf1fbc534a 100644 --- a/extensions/openai/errors.py +++ b/extensions/openai/errors.py @@ -1,4 +1,3 @@ -import openai from fastapi import HTTPException class OpenAIError(HTTPException): diff --git a/modules/chat.py b/modules/chat.py index d80b6ce1ba..09ddbd5844 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -176,19 +176,27 @@ def make_prompt(messages): orig_enc_len = encoded_length truncation_failed = False truncation_retries = 0 + other_content_len = 0 + user_idx = -1 + # Find the index of the user message, and calculate the length of the other messages + for i in range(len(messages)): + if messages[i]['role'] == 'user': + user_idx = i + else: + other_content_len += len(messages[i]['content']) + + # If there are no user messages, fail + if user_idx == -1: + logger.warn(f'Truncation failed with no user messages') + truncation_failed = True + + # Loop until we have a successful truncation while encoded_length > max_length and not truncation_failed: truncation_retries += 1 ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length - user_content_len = 1 - other_content_len = 0 - - # Count the length of the user messages and other roles messages - for i in range(len(messages)): - if messages[i]['role'] == 'user': - user_content_len = len(messages[i]['content']) - else: - other_content_len += len(messages[i]['content']) + user_content_len = max(len(messages[user_idx]['content']), 1) + # Scale the truncation start based on the length of the user messages relative to the total length ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len @@ -197,21 +205,25 @@ def make_prompt(messages): truncation_failed = True if not truncation_failed: + if shared.args.verbose: + logger.debug(f'Truncating user message from len {len(messages[user_idx]["content"])}') # Truncate the user messages from the start - for i in range(len(messages)): - if messages[i]['role'] == 'user': - if shared.args.verbose: - logger.debug(f'Truncating user message from len {len(messages[i]["content"])}') - messages[i]['content'] = messages[i]['content'][int(user_content_len * ratio_truncation_start):] - if shared.args.verbose: - logger.debug(f'to len {len(messages[i]["content"])}') + messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):] + if shared.args.verbose: + logger.debug(f'to len {len(messages[user_idx]["content"])}') + + # Re-render the prompt prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) if truncation_failed: + ctx_len = state['truncation_length'] + completion_len = state['max_new_tokens'] + # Error message similar to openai's error message on exceeding max context length + err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion." if shared.args.verbose: - logger.error(f'Prompt encoded_length {orig_enc_len} > max_length {max_length} \nMessages failed from truncation:\n {orig_msg}') - raise ValueError(f'Prompt encoded_length {orig_enc_len} > max_length {max_length}') + logger.error(f'{err_msg} \nMessages failed from truncation:\n {orig_msg}') + raise ValueError(err_msg) if also_return_rows: return prompt, [message['content'] for message in messages] From 60722012d32a748120edb7fa0c468ff61024ec7f Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Thu, 28 Dec 2023 17:26:52 +0000 Subject: [PATCH 4/5] Improve verboise logging for user input truncation --- modules/chat.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 09ddbd5844..216543ad2d 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -190,26 +190,30 @@ def make_prompt(messages): if user_idx == -1: logger.warn(f'Truncation failed with no user messages') truncation_failed = True + + _verbose_log = shared.args.verbose # Loop until we have a successful truncation while encoded_length > max_length and not truncation_failed: truncation_retries += 1 ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length user_content_len = max(len(messages[user_idx]['content']), 1) + if _verbose_log: + logger.debug(f'Truncating user message attemp #{truncation_retries} by {ratio_truncation_start* 100:.2f}% from len {user_content_len}') # Scale the truncation start based on the length of the user messages relative to the total length ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len + if _verbose_log: + logger.debug(f'Ratio adjust to {ratio_truncation_start* 100:.2f}% after factoring in heuristics') if ratio_truncation_start >= 1.0 or truncation_retries >= 5: - logger.warn(f'Truncation failed with ratio_truncation_start {ratio_truncation_start}, truncation_retries {truncation_retries}') + logger.warn(f'Truncation failed with ratio {ratio_truncation_start}, retries {truncation_retries}') truncation_failed = True if not truncation_failed: - if shared.args.verbose: - logger.debug(f'Truncating user message from len {len(messages[user_idx]["content"])}') # Truncate the user messages from the start messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):] - if shared.args.verbose: + if _verbose_log: logger.debug(f'to len {len(messages[user_idx]["content"])}') # Re-render the prompt @@ -221,7 +225,7 @@ def make_prompt(messages): completion_len = state['max_new_tokens'] # Error message similar to openai's error message on exceeding max context length err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion." - if shared.args.verbose: + if _verbose_log: logger.error(f'{err_msg} \nMessages failed from truncation:\n {orig_msg}') raise ValueError(err_msg) From 2ffbf39009c97aa72676d62478b6ff169ab3240e Mon Sep 17 00:00:00 2001 From: yhyu13 Date: Fri, 5 Jan 2024 15:07:43 +0000 Subject: [PATCH 5/5] Relax excpetion throwing on truncation fail with kwarg; Move truncation logic to method --- extensions/openai/completions.py | 4 +- modules/chat.py | 117 ++++++++++++++++--------------- 2 files changed, 64 insertions(+), 57 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 8f278b7737..0bcca9f3b4 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -294,9 +294,9 @@ def chat_streaming_chunk(content): # generate reply ####################################### try: - prompt = generate_chat_prompt(user_input, generate_params) + prompt = generate_chat_prompt(user_input, generate_params, except_on_truncation_fail=True) except ValueError as e: - raise InvalidRequestError(message=str(e), param=[user_input, generate_params]) + raise InvalidRequestError(message=str(e), param='messages') token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) diff --git a/modules/chat.py b/modules/chat.py index 216543ad2d..917aa7598b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -80,6 +80,7 @@ def generate_chat_prompt(user_input, state, **kwargs): _continue = kwargs.get('_continue', False) also_return_rows = kwargs.get('also_return_rows', False) history = kwargs.get('history', state['history'])['internal'] + except_on_truncation_fail = kwargs.get('except_on_truncation_fail', False) # Templates chat_template = jinja_env.from_string(state['chat_template_str']) @@ -166,68 +167,74 @@ def make_prompt(messages): prompt = remove_extra_bos(prompt) return prompt - prompt = make_prompt(messages) - - # Handle truncation - max_length = get_max_prompt_length(state) - encoded_length = get_encoded_length(prompt) - if len(messages) > 0 and encoded_length > max_length: - orig_msg = messages - orig_enc_len = encoded_length - truncation_failed = False - truncation_retries = 0 - other_content_len = 0 - user_idx = -1 - - # Find the index of the user message, and calculate the length of the other messages - for i in range(len(messages)): - if messages[i]['role'] == 'user': - user_idx = i - else: - other_content_len += len(messages[i]['content']) - - # If there are no user messages, fail - if user_idx == -1: - logger.warn(f'Truncation failed with no user messages') - truncation_failed = True + def try_truncate_prompt(prompt): + # Handle truncation + max_length = get_max_prompt_length(state) + encoded_length = get_encoded_length(prompt) + if len(messages) > 0 and encoded_length > max_length: + orig_msg = messages + orig_enc_len = encoded_length + truncation_failed = False + truncation_retries = 0 + other_content_len = 0 + user_idx = -1 - _verbose_log = shared.args.verbose - - # Loop until we have a successful truncation - while encoded_length > max_length and not truncation_failed: - truncation_retries += 1 - ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length - user_content_len = max(len(messages[user_idx]['content']), 1) - if _verbose_log: - logger.debug(f'Truncating user message attemp #{truncation_retries} by {ratio_truncation_start* 100:.2f}% from len {user_content_len}') - - # Scale the truncation start based on the length of the user messages relative to the total length - ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len - if _verbose_log: - logger.debug(f'Ratio adjust to {ratio_truncation_start* 100:.2f}% after factoring in heuristics') + # Find the index of the user message, and calculate the length of the other messages + for i in range(len(messages)): + if messages[i]['role'] == 'user': + user_idx = i + else: + other_content_len += len(messages[i]['content']) - if ratio_truncation_start >= 1.0 or truncation_retries >= 5: - logger.warn(f'Truncation failed with ratio {ratio_truncation_start}, retries {truncation_retries}') + # If there are no user messages, fail + if user_idx == -1: + logger.warn(f'Truncation failed with no user messages') truncation_failed = True - if not truncation_failed: - # Truncate the user messages from the start - messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):] + _verbose_log = shared.args.verbose + + # Loop until we have a successful truncation + while encoded_length > max_length and not truncation_failed: + truncation_retries += 1 + ratio_truncation_start = 1.0 - 0.95 * max_length / encoded_length + user_content_len = max(len(messages[user_idx]['content']), 1) if _verbose_log: - logger.debug(f'to len {len(messages[user_idx]["content"])}') + logger.debug(f'Truncating user message attemp #{truncation_retries} by {ratio_truncation_start* 100:.2f}% from len {user_content_len}') - # Re-render the prompt - prompt = make_prompt(messages) - encoded_length = get_encoded_length(prompt) + # Scale the truncation start based on the length of the user messages relative to the total length + ratio_truncation_start *= 1.0 * (user_content_len + other_content_len) / user_content_len + if _verbose_log: + logger.debug(f'Ratio adjust to {ratio_truncation_start* 100:.2f}% after factoring in heuristics') + + if ratio_truncation_start >= 1.0 or truncation_retries >= 0: + logger.warn(f'Truncation failed with ratio {ratio_truncation_start}, retries {truncation_retries}') + truncation_failed = True + + if not truncation_failed: + # Truncate the user messages from the start + messages[user_idx]['content'] = messages[user_idx]['content'][int(user_content_len * ratio_truncation_start):] + if _verbose_log: + logger.debug(f'to len {len(messages[user_idx]["content"])}') + + # Re-render the prompt + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + if truncation_failed: + ctx_len = state['truncation_length'] + completion_len = state['max_new_tokens'] + # Error message similar to openai's error message on exceeding max context length + err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion." + logger.error(f'{err_msg}') + if _verbose_log: + logger.error(f'Messages failed from truncation:\n {orig_msg}') + if except_on_truncation_fail: + raise ValueError(err_msg) - if truncation_failed: - ctx_len = state['truncation_length'] - completion_len = state['max_new_tokens'] - # Error message similar to openai's error message on exceeding max context length - err_msg = f"This model's maximum context length is {ctx_len} tokens. However, you requested {orig_enc_len + completion_len} tokens ({orig_enc_len} in the messages, {completion_len} in the completion). Please reduce the length of the messages or completion." - if _verbose_log: - logger.error(f'{err_msg} \nMessages failed from truncation:\n {orig_msg}') - raise ValueError(err_msg) + return prompt + + prompt = make_prompt(messages) + prompt = try_truncate_prompt(prompt) if also_return_rows: return prompt, [message['content'] for message in messages]