Skip to content

Commit

Permalink
Merge pull request #942 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Sep 5, 2024
2 parents f69fd58 + 5c92cdc commit 38f84f8
Show file tree
Hide file tree
Showing 12 changed files with 448 additions and 198 deletions.
19 changes: 13 additions & 6 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _encode_sd3_prompt_with_t5(
prompt=None,
num_images_per_prompt=1,
device=None,
return_masked_embed: bool = True,
zero_padding_tokens: bool = True,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
Expand All @@ -54,7 +54,7 @@ def _encode_sd3_prompt_with_t5(
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
attention_mask = text_inputs.attention_mask.to(device)

if return_masked_embed:
if zero_padding_tokens:
# for some reason, SAI's reference code doesn't bother to mask the prompt embeddings.
# this can lead to a problem where the model fails to represent short and long prompts equally well.
# additionally, the model learns the bias of the prompt embeds' noise.
Expand Down Expand Up @@ -255,7 +255,7 @@ def encode_flux_prompt(
tokenizers,
prompt: str,
is_validation: bool = False,
return_masked_embed: bool = True,
zero_padding_tokens: bool = True,
):
"""
Encode a prompt for a Flux model.
Expand Down Expand Up @@ -288,6 +288,11 @@ def encode_flux_prompt(
device=self.accelerator.device,
max_sequence_length=StateTracker.get_args().tokenizer_max_length,
)
if zero_padding_tokens:
# we can zero the padding tokens if we're just going to mask them later anyway.
prompt_embeds = prompt_embeds * masks.to(
device=prompt_embeds.device
).unsqueeze(-1).expand(prompt_embeds.shape)

return prompt_embeds, pooled_prompt_embeds, time_ids, masks

Expand All @@ -298,7 +303,7 @@ def encode_sd3_prompt(
tokenizers,
prompt: str,
is_validation: bool = False,
return_masked_embed: bool = True,
zero_padding_tokens: bool = True,
):
"""
Encode a prompt for an SD3 model.
Expand Down Expand Up @@ -341,7 +346,7 @@ def encode_sd3_prompt(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=self.accelerator.device,
return_masked_embed=return_masked_embed,
zero_padding_tokens=zero_padding_tokens,
)

clip_prompt_embeds = torch.nn.functional.pad(
Expand Down Expand Up @@ -494,7 +499,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):
self.tokenizers,
prompt,
is_validation,
return_masked_embed=(
zero_padding_tokens=(
True
if StateTracker.get_args().sd3_t5_mask_behaviour == "mask"
else False
Expand Down Expand Up @@ -1021,6 +1026,8 @@ def compute_embeddings_for_flux_prompts(
self.tokenizers,
[prompt],
is_validation,
zero_padding_tokens=StateTracker.get_args().t5_padding
== "zero",
)
)
logger.debug(
Expand Down
10 changes: 9 additions & 1 deletion helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,15 @@ def parse_cmdline_args(input_args=None):
default=False,
help="Use attention masking while training flux.",
)
parser.add_argument(
"--t5_padding",
choices=["zero", "unmodified"],
default="unmodified",
help=(
"The padding behaviour for Flux. The default is 'zero', which will pad the input with zeros."
" The alternative is 'unmodified', which will not pad the input."
),
)
parser.add_argument(
"--smoldit",
action="store_true",
Expand Down Expand Up @@ -1908,7 +1917,6 @@ def parse_cmdline_args(input_args=None):
raise ValueError(
f"Model is not using bf16 precision, but the optimizer {chosen_optimizer} requires it."
)
print(f"optimizer: {optimizer_details}")

if torch.backends.mps.is_available():
if (
Expand Down
7 changes: 6 additions & 1 deletion helpers/configuration/json_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import logging

# Set up logging
from helpers.training.multi_process import _get_rank

logger = logging.getLogger("SimpleTuner")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
if _get_rank() > 0:
logger.setLevel(logging.WARNING)
else:
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))


def normalize_args(args_dict):
Expand Down
7 changes: 6 additions & 1 deletion helpers/configuration/toml_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import logging

# Set up logging
from helpers.training.multi_process import _get_rank

logger = logging.getLogger("SimpleTuner")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))
if _get_rank() > 0:
logger.setLevel(logging.WARNING)
else:
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))


def normalize_args(args_dict):
Expand Down
2 changes: 1 addition & 1 deletion helpers/log_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def format(self, record):
logging.INFO
) # Change to ERROR if you want to suppress INFO messages too
console_handler.setFormatter(
ColorizedFormatter("%(asctime)s [%(levelname)s] (%(name)s) %(message)s")
ColorizedFormatter("%(asctime)s [%(levelname)s] %(message)s")
)

# blank out the existing debug.log, if exists
Expand Down
8 changes: 4 additions & 4 deletions helpers/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ def _get_clip_prompt_embeds(
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
# logger.warning(
# "The following part of your input was truncated because CLIP can only handle sequences up to"
# f" {self.tokenizer_max_length} tokens: {removed_text}"
# )
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False
)
Expand Down
Loading

0 comments on commit 38f84f8

Please sign in to comment.