Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ def load_character(character, name1, name2):


def load_instruction_template(template):
if template == 'None':
return ''

for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
if filepath.exists():
break
Expand Down
61 changes: 44 additions & 17 deletions modules/models_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,27 +243,54 @@ def save_model_settings(model, state):
Save the settings for this model to models/config-user.yaml
'''
if model == 'None':
yield ("Not saving the settings because no model is loaded.")
yield ("Not saving the settings because no model is selected in the menu.")
return

with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
user_config = shared.load_user_config()
model_regex = model + '$' # For exact matches
if model_regex not in user_config:
user_config[model_regex] = {}

for k in ui.list_model_elements():
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
user_config[model_regex][k] = state[k]

model_regex = model + '$' # For exact matches
if model_regex not in user_config:
user_config[model_regex] = {}
shared.user_config = user_config

for k in ui.list_model_elements():
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
user_config[model_regex][k] = state[k]
output = yaml.dump(user_config, sort_keys=False)
p = Path(f'{shared.args.model_dir}/config-user.yaml')
with open(p, 'w') as f:
f.write(output)

shared.user_config = user_config
yield (f"Settings for `{model}` saved to `{p}`.")

output = yaml.dump(user_config, sort_keys=False)
with open(p, 'w') as f:
f.write(output)

yield (f"Settings for `{model}` saved to `{p}`.")
def save_instruction_template(model, template):
'''
Similar to the function above, but it saves only the instruction template.
'''
if model == 'None':
yield ("Not saving the template because no model is selected in the menu.")
return

user_config = shared.load_user_config()
model_regex = model + '$' # For exact matches
if model_regex not in user_config:
user_config[model_regex] = {}

if template == 'None':
user_config[model_regex].pop('instruction_template', None)
else:
user_config[model_regex]['instruction_template'] = template

shared.user_config = user_config

output = yaml.dump(user_config, sort_keys=False)
p = Path(f'{shared.args.model_dir}/config-user.yaml')
with open(p, 'w') as f:
f.write(output)

if template == 'None':
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
else:
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")
23 changes: 18 additions & 5 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,23 @@ def is_chat():
return True


def load_user_config():
'''
Loads custom model-specific settings
'''
if Path(f'{args.model_dir}/config-user.yaml').exists():
file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip()

if file_content:
user_config = yaml.safe_load(file_content)
else:
user_config = {}
else:
user_config = {}

return user_config


args.loader = fix_loader_name(args.loader)

# Activate the multimodal extension
Expand All @@ -297,11 +314,7 @@ def is_chat():
model_config = {}

# Load custom model-specific settings
with Path(f'{args.model_dir}/config-user.yaml') as p:
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
user_config = load_user_config()

model_config = OrderedDict(model_config)
user_config = OrderedDict(user_config)
2 changes: 1 addition & 1 deletion modules/ui_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def create_chat_settings_ui():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='Select template to load...', elem_classes='slim-dropdown')
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='None', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button')
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
Expand Down
14 changes: 14 additions & 0 deletions modules/ui_model_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modules.models_settings import (
apply_model_settings_to_state,
get_model_metadata,
save_instruction_template,
save_model_settings,
update_model_parameters
)
Expand Down Expand Up @@ -165,6 +166,14 @@ def create_ui():
shared.gradio['create_llamacpp_hf_button'] = gr.Button("Submit", variant="primary", interactive=not mu)
gr.Markdown("This will move your gguf file into a subfolder of `models` along with the necessary tokenizer files.")

with gr.Tab("Customize instruction template"):
with gr.Row():
shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)

shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu)
gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenver the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.")

with gr.Row():
shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready')

Expand Down Expand Up @@ -214,6 +223,7 @@ def create_event_handlers():
shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), gradio('autoload_model'), gradio('load_model'))
shared.gradio['create_llamacpp_hf_button'].click(create_llamacpp_hf, gradio('gguf_menu', 'unquantized_url'), gradio('model_status'), show_progress=True)
shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True)


def load_model_wrapper(selected_model, loader, autoload=False):
Expand Down Expand Up @@ -320,3 +330,7 @@ def update_truncation_length(current_length, state):
return state['n_ctx']

return current_length


def save_model_template(model, template):
pass
2 changes: 1 addition & 1 deletion modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_available_instruction_templates():
if os.path.exists(path):
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))

return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys)
return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)


def get_available_extensions():
Expand Down