Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,53 @@ def make_prompt(messages):
prompt = remove_extra_bos(prompt)
return prompt

prompt = make_prompt(messages)

# Handle truncation
max_length = get_max_prompt_length(state)
while len(messages) > 0 and get_encoded_length(prompt) > max_length:
# Try to save the system message
if len(messages) > 1 and messages[0]['role'] == 'system':
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)

while len(messages) > 0 and encoded_length > max_length:

# Remove old message, save system message
if len(messages) > 2 and messages[0]['role'] == 'system':
messages.pop(1)
else:

# Remove old message when no system message is present
elif len(messages) > 1 and messages[0]['role'] != 'system':
messages.pop(0)

# Resort to truncating the user input
else:

user_message = messages[-1]['content']

# Bisect the truncation point
left, right = 0, len(user_message) - 1

while right - left > 1:
mid = (left + right) // 2

messages[-1]['content'] = user_message[mid:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)

if encoded_length <= max_length:
right = mid
else:
left = mid

messages[-1]['content'] = user_message[right:]
prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)
if encoded_length > max_length:
logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n")
raise ValueError
else:
logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}.")
break

prompt = make_prompt(messages)
encoded_length = get_encoded_length(prompt)

if also_return_rows:
return prompt, [message['content'] for message in messages]
Expand Down
13 changes: 9 additions & 4 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
else:
generate_func = generate_reply_HF

if generate_func != generate_reply_HF and shared.args.verbose:
logger.info("PROMPT=")
print(question)
print()

# Prepare the input
original_question = question
if not is_chat:
Expand All @@ -65,10 +70,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
if type(st) is list and len(st) > 0:
all_stop_strings += st

if shared.args.verbose:
logger.info("PROMPT=")
print(question)

shared.stop_everything = False
clear_torch_cache()
seed = set_manual_seed(state['seed'])
Expand Down Expand Up @@ -355,6 +356,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params)
print()

logger.info("PROMPT=")
print(decode(input_ids[0], skip_special_tokens=False))
print()

t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq:
Expand Down