Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revamp model card to work by default and provide quanto hints #1133

Merged
merged 6 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ The following values are recommended for `config.json`:
"--validation_guidance_skip_layers_start": 0.01,
"--validation_guidance_skip_layers_stop": 0.2,
"--validation_guidance_skip_scale": 2.8,
"--validation_guidance": 4.0
"--validation_guidance": 4.0,
"--flux_use_uniform_schedule": true,
"--flux_schedule_auto_shift": true
}
```

Expand Down Expand Up @@ -308,7 +310,8 @@ Some changes were made to SimpleTuner's SD3.5 support:
- Offering a switch (`--sd3_clip_uncond_behaviour` and `--sd3_t5_uncond_behaviour`) to use empty encoded blank captions for unconditional predictions (`empty_string`, **default**) or zeros (`zero`), not a recommended setting to tweak.
- SD3.5 training loss function was updated to match that found in the upstream StabilityAI/SD3.5 repository
- Updated default `--flux_schedule_shift` value to 3 to match the static 1024px value for SD3
- 512px training requires the use of `--flux_schedule_shift=1`
- StabilityAI followed-up with documentation to use `--flux_schedule_shift=1` with `--flux_use_uniform_schedule`
- Community members have reported that `--flux_schedule_auto_shift` works better when using mult-aspect or multi-resolution training
- Updated the hard-coded tokeniser sequence length limit to **256** with the option to revert it to **77** tokens to save disk space or compute at the cost of output quality degradation


Expand Down
16 changes: 10 additions & 6 deletions helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,13 @@ def batch_write_embeddings(self):
if len(batch) > 0:
self.process_write_batch(batch)
self.write_thread_bar.update(len(batch))
logger.debug(f"Exiting batch write thread, no more work to do after writing {written_elements} elements")
logger.debug(
f"Exiting batch write thread, no more work to do after writing {written_elements} elements"
)
break
logger.debug(f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}")
logger.debug(
f"Queue is empty. Retrieving new entries. Should retrieve? {self.process_write_batches}"
)
pass
except Exception:
logger.exception("An error occurred while writing embeddings to disk.")
Expand Down Expand Up @@ -525,9 +529,7 @@ def encode_prompt(self, prompt: str, is_validation: bool = False):
prompt,
is_validation,
zero_padding_tokens=(
True
if StateTracker.get_args().t5_padding == "zero"
else False
True if StateTracker.get_args().t5_padding == "zero" else False
),
)
else:
Expand Down Expand Up @@ -1320,7 +1322,9 @@ def compute_embeddings_for_sd3_prompts(
)
if should_encode:
# If load_from_cache is True, should_encode would be False unless we failed to load.
self.debug_log(f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}")
self.debug_log(
f"Encoding filename {filename} :: device {self.text_encoders[0].device} :: prompt {prompt}"
)
prompt_embeds, pooled_prompt_embeds = self.encode_sd3_prompt(
self.text_encoders,
self.tokenizers,
Expand Down
8 changes: 5 additions & 3 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def info_log(message):
logger.info(message)


def check_column_values(column_data, column_name, parquet_path, fallback_caption_column=False):
def check_column_values(
column_data, column_name, parquet_path, fallback_caption_column=False
):
# Determine if the column contains arrays or scalar values
non_null_values = column_data.dropna()
if non_null_values.empty:
Expand Down Expand Up @@ -362,15 +364,15 @@ def configure_parquet_database(backend: dict, args, data_backend: BaseDataBacken
df[caption_column],
caption_column,
parquet_path,
fallback_caption_column=fallback_caption_column
fallback_caption_column=fallback_caption_column,
)

# Apply the function to the filename_column.
check_column_values(
df[filename_column],
filename_column,
parquet_path,
fallback_caption_column=False # Always check filename_column
fallback_caption_column=False, # Always check filename_column
)

# Store the database in StateTracker
Expand Down
5 changes: 4 additions & 1 deletion helpers/image_manipulation/training_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def __init__(
self.resolution = self.data_backend_config.get("resolution")
self.resolution_type = self.data_backend_config.get("resolution_type")
self.target_size_calculator = resize_helpers.get(self.resolution_type)
if self.target_size_calculator is None and conditioning_type not in ["mask", "controlnet"]:
if self.target_size_calculator is None and conditioning_type not in [
"mask",
"controlnet",
]:
raise ValueError(f"Unknown resolution type: {self.resolution_type}")
self._set_resolution()
self.target_downsample_size = self.data_backend_config.get(
Expand Down
12 changes: 6 additions & 6 deletions helpers/models/flux/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,16 +824,16 @@ def __call__(

noise_pred = self.transformer(
hidden_states=latents.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=text_ids,
img_ids=latent_image_ids,
Expand All @@ -846,16 +846,16 @@ def __call__(
if guidance_scale_real > 1.0 and i >= no_cfg_until_timestep:
noise_pred_uncond = self.transformer(
hidden_states=latents.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
encoder_hidden_states=negative_prompt_embeds.to(
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
device=self.transformer.device # , dtype=self.transformer.dtype # can't cast dtype like this because of NF4
),
txt_ids=negative_text_ids.to(device=self.transformer.device),
img_ids=latent_image_ids.to(device=self.transformer.device),
Expand Down
4 changes: 3 additions & 1 deletion helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def prepare_instance_prompt_from_parquet(
if type(image_caption) == str:
image_caption = image_caption.strip()
if type(image_caption) in (list, tuple, numpy.ndarray, pd.Series):
image_caption = [str(item).strip() for item in image_caption if item is not None]
image_caption = [
str(item).strip() for item in image_caption if item is not None
]
if prepend_instance_prompt:
if type(image_caption) == list:
image_caption = [instance_prompt + " " + x for x in image_caption]
Expand Down
82 changes: 74 additions & 8 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,38 @@ def _model_imports(args):
return f"{output}"


def lycoris_download_info():
"""output a function to download the adapter"""
output_fn = """
def download_adapter(repo_id: str):
import os
from huggingface_hub import hf_hub_download
adapter_filename = "pytorch_lora_weights.safetensors"
cache_dir = os.environ.get('HF_PATH', os.path.expanduser('~/.cache/huggingface/hub/models'))
cleaned_adapter_path = repo_id.replace("/", "_").replace("\\\\", "_").replace(":", "_")
path_to_adapter = os.path.join(cache_dir, cleaned_adapter_path)
path_to_adapter_file = os.path.join(path_to_adapter, adapter_filename)
os.makedirs(path_to_adapter, exist_ok=True)
hf_hub_download(
repo_id=repo_id, filename=adapter_filename, local_dir=path_to_adapter
)

return path_to_adapter_file
"""

return output_fn


def _model_component_name(args):
model_component_name = "pipeline.transformer"
if args.model_family in ["sdxl", "kolors", "legacy", "deepfloyd"]:
model_component_name = "pipeline.unet"

return model_component_name


def _model_load(args, repo_id: str = None):
model_component_name = _model_component_name(args)
hf_user_name = StateTracker.get_hf_username()
if hf_user_name is not None:
repo_id = f"{hf_user_name}/{repo_id}" if hf_user_name else repo_id
Expand All @@ -114,22 +145,26 @@ def _model_load(args, repo_id: str = None):
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id), torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
f"\npipeline.load_lora_weights(adapter_id)"
)
elif args.lora_type.lower() == "lycoris":
output = (
f"model_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_id = 'pytorch_lora_weights.safetensors' # you will have to download this manually"
f"{lycoris_download_info()}"
f"\nmodel_id = '{args.pretrained_model_name_or_path}'"
f"\nadapter_repo_id = '{repo_id if repo_id is not None else args.output_dir}'"
f"\nadapter_filename = 'pytorch_lora_weights.safetensors'"
f"\nadapter_file_path = download_adapter(repo_id=adapter_repo_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
"\nlora_scale = 1.0"
)
else:
output = (
f"model_id = '{repo_id if repo_id else os.path.join(args.output_dir, 'pipeline')}'"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id)"
f"\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype={StateTracker.get_weight_dtype()}) # loading directly in bf16"
)
if args.model_type == "lora" and args.lora_type.lower() == "lycoris":
output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipeline.transformer)"
output += f"\nwrapper, _ = create_lycoris_from_weights(lora_scale, adapter_file_path, {model_component_name})"
output += "\nwrapper.merge_to()"

return output
Expand Down Expand Up @@ -162,6 +197,33 @@ def _skip_layers(args):
return f"\n skip_guidance_layers={args.validation_guidance_skip_layers},"


def _pipeline_move_to(args):
output = f"pipeline.to({_torch_device()}) # the pipeline is already in its target precision level"

return output


def _pipeline_quanto(args):
# return some optional lines to run Quanto on the model pipeline
if args.model_type == "full":
return ""
model_component_name = _model_component_name(args)
comment_character = ""
was_quantised = "The model was quantised during training, and so it is recommended to do the same during inference time."
if args.base_model_precision == "no_change":
comment_character = "#"
was_quantised = "The model was not quantised during training, so it is not necessary to quantise it during inference time."
output = f"""
## Optional: quantise the model to save on vram.
## Note: {was_quantised}
{comment_character}from optimum.quanto import quantize, freeze, qint8
{comment_character}quantize({model_component_name}, weights=qint8)
{comment_character}freeze({model_component_name})
"""

return output


def _validation_resolution(args):
if args.validation_resolution == "" or args.validation_resolution is None:
return f"width=1024,\n" f" height=1024,"
Expand All @@ -188,13 +250,14 @@ def code_example(args, repo_id: str = None):

prompt = "{args.validation_prompt if args.validation_prompt else 'An astronaut is riding a horse through the jungles of Thailand.'}"
{_negative_prompt(args)}
pipeline.to({_torch_device()})
{_pipeline_quanto(args)}
{_pipeline_move_to(args)}
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if args.model_family.lower() != 'flux' else ''}
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)},{_skip_layers(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}{_skip_layers(args)}
).images[0]
image.save("output.png", format="PNG")
```
Expand Down Expand Up @@ -226,7 +289,10 @@ def lora_info(args):
lycoris_config_file = args.lycoris_config
# read the json file
with open(lycoris_config_file, "r") as file:
lycoris_config = json.load(file)
try:
lycoris_config = json.load(file)
except:
lycoris_config = {"error": "could not locate or load LyCORIS config."}
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


Expand Down
3 changes: 2 additions & 1 deletion helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def safety_check(args, accelerator):
sys.exit(1)

if (
args.flux_schedule_shift is not None and args.flux_schedule_shift > 0
args.flux_schedule_shift is not None
and args.flux_schedule_shift > 0
and args.flux_schedule_auto_shift
):
logger.error(
Expand Down
12 changes: 8 additions & 4 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,9 +2183,11 @@ def train(self):
)
training_logger.debug(f"Working on batch size: {bsz}")
if self.config.flow_matching:
if (
not self.config.flux_fast_schedule
and not any([self.config.flux_use_beta_schedule, self.config.flux_use_uniform_schedule])
if not self.config.flux_fast_schedule and not any(
[
self.config.flux_use_beta_schedule,
self.config.flux_use_uniform_schedule,
]
):
# imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
# also used by: https://github.com/XLabs-AI/x-flux/tree/main
Expand Down Expand Up @@ -2316,7 +2318,9 @@ def train(self):
elif self.config.flow_matching_loss == "compatible":
target = noise - latents
elif self.config.flow_matching_loss == "sd35":
sigma_reshaped = sigmas.view(-1, 1, 1, 1) # Ensure sigma has the correct shape
sigma_reshaped = sigmas.view(
-1, 1, 1, 1
) # Ensure sigma has the correct shape
target = (noisy_latents - latents) / sigma_reshaped

elif self.noise_scheduler.config.prediction_type == "epsilon":
Expand Down
12 changes: 9 additions & 3 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
device=self.inference_device, dtype=self.weight_dtype
)
)
prompt_embeds["pooled_prompt_embeds"] = current_validation_pooled_embeds
prompt_embeds["pooled_prompt_embeds"] = current_validation_pooled_embeds.to(
device=self.inference_device, dtype=self.weight_dtype
)
prompt_embeds["negative_pooled_prompt_embeds"] = (
self.validation_negative_pooled_embeds
)
Expand All @@ -662,7 +664,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
current_validation_prompt_embeds, current_validation_prompt_mask = (
current_validation_prompt_embeds
)
current_validation_prompt_embeds = current_validation_prompt_embeds[0]
current_validation_prompt_embeds = current_validation_prompt_embeds[
0
].to(device=self.inference_device, dtype=self.weight_dtype)
if (
type(self.validation_negative_prompt_embeds) is tuple
or type(self.validation_negative_prompt_embeds) is list
Expand All @@ -672,7 +676,9 @@ def _gather_prompt_embeds(self, validation_prompt: str):
self.validation_negative_prompt_mask,
) = self.validation_negative_prompt_embeds[0]
else:
current_validation_prompt_embeds = current_validation_prompt_embeds[0]
current_validation_prompt_embeds = current_validation_prompt_embeds[
0
].to(device=self.inference_device, dtype=self.weight_dtype)
# logger.debug(
# f"Validations received the prompt embed: ({type(current_validation_prompt_embeds)}) positive={current_validation_prompt_embeds.shape if type(current_validation_prompt_embeds) is not list else current_validation_prompt_embeds[0].shape},"
# f" ({type(self.validation_negative_prompt_embeds)}) negative={self.validation_negative_prompt_embeds.shape if type(self.validation_negative_prompt_embeds) is not list else self.validation_negative_prompt_embeds[0].shape}"
Expand Down
Loading
Loading