-
Notifications
You must be signed in to change notification settings - Fork 638
Flux Validation #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux Validation #1518
Conversation
| dp_rank (int): Data parallel rank. | ||
| dp_world_size (int): Data parallel world size. | ||
| infinite (bool): Whether to loop over the dataset infinitely. | ||
| generate_timesteps (booL): Generate stratified timesteps in round-robin style for validation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| generate_timesteps (booL): Generate stratified timesteps in round-robin style for validation | |
| generate_timesteps (bool): Generate stratified timesteps in round-robin style for validation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try to create a subclass to do this, as discussed offline.
o/w it's easy to confuse with timestep concept during training / inference.
|
|
||
| # skip low quality image or image with color channel = 1 | ||
| if sample_dict["image"] is None: | ||
| # sample_id = sample.get('sample_id') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
| eval_freq: int = 100 | ||
| """Frequency of evaluation/sampling during training""" | ||
| save_imgs: int = 1 | ||
| """ How many images to generate and save in validation, -1 means same number as steps""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
explain that the source prompt is coming from validation dataset and taken from "the beginning"
| """How many denoising steps to sample when generating an image""" | ||
| eval_freq: int = 100 | ||
| """Frequency of evaluation/sampling during training""" | ||
| save_imgs: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call it something like save_imgs_count? save_imgs sounds like a bool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, can you show examples of generate img? Just to verify capability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| @dataclass | ||
| class Eval: | ||
| class Validation(Validation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to inherit? Training is not inheriting.
| # Patchify: Convert latent into a sequence of patches | ||
| latents = pack_latents(latents) | ||
|
|
||
| latent_noise_pred = model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please adapt to #1494 for the amp part
| bsz = labels.shape[0] | ||
|
|
||
| # To evaluate all 8 timesteps per sample do | ||
| if self.all_timesteps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's OK to keep this.
| validation_end_time = time.time() | ||
| validation_duration = validation_end_time - validation_start_time | ||
|
|
||
| # Log timing information | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| logger.info(f"Validation step {step} completed in {validation_duration:.3f}s ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this for debugging? we didn't need this in llama3 validator
|
|
||
| save_imgs = self.job_config.validation.save_imgs | ||
| if save_imgs == -1 or num_steps < save_imgs: | ||
| t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why build this multiple times?
| if isinstance(prompt, list): | ||
| prompt = " ".join(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in validate.py you are taking prompt[0]. Do you still need to concatenate here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This concatenation is because coco dataset gives the prompts in a list of strings instead of a single string for one sample. This part just combines the list-style prompt into a single string.
The prompt[0] in validation is since we only generate an image from the first sample in each validation batch
Edit: I'm changing this so that it will generate the correct number of images regardless of batch size
|
@tianyu-l addressed comments.
|
…d flux validator dataset for clarity. Corrects the image generation in validation to reflect save_img_count
ed8068d to
09141ad
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Had some final comments.
| from .model.args import FluxModelArgs | ||
| from .model.autoencoder import AutoEncoderParams | ||
| from .model.model import FluxModel | ||
| from .model.validate import build_flux_validator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validate doesn't sound part of model. Let's just put it under flux/ root folder
| img = _process_cc12m_image(sample["image"], output_size=output_size) | ||
| prompt = sample["caption"] | ||
| if isinstance(prompt, list): | ||
| prompt = " ".join(prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at one example
[
"A desktop computer monitor sitting next to a keyboard.",
"A computer, keyboard, modem, and mouse are sitting on a work desk.",
"a computer monitor, keyboard, and mouse sit on a table.",
"A black desktop computer atop a wooden table.",
"a desktop computer monitor with a keyboard and mouse"
]
It's always a list of alternative captions. I think we should only pick one (maybe the first one for deterministic training / eval).
|
|
||
| # skip low quality image or image with color channel = 1 | ||
| if sample_dict["image"] is None: | ||
| # sample_id = sample.get('sample_id') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove comment
| break | ||
|
|
||
| prompt = input_dict.pop("prompt") | ||
| if not isinstance(prompt, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To educate me: if prompt is yielded from the dataset as a str, would the dataloader batchify it into a list? Could you please share source on this behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch dataloader handles this batchifying all the way up the inheritance stack ParallelAwareDataloader -> StatefulDataLoader -> DataLoader. And the FluxDataset gets passed to the ParallelAwareDataLoader in the build_dataloader_fn. I'm not sure of the exact implementation details and edge cases but the api wants for an IterableDataset with an overloaded iter for datastream. https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
…pt instead of joining,
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really awesome work!
# This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper. The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score. This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly. ### Implementations - Integrates the image generation evaluation in the validation step, users can - Refactors and combines eval job_config with validation - Adds an `all_timesteps` option to the job_config to choose whether to use round robin timesteps or full timesteps per sample - Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation ### Enabling all timesteps Developers can enable the full timestamp method of validation by setting `all_timesteps = True` in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams `validation.local_batch_size, validation.steps` to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for `validation.local_batch_size` to `training.local_batch_size` will not spike the memory higher than training when `fsdp = 8`. Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in `all_timesteps=True` configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly. | Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps (batch_size=8, steps=4, fsdp=8) | | ---- | --- | | <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46 42 PM" src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b" /> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30 10 PM" src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083" /> |
# This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper. The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score. This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly. ### Implementations - Integrates the image generation evaluation in the validation step, users can - Refactors and combines eval job_config with validation - Adds an `all_timesteps` option to the job_config to choose whether to use round robin timesteps or full timesteps per sample - Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation ### Enabling all timesteps Developers can enable the full timestamp method of validation by setting `all_timesteps = True` in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams `validation.local_batch_size, validation.steps` to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for `validation.local_batch_size` to `training.local_batch_size` will not spike the memory higher than training when `fsdp = 8`. Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in `all_timesteps=True` configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly. | Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps (batch_size=8, steps=4, fsdp=8) | | ---- | --- | | <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46 42 PM" src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b" /> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30 10 PM" src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083" /> |

This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper.
The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score.
This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly.
Implementations
all_timestepsoption to the job_config to choose whether to use round robin timesteps or full timesteps per sampleEnabling all timesteps
Developers can enable the full timestamp method of validation by setting
all_timesteps = Truein the flux validation job config. Enabling all_timesteps may require tweaking some hyperparamsvalidation.local_batch_size, validation.stepsto prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 forvalidation.local_batch_sizetotraining.local_batch_sizewill not spike the memory higher than training whenfsdp = 8.Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in
all_timesteps=Trueconfiguration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly.