From dc043d366eb617c1124a4cc2bff8bee1ee4d2969 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 6 Oct 2023 08:40:50 -0700 Subject: [PATCH] Jinja instruct templates --- modules/chat.py | 27 ++++++++++++++++++++++++++- modules/models.py | 2 +- modules/models_settings.py | 13 +++++++++++++ modules/ui.py | 1 + modules/ui_chat.py | 30 +++++++++++++++++------------- 5 files changed, 58 insertions(+), 15 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 334693ab04..11294bdd0e 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -9,6 +9,7 @@ import gradio as gr import yaml +from jinja2 import Template from PIL import Image import modules.shared as shared @@ -71,12 +72,36 @@ def get_turn_substrings(state, instruct=False): return output +def generate_chat_prompt_jinja(user_input, state, **kwargs): + history = kwargs.get('history', state['history'])['internal'] + + # Find the maximum prompt size + max_length = get_max_prompt_length(state) + + # Initialize the chat template and data + chat_template = Template(state['jinja_template']) + chat_data = {"messages": []} + + # Iterate through the list of lists and populate the chat_data dictionary + for row in history: + message, reply = row + chat_data["messages"].append({"role": "user", "content": message}) + chat_data["messages"].append({"role": "assistant", "content": reply}) + + chat_data["messages"].append({"role": "user", "content": user_input}) + + return chat_template.render(**chat_data) + + def generate_chat_prompt(user_input, state, **kwargs): + is_instruct = state['mode'] == 'instruct' + if is_instruct and state['jinja_template'] != '': + return generate_chat_prompt_jinja(user_input, state, **kwargs) + impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) also_return_rows = kwargs.get('also_return_rows', False) history = kwargs.get('history', state['history'])['internal'] - is_instruct = state['mode'] == 'instruct' # Find the maximum prompt size max_length = get_max_prompt_length(state) diff --git a/modules/models.py b/modules/models.py index db515636d8..4303b93769 100644 --- a/modules/models.py +++ b/modules/models.py @@ -90,7 +90,7 @@ def load_model(model_name, loader=None): if any((shared.args.xformers, shared.args.sdp_attention)): llama_attn_hijack.hijack_llama_attention() - logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n") + logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/modules/models_settings.py b/modules/models_settings.py index aecb7a89ab..da19c35f38 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -20,6 +20,7 @@ def get_fallback_settings(): 'truncation_length': shared.settings['truncation_length'], 'skip_special_tokens': shared.settings['skip_special_tokens'], 'custom_stopping_strings': shared.settings['custom_stopping_strings'], + 'jinja_template': '', } @@ -91,6 +92,18 @@ def get_model_metadata(model): if 'desc_act' in metadata: model_settings['desc_act'] = metadata['desc_act'] + # Try to find the Jinja instruct template + path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json' + if path.exists(): + metadata = json.loads(open(path, 'r').read()) + if 'chat_template' in metadata: + template = metadata['chat_template'] + for k in ['eos_token', 'bos_token']: + if k in metadata: + template = template.replace(k, "'{}'".format(metadata[k])) + + model_settings['jinja_template'] = template + # Apply user settings from models/config-user.yaml settings = shared.user_config for pat in settings: diff --git a/modules/ui.py b/modules/ui.py index 77e56e92fa..0502bc4e5a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -147,6 +147,7 @@ def list_interface_input_elements(): 'name2_instruct', 'context_instruct', 'turn_template', + 'jinja_template', 'chat_style', 'chat-instruct_command', ] diff --git a/modules/ui_chat.py b/modules/ui_chat.py index 88257989b1..0c4af1ce1c 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -106,24 +106,28 @@ def create_chat_settings_ui(): with gr.Tab('Instruction template'): with gr.Row(): - with gr.Row(): - shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown') - ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) - shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) - shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu) - - shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string') - shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string') - shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar']) - shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar']) + with gr.Column(): + with gr.Row(): + shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown') + ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) + shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) + shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu) + + shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string') + shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string') + shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar']) + shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar']) + + with gr.Column(): + shared.gradio['jinja_template'] = gr.Textbox(value='', lines=14, label='Hugging Face Jinja template', info='If set, the fields under \"Instruction template\" are ignored (except for the chat-instruct command).', elem_classes=['add_scrollbar', 'monospace']) + + shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar']) + with gr.Row(): shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button']) shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button']) shared.gradio['send_instruction_to_negative_prompt'] = gr.Button('Send to negative prompt', elem_classes=['small-button']) - with gr.Row(): - shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar']) - with gr.Tab('Chat history'): with gr.Row(): with gr.Column():