Skip to content

Commit

Permalink
Merge pull request #1263 from bghira/feature/dataset-configurator-for…
Browse files Browse the repository at this point in the history
…-sd3

updates for configure.py multi-resolution suggestions & dataset cache path handling
  • Loading branch information
bghira authored Jan 3, 2025
2 parents 13c405f + 9e32d34 commit 6588eac
Showing 1 changed file with 95 additions and 63 deletions.
158 changes: 95 additions & 63 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,69 +752,46 @@ def configure_env():
sys.exit(1)

# dataloader configuration
resolution_configs = {
256: {"resolution": 256, "minimum_image_size": 128},
512: {"resolution": 512, "minimum_image_size": 256},
768: {"resolution": 768, "minimum_image_size": 512},
1024: {"resolution": 1024, "minimum_image_size": 768},
1440: {"resolution": 1440, "minimum_image_size": 1024},
2048: {"resolution": 2048, "minimum_image_size": 1440},
}
default_dataset_configuration = {
"id": "PLACEHOLDER",
"type": "local",
"instance_data_dir": None,
"crop": False,
"resolution_type": "pixel_area",
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae",
}
default_cropped_dataset_configuration = {
"id": "PLACEHOLDER-crop",
"type": "local",
"instance_data_dir": None,
"crop": True,
"crop_aspect": "square",
"crop_style": "center",
"vae_cache_clear_each_epoch": False,
"resolution_type": "pixel_area",
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae-crop",
}

default_local_configuration = [
{
"id": "PLACEHOLDER-512",
"type": "local",
"instance_data_dir": None,
"crop": False,
"crop_style": "random",
"minimum_image_size": 128,
"resolution": 512,
"resolution_type": "pixel_area",
"repeats": 10,
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae-512",
},
{
"id": "PLACEHOLDER-1024",
"type": "local",
"instance_data_dir": None,
"crop": False,
"crop_style": "random",
"minimum_image_size": 128,
"resolution": 1024,
"resolution_type": "pixel_area",
"repeats": 10,
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae-1024",
},
{
"id": "PLACEHOLDER-512-crop",
"type": "local",
"instance_data_dir": None,
"crop": True,
"crop_style": "random",
"minimum_image_size": 128,
"resolution": 512,
"resolution_type": "pixel_area",
"repeats": 10,
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae-512-crop",
},
{
"id": "PLACEHOLDER-1024-crop",
"type": "local",
"instance_data_dir": None,
"crop": True,
"crop_style": "random",
"minimum_image_size": 128,
"resolution": 1024,
"resolution_type": "pixel_area",
"repeats": 10,
"metadata_backend": "discovery",
"caption_strategy": "filename",
"cache_dir_vae": "vae-1024-crop",
},
{
"id": "text-embed-cache",
"dataset_type": "text_embeds",
"default": True,
"type": "local",
"cache_dir": "text",
"write_batch_size": 128,
},
]

Expand Down Expand Up @@ -894,9 +871,36 @@ def configure_env():
)
dataset_repeats = int(
prompt_user(
"How many times do you want to repeat each image in the dataset?", 10
"How many times do you want to repeat each image in the dataset? A value of zero means the dataset will only be seen once; a value of one will cause the dataset to be sampled twice.",
10,
)
)
default_base_resolutions = "1024"
multi_resolution_recommendation_text = (
"Multiple resolutions may be provided, but this is only recommended for Flux."
)
multi_resolution_capable_models = ["flux"]
if env_contents["--model_family"] in multi_resolution_capable_models:
default_base_resolutions = "256,512,768,1024,1440"
multi_resolution_recommendation_text = "A comma-separated list of values or a single item can be given to train on multiple base resolutions."
dataset_resolutions = prompt_user(
f"Which resolutions do you want to train? {multi_resolution_recommendation_text}",
default_base_resolutions,
)
if "," in dataset_resolutions:
# most models don't work with multi base resolution training.
if env_contents["--model_family"] not in multi_resolution_capable_models:
print(
"WARNING: Most models do not play well with multi-resolution training, resulting in degraded outputs and broken hearts. Proceed with caution."
)
dataset_resolutions = [int(res) for res in dataset_resolutions.split(",")]
else:
try:
dataset_resolutions = [int(dataset_resolutions)]
except:
print("Invalid resolution value. Using 1024 instead.")
dataset_resolutions = [1024]

dataset_cache_prefix = prompt_user(
"Where will your VAE and text encoder caches be written to? Subdirectories will be created inside for you automatically.",
"cache/",
Expand All @@ -910,13 +914,21 @@ def configure_env():
)

# Now we'll modify the default json and if has_very_large_images is true, we will add two keys to each image dataset, 'maximum_image_size' and 'target_downsample_size' equal to the dataset's resolution value
for dataset in default_local_configuration:
if dataset.get("dataset_type") == "text_embeds":
dataset["cache_dir"] = f"{dataset_cache_prefix}/{dataset['cache_dir']}"
continue
dataset["instance_data_dir"] = dataset_path
def create_dataset_config(resolution, default_config):
dataset = default_config.copy()
dataset.update(resolution_configs[resolution])
dataset["id"] = f"{dataset['id']}-{resolution}"
dataset["instance_data_dir"] = os.path.abspath(dataset_path)
dataset["repeats"] = dataset_repeats
dataset["cache_dir_vae"] = f"{dataset_cache_prefix}/{dataset['cache_dir_vae']}"
# we want the absolute path, as this works best with datasets containing nested subdirectories.
dataset["cache_dir_vae"] = os.path.abspath(
os.path.join(
dataset_cache_prefix,
env_contents["--model_family"],
dataset["cache_dir_vae"],
str(resolution),
)
)
if has_very_large_images:
dataset["maximum_image_size"] = dataset["resolution"]
dataset["target_downsample_size"] = dataset["resolution"]
Expand All @@ -925,6 +937,26 @@ def configure_env():
dataset["instance_prompt"] = dataset_instance_prompt
dataset["caption_strategy"] = dataset_caption_strategy

if has_very_large_images:
dataset["maximum_image_size"] = dataset["resolution"]
dataset["target_downsample_size"] = dataset["resolution"]
return dataset

# this is because the text embed dataset is in the default config list at the top.
# it's confusingly written because i'm lazy, but you could do this any number of ways.
default_local_configuration[0]["cache_dir"] = os.path.abspath(
os.path.join(dataset_cache_prefix, env_contents["--model_family"], "text")
)
for resolution in dataset_resolutions:
uncropped_dataset = create_dataset_config(
resolution, default_dataset_configuration
)
default_local_configuration.append(uncropped_dataset)
cropped_dataset = create_dataset_config(
resolution, default_cropped_dataset_configuration
)
default_local_configuration.append(cropped_dataset)

print("Dataloader configuration:")
print(default_local_configuration)
confirm = prompt_user("Does this look correct? (y/n)", "y").lower() == "y"
Expand Down

0 comments on commit 6588eac

Please sign in to comment.