Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import gradio as gr
import yaml
from jinja2 import Template
from PIL import Image

import modules.shared as shared
Expand Down Expand Up @@ -71,12 +72,36 @@ def get_turn_substrings(state, instruct=False):
return output


def generate_chat_prompt_jinja(user_input, state, **kwargs):
history = kwargs.get('history', state['history'])['internal']

# Find the maximum prompt size
max_length = get_max_prompt_length(state)

# Initialize the chat template and data
chat_template = Template(state['jinja_template'])
chat_data = {"messages": []}

# Iterate through the list of lists and populate the chat_data dictionary
for row in history:
message, reply = row
chat_data["messages"].append({"role": "user", "content": message})
chat_data["messages"].append({"role": "assistant", "content": reply})

chat_data["messages"].append({"role": "user", "content": user_input})

return chat_template.render(**chat_data)


def generate_chat_prompt(user_input, state, **kwargs):
is_instruct = state['mode'] == 'instruct'
if is_instruct and state['jinja_template'] != '':
return generate_chat_prompt_jinja(user_input, state, **kwargs)

impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', state['history'])['internal']
is_instruct = state['mode'] == 'instruct'

# Find the maximum prompt size
max_length = get_max_prompt_length(state)
Expand Down
2 changes: 1 addition & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_model(model_name, loader=None):
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()

logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer


Expand Down
13 changes: 13 additions & 0 deletions modules/models_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_fallback_settings():
'truncation_length': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'jinja_template': '',
}


Expand Down Expand Up @@ -91,6 +92,18 @@ def get_model_metadata(model):
if 'desc_act' in metadata:
model_settings['desc_act'] = metadata['desc_act']

# Try to find the Jinja instruct template
path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json'
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'chat_template' in metadata:
template = metadata['chat_template']
for k in ['eos_token', 'bos_token']:
if k in metadata:
template = template.replace(k, "'{}'".format(metadata[k]))

model_settings['jinja_template'] = template

# Apply user settings from models/config-user.yaml
settings = shared.user_config
for pat in settings:
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def list_interface_input_elements():
'name2_instruct',
'context_instruct',
'turn_template',
'jinja_template',
'chat_style',
'chat-instruct_command',
]
Expand Down
30 changes: 17 additions & 13 deletions modules/ui_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,28 @@ def create_chat_settings_ui():

with gr.Tab('Instruction template'):
with gr.Row():
with gr.Row():
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu)

shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar'])
shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar'])
with gr.Column():
with gr.Row():
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu)

shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar'])
shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar'])

with gr.Column():
shared.gradio['jinja_template'] = gr.Textbox(value='', lines=14, label='Hugging Face Jinja template', info='If set, the fields under \"Instruction template\" are ignored (except for the chat-instruct command).', elem_classes=['add_scrollbar', 'monospace'])

shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar'])

with gr.Row():
shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button'])
shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button'])
shared.gradio['send_instruction_to_negative_prompt'] = gr.Button('Send to negative prompt', elem_classes=['small-button'])

with gr.Row():
shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.', elem_classes=['add_scrollbar'])

with gr.Tab('Chat history'):
with gr.Row():
with gr.Column():
Expand Down