Skip to content

Conversation

@wesleytruong
Copy link
Contributor

@wesleytruong wesleytruong commented Aug 1, 2025

This pr implements the validator class for flux following the method discussed in Stable Diffusion 3 paper.

The paper shows that creating 8 equidistant timesteps and calculating the average loss on them will result in a highly correlated loss to external validation methods such as CLIP or FID score.

This pr's implementation rather than creating 8 stratified timesteps per sample, only applies one of these equidistant timesteps to each sample in a round-robin fashion. Aggregated over many samples in a validation set, this should give a similar validation score as the full timestep method, but will process more validation samples quickly.

Implementations

  • Integrates the image generation evaluation in the validation step, users can
  • Refactors and combines eval job_config with validation
    • Adds an all_timesteps option to the job_config to choose whether to use round robin timesteps or full timesteps per sample
  • Creates validator class and validation dataloader for flux, validator dataloader handles generating timesteps for round-robin method of validation

Enabling all timesteps

Developers can enable the full timestamp method of validation by setting all_timesteps = True in the flux validation job config. Enabling all_timesteps may require tweaking some hyperparams validation.local_batch_size, validation.steps to prevent spiking memory and optimizing throughput. By using a ratio of around 1/4 for validation.local_batch_size to training.local_batch_size will not spike the memory higher than training when fsdp = 8.

Below we can see the difference between round robin and all timesteps. In the comparison the total number of validation samples processed is the same, but in all_timesteps=True configuration we have to lower the batch size to prevent memory spiking. All timesteps also achieves a higher throughput (tps) but still processes total samples of validation set more slowly.

Round Robin (batch_size=32, steps=1, fsdp=8) All Timesteps (batch_size=8, steps=4, fsdp=8)
Screenshot 2025-08-01 at 3 46 42 PM Screenshot 2025-08-01 at 3 30 10 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 1, 2025
@wesleytruong wesleytruong changed the title creates an efficient validator for flux using loss method Flux Validation Aug 1, 2025
dp_rank (int): Data parallel rank.
dp_world_size (int): Data parallel world size.
infinite (bool): Whether to loop over the dataset infinitely.
generate_timesteps (booL): Generate stratified timesteps in round-robin style for validation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
generate_timesteps (booL): Generate stratified timesteps in round-robin style for validation
generate_timesteps (bool): Generate stratified timesteps in round-robin style for validation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try to create a subclass to do this, as discussed offline.
o/w it's easy to confuse with timestep concept during training / inference.


# skip low quality image or image with color channel = 1
if sample_dict["image"] is None:
# sample_id = sample.get('sample_id')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment

eval_freq: int = 100
"""Frequency of evaluation/sampling during training"""
save_imgs: int = 1
""" How many images to generate and save in validation, -1 means same number as steps"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain that the source prompt is coming from validation dataset and taken from "the beginning"

"""How many denoising steps to sample when generating an image"""
eval_freq: int = 100
"""Frequency of evaluation/sampling during training"""
save_imgs: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it something like save_imgs_count? save_imgs sounds like a bool

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, can you show examples of generate img? Just to verify capability

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, can you show examples of generate img? Just to verify capability

Here's a generated image to show that generation works. The image isn't good currently since I don't have pre-trained weights but will revisit after adding hf conversion
image


@dataclass
class Eval:
class Validation(Validation):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to inherit? Training is not inheriting.

# Patchify: Convert latent into a sequence of patches
latents = pack_latents(latents)

latent_noise_pred = model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please adapt to #1494 for the amp part

bsz = labels.shape[0]

# To evaluate all 8 timesteps per sample do
if self.all_timesteps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to keep this.

Comment on lines 238 to 244
validation_end_time = time.time()
validation_duration = validation_end_time - validation_start_time

# Log timing information
from torchtitan.tools.logging import logger

logger.info(f"Validation step {step} completed in {validation_duration:.3f}s ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for debugging? we didn't need this in llama3 validator


save_imgs = self.job_config.validation.save_imgs
if save_imgs == -1 or num_steps < save_imgs:
t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why build this multiple times?

Comment on lines 129 to 130
if isinstance(prompt, list):
prompt = " ".join(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in validate.py you are taking prompt[0]. Do you still need to concatenate here?

Copy link
Contributor Author

@wesleytruong wesleytruong Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This concatenation is because coco dataset gives the prompts in a list of strings instead of a single string for one sample. This part just combines the list-style prompt into a single string.

The prompt[0] in validation is since we only generate an image from the first sample in each validation batch
Edit: I'm changing this so that it will generate the correct number of images regardless of batch size

@wesleytruong
Copy link
Contributor Author

@tianyu-l addressed comments.

  • added all_timesteps option to job_config,
  • separated flux dataset and flux validation dataset to separate the timestep generation logic that is only used for validation.
  • changed save_img_count to generate images more logically rather than one per batch

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Had some final comments.

from .model.args import FluxModelArgs
from .model.autoencoder import AutoEncoderParams
from .model.model import FluxModel
from .model.validate import build_flux_validator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate doesn't sound part of model. Let's just put it under flux/ root folder

img = _process_cc12m_image(sample["image"], output_size=output_size)
prompt = sample["caption"]
if isinstance(prompt, list):
prompt = " ".join(prompt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at one example

[
"A desktop computer monitor sitting next to a keyboard.",
"A computer, keyboard, modem, and mouse are sitting on a work desk.",
"a computer monitor, keyboard, and mouse sit on a table.",
"A black desktop computer atop a wooden table.",
"a desktop computer monitor with a keyboard and mouse"
]
It's always a list of alternative captions. I think we should only pick one (maybe the first one for deterministic training / eval).


# skip low quality image or image with color channel = 1
if sample_dict["image"] is None:
# sample_id = sample.get('sample_id')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove comment

break

prompt = input_dict.pop("prompt")
if not isinstance(prompt, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To educate me: if prompt is yielded from the dataset as a str, would the dataloader batchify it into a list? Could you please share source on this behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch dataloader handles this batchifying all the way up the inheritance stack ParallelAwareDataloader -> StatefulDataLoader -> DataLoader. And the FluxDataset gets passed to the ParallelAwareDataLoader in the build_dataloader_fn. I'm not sure of the exact implementation details and edge cases but the api wants for an IterableDataset with an overloaded iter for datastream. https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really awesome work!

@tianyu-l tianyu-l merged commit a204e31 into main Aug 5, 2025
10 checks passed
@tianyu-l tianyu-l deleted the flux_validator branch August 5, 2025 22:00
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
# This pr implements the validator class for flux following the method
discussed in Stable Diffusion 3 paper.

The paper shows that creating 8 equidistant timesteps and calculating
the average loss on them will result in a highly correlated loss to
external validation methods such as CLIP or FID score.

This pr's implementation rather than creating 8 stratified timesteps per
sample, only applies one of these equidistant timesteps to each sample
in a round-robin fashion. Aggregated over many samples in a validation
set, this should give a similar validation score as the full timestep
method, but will process more validation samples quickly.

### Implementations
- Integrates the image generation evaluation in the validation step,
users can
- Refactors and combines eval job_config with validation
- Adds an `all_timesteps` option to the job_config to choose whether to
use round robin timesteps or full timesteps per sample
- Creates validator class and validation dataloader for flux, validator
dataloader handles generating timesteps for round-robin method of
validation

### Enabling all timesteps
Developers can enable the full timestamp method of validation by setting
`all_timesteps = True` in the flux validation job config. Enabling
all_timesteps may require tweaking some hyperparams
`validation.local_batch_size, validation.steps` to prevent spiking
memory and optimizing throughput. By using a ratio of around 1/4 for
`validation.local_batch_size` to `training.local_batch_size` will not
spike the memory higher than training when `fsdp = 8`.

Below we can see the difference between round robin and all timesteps.
In the comparison the total number of validation samples processed is
the same, but in `all_timesteps=True` configuration we have to lower the
batch size to prevent memory spiking. All timesteps also achieves a
higher throughput (tps) but still processes total samples of validation
set more slowly.

| Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps
(batch_size=8, steps=4, fsdp=8) |
| ---- | --- |
| <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46
42 PM"
src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b"
/> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30
10 PM"
src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083"
/> |
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
# This pr implements the validator class for flux following the method
discussed in Stable Diffusion 3 paper.

The paper shows that creating 8 equidistant timesteps and calculating
the average loss on them will result in a highly correlated loss to
external validation methods such as CLIP or FID score.

This pr's implementation rather than creating 8 stratified timesteps per
sample, only applies one of these equidistant timesteps to each sample
in a round-robin fashion. Aggregated over many samples in a validation
set, this should give a similar validation score as the full timestep
method, but will process more validation samples quickly.

### Implementations
- Integrates the image generation evaluation in the validation step,
users can
- Refactors and combines eval job_config with validation
- Adds an `all_timesteps` option to the job_config to choose whether to
use round robin timesteps or full timesteps per sample
- Creates validator class and validation dataloader for flux, validator
dataloader handles generating timesteps for round-robin method of
validation

### Enabling all timesteps
Developers can enable the full timestamp method of validation by setting
`all_timesteps = True` in the flux validation job config. Enabling
all_timesteps may require tweaking some hyperparams
`validation.local_batch_size, validation.steps` to prevent spiking
memory and optimizing throughput. By using a ratio of around 1/4 for
`validation.local_batch_size` to `training.local_batch_size` will not
spike the memory higher than training when `fsdp = 8`.

Below we can see the difference between round robin and all timesteps.
In the comparison the total number of validation samples processed is
the same, but in `all_timesteps=True` configuration we have to lower the
batch size to prevent memory spiking. All timesteps also achieves a
higher throughput (tps) but still processes total samples of validation
set more slowly.

| Round Robin (batch_size=32, steps=1, fsdp=8) | All Timesteps
(batch_size=8, steps=4, fsdp=8) |
| ---- | --- |
| <img width="682" height="303" alt="Screenshot 2025-08-01 at 3 46
42 PM"
src="https://github.com/user-attachments/assets/30328bfe-4c3c-4912-a329-2b94c834b67b"
/> | <img width="719" height="308" alt="Screenshot 2025-08-01 at 3 30
10 PM"
src="https://github.com/user-attachments/assets/c7325d21-8a7b-41d9-a0d2-74052e425083"
/> |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants