Skip to content

Commit

Permalink
Move folder code to class
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 3, 2023
1 parent 58809ac commit 305d13e
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 321 deletions.
131 changes: 21 additions & 110 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import pathlib
import argparse
from library.common_gui import (
get_folder_path,
remove_doublequote,
get_file_path,
get_any_file_path,
get_saveasfile_path,
Expand All @@ -29,6 +27,7 @@
from library.class_source_model import SourceModel
from library.class_basic_training import BasicTraining
from library.class_advanced_training import AdvancedTraining
from library.class_folders import Folders
from library.tensorboard_gui import (
gradio_tensorboard,
start_tensorboard,
Expand All @@ -45,13 +44,6 @@
# Set up logging
log = setup_logging()

# from easygui import msgbox

folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄


def save_configuration(
save_as,
Expand Down Expand Up @@ -653,10 +645,10 @@ def train_model(


def dreambooth_tab(
train_data_dir=gr.Textbox(),
reg_data_dir=gr.Textbox(),
output_dir=gr.Textbox(),
logging_dir=gr.Textbox(),
# train_data_dir=gr.Textbox(),
# reg_data_dir=gr.Textbox(),
# output_dir=gr.Textbox(),
# logging_dir=gr.Textbox(),
headless=False,
):
dummy_db_true = gr.Label(value=True, visible=False)
Expand All @@ -669,90 +661,9 @@ def dreambooth_tab(

source_model = SourceModel(headless=headless)

# (
# pretrained_model_name_or_path,
# v2,
# v_parameterization,
# sdxl,
# save_model_as,
# model_list,
# ) = gradio_source_model(headless=headless)

with gr.Tab('Folders'):
with gr.Row():
train_data_dir = gr.Textbox(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
train_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
train_data_dir_input_folder.click(
get_folder_path,
outputs=train_data_dir,
show_progress=False,
)
reg_data_dir = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
reg_data_dir_input_folder.click(
get_folder_path,
outputs=reg_data_dir,
show_progress=False,
)
with gr.Row():
output_dir = gr.Textbox(
label='Model output folder',
placeholder='Folder to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
output_dir_input_folder.click(get_folder_path, outputs=output_dir)
logging_dir = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not headless)
)
logging_dir_input_folder.click(
get_folder_path,
outputs=logging_dir,
show_progress=False,
)
with gr.Row():
output_name = gr.Textbox(
label='Model output name',
placeholder='Name of the model to output',
value='last',
interactive=True,
)
train_data_dir.change(
remove_doublequote,
inputs=[train_data_dir],
outputs=[train_data_dir],
)
reg_data_dir.change(
remove_doublequote,
inputs=[reg_data_dir],
outputs=[reg_data_dir],
)
output_dir.change(
remove_doublequote,
inputs=[output_dir],
outputs=[output_dir],
)
logging_dir.change(
remove_doublequote,
inputs=[logging_dir],
outputs=[logging_dir],
)
with gr.Tab('Training parameters'):
folders = Folders(headless=headless)
with gr.Tab('Parameters'):
basic_training = BasicTraining(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
Expand Down Expand Up @@ -818,10 +729,10 @@ def dreambooth_tab(
'This section provide Dreambooth tools to help setup your dataset...'
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=train_data_dir,
reg_data_dir_input=reg_data_dir,
output_dir_input=output_dir,
logging_dir_input=logging_dir,
train_data_dir_input=folders.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
)

Expand All @@ -834,7 +745,7 @@ def dreambooth_tab(

button_start_tensorboard.click(
start_tensorboard,
inputs=logging_dir,
inputs=folders.logging_dir,
show_progress=False,
)

Expand All @@ -848,10 +759,10 @@ def dreambooth_tab(
source_model.v2,
source_model.v_parameterization,
source_model.sdxl_checkbox,
logging_dir,
train_data_dir,
reg_data_dir,
output_dir,
folders.logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
max_resolution,
basic_training.learning_rate,
basic_training.lr_scheduler,
Expand Down Expand Up @@ -881,7 +792,7 @@ def dreambooth_tab(
advanced_training.flip_aug,
advanced_training.clip_skip,
vae,
output_name,
folders.output_name,
advanced_training.max_token_length,
advanced_training.max_train_epochs,
advanced_training.max_data_loader_n_workers,
Expand Down Expand Up @@ -961,10 +872,10 @@ def dreambooth_tab(
)

return (
train_data_dir,
reg_data_dir,
output_dir,
logging_dir,
folders.train_data_dir,
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)


Expand Down
89 changes: 89 additions & 0 deletions library/class_folders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import gradio as gr
from .common_gui import remove_doublequote, get_folder_path

class Folders:
def __init__(self, headless=False):
self.headless = headless

with gr.Row():
self.train_data_dir = gr.Textbox(
label='Image folder',
placeholder='Folder where the training folders containing the images are located',
)
self.train_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not self.headless)
)
self.train_data_dir_folder.click(
get_folder_path,
outputs=self.train_data_dir,
show_progress=False,
)
self.reg_data_dir = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
self.reg_data_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not self.headless)
)
self.reg_data_dir_folder.click(
get_folder_path,
outputs=self.reg_data_dir,
show_progress=False,
)
with gr.Row():
self.output_dir = gr.Textbox(
label='Output folder',
placeholder='Folder to output trained model',
)
self.output_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not self.headless)
)
self.output_dir_folder.click(
get_folder_path,
outputs=self.output_dir,
show_progress=False,
)
self.logging_dir = gr.Textbox(
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
self.logging_dir_folder = gr.Button(
'📂', elem_id='open_folder_small', visible=(not self.headless)
)
self.logging_dir_folder.click(
get_folder_path,
outputs=self.logging_dir,
show_progress=False,
)
with gr.Row():
self.output_name = gr.Textbox(
label='Model output name',
placeholder='(Name of the model to output)',
value='last',
interactive=True,
)
self.training_comment = gr.Textbox(
label='Training comment',
placeholder='(Optional) Add training comment to be included in metadata',
interactive=True,
)
self.train_data_dir.blur(
remove_doublequote,
inputs=[self.train_data_dir],
outputs=[self.train_data_dir],
)
self.reg_data_dir.blur(
remove_doublequote,
inputs=[self.reg_data_dir],
outputs=[self.reg_data_dir],
)
self.output_dir.blur(
remove_doublequote,
inputs=[self.output_dir],
outputs=[self.output_dir],
)
self.logging_dir.blur(
remove_doublequote,
inputs=[self.logging_dir],
outputs=[self.logging_dir],
)
Loading

0 comments on commit 305d13e

Please sign in to comment.