Skip to content

Commit

Permalink
Refactor Image Sampler code
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 3, 2023
1 parent 305d13e commit 22dfa09
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 48 deletions.
17 changes: 6 additions & 11 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
gradio_dreambooth_folder_creation_tab,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.class_sample_images import SampleImages, run_cmd_sample

from library.custom_logging import setup_logging

Expand Down Expand Up @@ -717,12 +717,7 @@ def dreambooth_tab(
outputs=[basic_training.cache_latents],
)

(
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
) = sample_gradio_config()
sample = SampleImages()

with gr.Tab('Tools'):
gr.Markdown(
Expand Down Expand Up @@ -813,10 +808,10 @@ def dreambooth_tab(
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
Expand Down
17 changes: 6 additions & 11 deletions finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
stop_tensorboard,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.class_sample_images import SampleImages, run_cmd_sample

from library.custom_logging import setup_logging

Expand Down Expand Up @@ -792,12 +792,7 @@ def finetune_tab(headless=False):
outputs=[basic_training.cache_latents], # Not applicable to fine_tune.py
)

(
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
) = sample_gradio_config()
sample = SampleImages()

button_run = gr.Button('Train model', variant='primary')

Expand Down Expand Up @@ -881,10 +876,10 @@ def finetune_tab(headless=False):
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
Expand Down
57 changes: 53 additions & 4 deletions library/sampler_gui.py → library/class_sample_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
### Gradio common sampler GUI section
###


def sample_gradio_config():
with gr.Accordion('Sample images config', open=False):
with gr.Row():
Expand Down Expand Up @@ -70,8 +69,7 @@ def sample_gradio_config():
sample_sampler,
sample_prompts,
)



def run_cmd_sample(
sample_every_n_steps,
sample_every_n_epochs,
Expand Down Expand Up @@ -104,4 +102,55 @@ def run_cmd_sample(
if not sample_every_n_steps == 0:
run_cmd += f' --sample_every_n_steps="{sample_every_n_steps}"'

return run_cmd
return run_cmd


class SampleImages:
def __init__(
self,
):
with gr.Accordion('Sample images config', open=False):
with gr.Row():
self.sample_every_n_steps = gr.Number(
label='Sample every n steps',
value=0,
precision=0,
interactive=True,
)
self.sample_every_n_epochs = gr.Number(
label='Sample every n epochs',
value=0,
precision=0,
interactive=True,
)
self.sample_sampler = gr.Dropdown(
label='Sample sampler',
choices=[
'ddim',
'pndm',
'lms',
'euler',
'euler_a',
'heun',
'dpm_2',
'dpm_2_a',
'dpmsolver',
'dpmsolver++',
'dpmsingle',
'k_lms',
'k_euler',
'k_euler_a',
'k_dpm_2',
'k_dpm_2_a',
],
value='euler_a',
interactive=True,
)
with gr.Row():
self.sample_prompts = gr.Textbox(
lines=5,
label='Sample prompts',
interactive=True,
placeholder='masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28',
)

17 changes: 6 additions & 11 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab
from library.verify_lora_gui import gradio_verify_lora_tab
from library.resize_lora_gui import gradio_resize_lora_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.class_sample_images import SampleImages, run_cmd_sample

from library.custom_logging import setup_logging

Expand Down Expand Up @@ -1415,12 +1415,7 @@ def update_LoRA_settings(LoRA_type):
outputs=[basic_training.cache_latents],
)

(
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
) = sample_gradio_config()
sample = SampleImages()

LoRA_type.change(
update_LoRA_settings,
Expand Down Expand Up @@ -1536,10 +1531,10 @@ def update_LoRA_settings(LoRA_type):
train_on_input,
conv_dim,
conv_alpha,
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
Expand Down
17 changes: 6 additions & 11 deletions textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
gradio_dreambooth_folder_creation_tab,
)
from library.utilities import utilities_tab
from library.sampler_gui import sample_gradio_config, run_cmd_sample
from library.class_sample_images import SampleImages, run_cmd_sample

from library.custom_logging import setup_logging

Expand Down Expand Up @@ -780,12 +780,7 @@ def ti_tab(
outputs=[basic_training.cache_latents],
)

(
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
) = sample_gradio_config()
sample = SampleImages()

with gr.Tab('Tools'):
gr.Markdown(
Expand Down Expand Up @@ -882,10 +877,10 @@ def ti_tab(
advanced_training.adaptive_noise_scale,
advanced_training.multires_noise_iterations,
advanced_training.multires_noise_discount,
sample_every_n_steps,
sample_every_n_epochs,
sample_sampler,
sample_prompts,
sample.sample_every_n_steps,
sample.sample_every_n_epochs,
sample.sample_sampler,
sample.sample_prompts,
advanced_training.additional_parameters,
advanced_training.vae_batch_size,
advanced_training.min_snr_gamma,
Expand Down

0 comments on commit 22dfa09

Please sign in to comment.