Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Implement HuggingFace inputs in all training tabs #2287

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,12 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG

### 2024/04/12 (v23.1.6)

- Rewrote a lot of the code to fix issues resulting from the security upgrades to remove shell=True from process calls.
- Improved the training and tensorboard buttons
- Upgrade the gradio version to 4.20.0 to fix a bug with runpod.
- Various other minor fixes.
- Rewrote significant portions of the code to address security vulnerabilities and remove the `shell=True` parameter from process calls.
- Enhanced the training and tensorboard buttons to provide a more intuitive and user-friendly experience.
- Upgraded the gradio version to 4.20.0 to address a bug that was causing issues with the runpod platform.
- Added a HuggingFace section to all trainers tabs, enabling users to authenticate and utilize HuggingFace's powerful AI models.
- Converted the Graphical User Interface (GUI) to use the configuration TOML file format to pass arguments to sd-scripts. This change improves security by eliminating the need for sensitive information to be passed through the command-line interface (CLI).
- Made various other minor improvements and bug fixes to enhance the overall functionality and user experience.

### 2024/04/10 (v23.1.5)

Expand Down
4 changes: 2 additions & 2 deletions kohya_gui/class_command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
if self.process and self.process.poll() is None:
log.info("The command is already running. Please wait for it to finish.")
else:
for i, item in enumerate(run_cmd):
log.info(f"{i}: {item}")
# for i, item in enumerate(run_cmd):
# log.info(f"{i}: {item}")

# Reconstruct the safe command string for display
command_to_run = ' '.join(run_cmd)
log.info(f"Executings command: {command_to_run}")

Check warning on line 40 in kohya_gui/class_command_executor.py

View workflow job for this annotation

GitHub Actions / build

"Executings" should be "Executions".

Check warning on line 40 in kohya_gui/class_command_executor.py

View workflow job for this annotation

GitHub Actions / build

"Executings" should be "Executions".

# Execute the command securely
self.process = subprocess.Popen(run_cmd, **kwargs)
Expand Down
82 changes: 82 additions & 0 deletions kohya_gui/class_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import gradio as gr
import toml
from .class_gui_config import KohyaSSGUIConfig

class HuggingFace:
def __init__(
self,
config: KohyaSSGUIConfig = {},
) -> None:
self.config = config

# Initialize the UI components
self.initialize_ui_components()

def initialize_ui_components(self) -> None:
# --huggingface_repo_id HUGGINGFACE_REPO_ID
# huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名
# --huggingface_repo_type HUGGINGFACE_REPO_TYPE
# huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類
# --huggingface_path_in_repo HUGGINGFACE_PATH_IN_REPO
# huggingface model path to upload files / huggingfaceにアップロードするファイルのパス
# --huggingface_token HUGGINGFACE_TOKEN
# huggingface token / huggingfaceのトークン
# --huggingface_repo_visibility HUGGINGFACE_REPO_VISIBILITY
# huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定('public'で公開、'private'またはNoneで非公開)
# --save_state_to_huggingface
# save state to huggingface / huggingfaceにstateを保存する
# --resume_from_huggingface
# resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})
# --async_upload upload to huggingface asynchronously / huggingfaceに非同期でアップロードする
with gr.Row():
self.huggingface_repo_id = gr.Textbox(
label="Huggingface repo id",
placeholder="huggingface repo id",
value=self.config.get("huggingface.repo_id", ""),
)

self.huggingface_token = gr.Textbox(
label="Huggingface token",
placeholder="huggingface token",
value=self.config.get("huggingface.token", ""),
)

with gr.Row():
# Repository settings
self.huggingface_repo_type = gr.Textbox(
label="Huggingface repo type",
placeholder="huggingface repo type",
value=self.config.get("huggingface.repo_type", ""),
)

self.huggingface_repo_visibility = gr.Textbox(
label="Huggingface repo visibility",
placeholder="huggingface repo visibility",
value=self.config.get("huggingface.repo_visibility", ""),
)

with gr.Row():
# File location in the repository
self.huggingface_path_in_repo = gr.Textbox(
label="Huggingface path in repo",
placeholder="huggingface path in repo",
value=self.config.get("huggingface.path_in_repo", ""),
)

with gr.Row():
# Functions
self.save_state_to_huggingface = gr.Checkbox(
label="Save state to huggingface",
value=self.config.get("huggingface.save_state_to_huggingface", False),
)

self.resume_from_huggingface = gr.Textbox(
label="Resume from huggingface",
placeholder="resume from huggingface",
value=self.config.get("huggingface.resume_from_huggingface", ""),
)

self.async_upload = gr.Checkbox(
label="Async upload",
value=self.config.get("huggingface.async_upload", False),
)
17 changes: 17 additions & 0 deletions kohya_gui/class_sample_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
###
### Gradio common sampler GUI section
###
def create_prompt_file(sample_prompts, output_dir):
"""
Creates a prompt file for image sampling.

Args:
sample_prompts (str): The prompts to use for image sampling.
output_dir (str): The directory where the output images will be saved.

Returns:
str: The path to the prompt file.
"""
sample_prompts_path = os.path.join(output_dir, "prompt.txt")

with open(sample_prompts_path, "w") as f:
f.write(sample_prompts)

return sample_prompts_path


def run_cmd_sample(
Expand Down
3 changes: 2 additions & 1 deletion kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,8 @@ def run_cmd_advanced_training(run_cmd: list = [], **kwargs):

# Use Weights and Biases logging
if kwargs.get("use_wandb"):
run_cmd.append("--log_with wandb")
run_cmd.append("--log_with")
run_cmd.append("wandb")

# V parameterization
if kwargs.get("v_parameterization"):
Expand Down
Loading
Loading