Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def parallelize_encoders(
fully_shard(t5_model.hf_module, **fsdp_config)

if parallel_dims.dp_replicate_enabled:
logger.info("Applied FSDP to the T5 encoder model")
logger.info("Applied HSDP to the T5 encoder model")
else:
logger.info("Applied FSDP to the T5 encoder model")

Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def denoise(
_, latent_channels, latent_height, latent_width = latents.shape

# create denoising schedule
timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
timesteps = get_schedule(denoising_steps, latent_height * latent_width, shift=True)

# create positional encodings
POSITION_DIM = 3
Expand Down
25 changes: 13 additions & 12 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,20 @@ def forward_backward_step(
# Patchify: Convert latent into a sequence of patches
latents = pack_latents(latents)

latent_noise_pred = model(
img=latents,
img_ids=latent_pos_enc,
txt=t5_encodings,
txt_ids=text_pos_enc,
y=clip_encodings,
timesteps=timesteps,
)
with self.maybe_enable_amp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide test / run results to show this would work for DDP or single-device training? Because we only supported FSDP for FLUX model, maybe_enable_map will only help when DDP or single device training

latent_noise_pred = model(
img=latents,
img_ids=latent_pos_enc,
txt=t5_encodings,
txt_ids=text_pos_enc,
y=clip_encodings,
timesteps=timesteps,
)

# Convert sequence of patches to latent shape
pred = unpack_latents(latent_noise_pred, latent_height, latent_width)
target = noise - labels
loss = self.loss_fn(pred, target)
# Convert sequence of patches to latent shape
pred = unpack_latents(latent_noise_pred, latent_height, latent_width)
target = noise - labels
loss = self.loss_fn(pred, target)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del (pred, noise, target)
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ def check_if_feature_in_pytorch(
# notify users to check if the pull request is included in their pytorch
logger.warning(
"Detected that the pytorch is built from source. Please make sure the PR "
f"({pull_request_link}) is included in pytorch for correct {feature_name}."
f"({pull_request}) is included in pytorch for correct {feature_name}."
)
elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
logger.warning(
f"Detected that the pytorch version {torch.__version__} is older than "
f"{min_nightly_version}. Please upgrade a newer version to include the "
f"change in ({pull_request_link}) for correct {feature_name}."
f"change in ({pull_request}) for correct {feature_name}."
)
Loading