diff --git a/examples/diffusion/recipes/flux/README.md b/examples/diffusion/recipes/flux/README.md new file mode 100644 index 0000000000..ce77166b2b --- /dev/null +++ b/examples/diffusion/recipes/flux/README.md @@ -0,0 +1,158 @@ +# FLUX Examples + +This directory contains example scripts for the FLUX diffusion model (text-to-image) with Megatron-Bridge: checkpoint conversion, inference, pretraining, and fine-tuning. + +All commands below assume you run them from the **Megatron-Bridge repository root** unless noted. Use `uv run` when you need the project’s virtualenv (e.g. `uv run python ...`, `uv run torchrun ...`). + +## Workspace Configuration + +Use a `WORKSPACE` environment variable as the base directory for checkpoints and results. Default is `/workspace`. Override it if needed: + +```bash +export WORKSPACE=/your/custom/path +``` + +Suggested layout: + +- `${WORKSPACE}/checkpoints/flux/` – Megatron FLUX checkpoints (after import) +- `${WORKSPACE}/checkpoints/flux_hf/` – Hugging Face FLUX model (download or export) +- `${WORKSPACE}/results/flux/` – Training outputs (pretrain/finetune) + +--- + +## 1. Checkpoint Conversion + +The script [conversion/convert_checkpoints.py](conversion/convert_checkpoints.py) converts between Hugging Face (diffusers) and Megatron checkpoint formats. + +**Source model:** [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (or a local clone). + +### Download the Hugging Face model (optional) + +If you want a local copy before conversion: + +```bash +huggingface-cli download black-forest-labs/FLUX.1-dev \ + --local-dir ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \ + --local-dir-use-symlinks False +``` + +**Note**: It is recommended to save the checkpoint because we will need to reuse the VAE and text encoders for the inference pipeline later as well. + +### Import: Hugging Face → Megatron + +Convert a Hugging Face FLUX model to Megatron format: + +```bash +uv run python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py import \ + --hf-model ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \ + --megatron-path ${WORKSPACE}/checkpoints/flux/flux.1-dev +``` + +The Megatron checkpoint is written under `--megatron-path` (e.g. `.../flux.1-dev/iter_0000000/`). Use that path for inference and fine-tuning. + +### Export: Megatron → Hugging Face + +Export a Megatron checkpoint back to Hugging Face (e.g. for use in diffusers). You must pass the **reference** HF model (for config and non-DiT components) and the **Megatron iteration directory**: + +```bash +uv run python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py export \ + --hf-model ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \ + --megatron-path ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \ + --hf-path ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev_export +``` + +**Note:** The exported directory contains only the DiT transformer weights. For a full pipeline (VAE, text encoders, etc.), copy the original HF repo and replace its `transformer` folder with the exported one. + +--- + +## 2. Inference + +The script [inference_flux.py](inference_flux.py) runs text-to-image generation with a Megatron-format FLUX checkpoint. You need: + +- **FLUX checkpoint:** Megatron DiT (e.g. from the import step above). +- **VAE:** Path to VAE weights (often inside the same HF repo as FLUX, e.g. `transformer` sibling directory or a separate VAE checkpoint). +- **Text encoders:** T5 and CLIP are loaded from Hugging Face by default; you can override with local paths. + +### Single prompt (default 1024×1024, 10 steps) + +```bash +uv run python examples/diffusion/recipes/flux/inference_flux.py \ + --flux_ckpt ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \ + --vae_ckpt ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev/vae \ + --prompts "a dog holding a sign that says hello world" \ + --output_path ./flux_output +``` + + +**VAE path:** If you downloaded FLUX.1-dev with `huggingface-cli`, the VAE is usually in the same repo (e.g. `${WORKSPACE}/checkpoints/flux_hf/flux.1-dev/vae`); use the path to the VAE subfolder or the main repo, depending on how the pipeline expects it. + +--- + +## 3. Pretraining + +The script [pretrain_flux.py](pretrain_flux.py) runs FLUX pretraining with the `pretrain_config()` recipe. Configuration can be overridden with Hydra-style CLI keys. + +**Recipe:** [megatron.bridge.diffusion.recipes.flux.flux.pretrain_config](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/src/megatron/bridge/diffusion/recipes/flux/flux.py) + +### Quick run with mock data (single node, 8 GPUs) + +```bash +uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock +``` + +### With CLI overrides only + +```bash +uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock \ + model.tensor_model_parallel_size=4 \ + train.train_iters=10000 \ + optimizer.lr=1e-4 +``` + + +### Flow matching options + +```bash +uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock \ + --timestep-sampling logit_normal \ + --flow-shift 1.0 \ + --use-loss-weighting +``` + +Before pretraining with real data, set the dataset in the recipe or in your YAML/CLI (e.g. `data_paths`, dataset blend, and cache paths). For data preprocessing, see the Megatron-Bridge data tutorials. + +--- + +## 4. Fine-Tuning + +The script [finetune_flux.py](finetune_flux.py) fine-tunes a pretrained FLUX checkpoint (Megatron format). It loads model weights and resets optimizer and step count; config can be overridden via YAML and CLI as with pretraining. + +Point `--load-checkpoint` at the **Megatron checkpoint directory** (either the base dir, e.g. `.../flux.1-dev`, or a specific iteration, e.g. `.../flux.1-dev/iter_0000000`): + +```bash +uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/finetune_flux.py \ + --load-checkpoint ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \ + --mock +``` + +**Note**: If you pass a path that ends with an `iter_XXXXXXX` directory, the script loads that iteration; otherwise it uses the latest iteration under the given path. + +**Note**: Loss might explode if you are using a mock dataset. + +--- + +## Summary: End-to-End Flow + +1. **Conversion (HF → Megatron)** + Download FLUX.1-dev (optional), then run the `import` command. Use the created `iter_0000000` path as your Megatron checkpoint. + +2. **Inference** + Run [inference_flux.py](inference_flux.py) with `--flux_ckpt` (Megatron `iter_*` path), `--vae_ckpt`, and `--prompts`. + +3. **Pretraining** + Run [pretrain_flux.py](pretrain_flux.py) with `--mock` or your data config; optionally use `--config-file` and CLI overrides. + +4. **Fine-Tuning** + Run [finetune_flux.py](finetune_flux.py) with `--load-checkpoint` set to a Megatron checkpoint (import or pretrain/finetune output), then `--mock` or your data and overrides. + +For more details, see the docstrings in each script and the recipe in `src/megatron/bridge/diffusion/recipes/flux/flux.py`. diff --git a/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py b/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py index f9a1d85643..a7f90a7d6c 100644 --- a/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py +++ b/examples/diffusion/recipes/flux/conversion/convert_checkpoints.py @@ -129,6 +129,8 @@ def import_hf_to_megatron( bridge = FluxBridge() provider = bridge.provider_bridge(hf) provider.perform_initialization = False + # Finalize config so init_method/output_layer_init_method are set (required by Megatron MLP) + provider.finalize() megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) bridge.load_weights_hf_to_megatron(hf, megatron_models) diff --git a/examples/diffusion/recipes/flux/finetune_flux.py b/examples/diffusion/recipes/flux/finetune_flux.py index 9c48f70e43..1dbfbe9023 100644 --- a/examples/diffusion/recipes/flux/finetune_flux.py +++ b/examples/diffusion/recipes/flux/finetune_flux.py @@ -22,40 +22,28 @@ The script loads a pretrained checkpoint and continues training with your custom dataset. Fine-tuning typically uses lower learning rates and fewer training iterations compared to pretraining. -Forward Step Options: - - Automodel FlowMatchingPipeline (default): Unified flow matching implementation - - Original FluxForwardStep (--use-original-step): Classic implementation - Examples: - Basic usage with checkpoint loading (uses automodel pipeline): - $ torchrun --nproc_per_node=8 finetune_flux.py \ + Basic usage with checkpoint loading: + $ uv run torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/pretrained/checkpoint --mock - Using original FluxForwardStep: - $ torchrun --nproc_per_node=8 finetune_flux.py \ - --load-checkpoint /path/to/pretrained/checkpoint --mock --use-original-step - Using a custom YAML config file: - $ torchrun --nproc_per_node=8 finetune_flux.py \ + $ uv run torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/pretrained/checkpoint \ --config-file my_custom_config.yaml Using CLI overrides only: - $ torchrun --nproc_per_node=8 finetune_flux.py \ + $ uv run torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/pretrained/checkpoint \ model.tensor_model_parallel_size=4 train.train_iters=5000 optimizer.lr=1e-5 Combining YAML and CLI overrides (CLI takes precedence): - $ torchrun --nproc_per_node=8 finetune_flux.py \ + $ uv run torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/pretrained/checkpoint \ --config-file conf/my_config.yaml \ model.pipeline_dtype=torch.float16 \ train.global_batch_size=512 - Using automodel pipeline with custom parameters (automodel is default): - $ torchrun --nproc_per_node=8 finetune_flux.py \ - --load-checkpoint /path/to/pretrained/checkpoint --mock \ - --flow-shift=1.0 --use-loss-weighting Configuration Precedence: 1. Base configuration from pretrain_config() recipe @@ -81,7 +69,7 @@ from omegaconf import OmegaConf -from megatron.bridge.diffusion.models.flux.flux_step_with_automodel import create_flux_forward_step +from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep from megatron.bridge.diffusion.recipes.flux.flux import pretrain_config from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.pretrain import pretrain @@ -159,22 +147,16 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: ) parser.add_argument("--debug", action="store_true", help="Enable debug logging") - # Forward step implementation choice - parser.add_argument( - "--use-original-step", - action="store_true", - help="Use original FluxForwardStep instead of automodel FlowMatchingPipeline (default)", - ) parser.add_argument( "--flow-shift", type=float, default=1.0, - help="Flow shift parameter (for automodel pipeline)", + help="Flow shift parameter", ) parser.add_argument( "--use-loss-weighting", action="store_true", - help="Use loss weighting (for automodel pipeline)", + help="Use loss weighting", ) # Parse known args for the script, remaining will be treated as overrides @@ -197,20 +179,16 @@ def main() -> None: and handles type conversions automatically. Examples of CLI usage: - # Fine-tune with default config and custom learning rate (automodel pipeline is default) + # Fine-tune with default config and custom learning rate torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/checkpoint --mock optimizer.lr=1e-5 - # Use original FluxForwardStep instead of automodel pipeline - torchrun --nproc_per_node=8 finetune_flux.py \ - --load-checkpoint /path/to/checkpoint --mock --use-original-step - # Custom config file with additional overrides torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/checkpoint \ --config-file my_config.yaml train.train_iters=5000 - # Multiple overrides for distributed fine-tuning (uses automodel by default) + # Multiple overrides for distributed fine-tuning torchrun --nproc_per_node=8 finetune_flux.py \ --load-checkpoint /path/to/checkpoint --mock \ model.tensor_model_parallel_size=4 \ @@ -218,10 +196,6 @@ def main() -> None: train.global_batch_size=512 \ optimizer.lr=5e-6 - # Automodel pipeline with custom flow matching parameters - torchrun --nproc_per_node=8 finetune_flux.py \ - --load-checkpoint /path/to/checkpoint --mock \ - --flow-shift=1.0 --use-loss-weighting """ args, cli_overrides = parse_cli_args() @@ -337,43 +311,21 @@ def main() -> None: cfg.checkpoint.load = None # Clear load to ensure pretrained_checkpoint takes precedence cfg.checkpoint.finetune = True - # Create forward step (configurable: original or automodel pipeline) - # Default is automodel pipeline unless --use-original-step is specified - if not args.use_original_step: - # Use automodel FlowMatchingPipeline - flux_forward_step = create_flux_forward_step( - use_automodel_pipeline=True, - timestep_sampling=args.timestep_sampling, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - flow_shift=args.flow_shift, - scheduler_steps=args.scheduler_steps, - guidance_scale=args.guidance_scale, - use_loss_weighting=args.use_loss_weighting, - ) - if get_rank_safe() == 0: - logger.info("=" * 70) - logger.info("✅ Using AUTOMODEL FlowMatchingPipeline") - logger.info(f" Timestep Sampling: {args.timestep_sampling}") - logger.info(f" Flow Shift: {args.flow_shift}") - logger.info(f" Loss Weighting: {args.use_loss_weighting}") - logger.info("=" * 70) - else: - # Use original FluxForwardStep - flux_forward_step = create_flux_forward_step( - use_automodel_pipeline=False, - timestep_sampling=args.timestep_sampling, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - scheduler_steps=args.scheduler_steps, - guidance_scale=args.guidance_scale, - ) - if get_rank_safe() == 0: - logger.info("=" * 70) - logger.info("✅ Using ORIGINAL FluxForwardStep") - logger.info(f" Timestep Sampling: {args.timestep_sampling}") - logger.info("=" * 70) + flux_forward_step = FluxForwardStep( + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + flow_shift=args.flow_shift, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + use_loss_weighting=args.use_loss_weighting, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info(f" Flow Shift: {args.flow_shift}") + logger.info(f" Loss Weighting: {args.use_loss_weighting}") + logger.info("=" * 70) # Display final configuration if get_rank_safe() == 0: @@ -394,9 +346,8 @@ def main() -> None: logger.info(f" mode_scale: {args.mode_scale}") logger.info(f" scheduler_steps: {args.scheduler_steps}") logger.info(f" guidance_scale: {args.guidance_scale}") - if not args.use_original_step: - logger.info(f" flow_shift: {args.flow_shift}") - logger.info(f" use_loss_weighting: {args.use_loss_weighting}") + logger.info(f" flow_shift: {args.flow_shift}") + logger.info(f" use_loss_weighting: {args.use_loss_weighting}") # Start training (fine-tuning) logger.debug("Starting fine-tuning...") diff --git a/examples/diffusion/recipes/flux/inference_flux.py b/examples/diffusion/recipes/flux/inference_flux.py index 2acc7bf468..38678fdf12 100644 --- a/examples/diffusion/recipes/flux/inference_flux.py +++ b/examples/diffusion/recipes/flux/inference_flux.py @@ -12,6 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +FLUX Inference Script for text-to-image generation. + +Runs the FLUX diffusion model to generate images from text prompts. Requires a Megatron-format +FLUX checkpoint (e.g. from convert_checkpoints.py import), a VAE checkpoint, and optionally +T5/CLIP text encoder IDs (downloaded from Hugging Face by default). + +Examples: + Single prompt, default resolution (1024x1024) and 10 steps: + $ uv run python inference_flux.py --flux_ckpt /path/to/flux/iter_0000000 \\ + --vae_ckpt /path/to/vae --prompts "a dog holding a sign said hello world" \\ + --output_path ./flux_output + + Multiple prompts: + $ uv run python inference_flux.py --flux_ckpt /path/to/flux/iter_0000000 \\ + --vae_ckpt /path/to/vae --prompts "prompt one" --prompts "prompt two" \\ + --output_path ./flux_output + + Custom resolution and inference steps: + $ uv run python inference_flux.py --flux_ckpt /path/to/flux/iter_0000000 \\ + --vae_ckpt /path/to/vae --prompts "a cat on a mat" \\ + --height 512 --width 512 --num_inference_steps 20 --output_path ./flux_output + + From repository root: + $ uv run python examples/diffusion/recipes/flux/inference_flux.py \\ + --flux_ckpt /aot/checkpoints/dfm/flux.1-dev/iter_0000000 \\ + --vae_ckpt /aot/checkpoints/flux.1-dev/vae \\ + --prompts "a dog holding a sign said hello world" --output_path ./flux_output +""" + import argparse import os @@ -25,7 +55,6 @@ def parse_args(): # noqa: D103 parser.add_argument("--vae_ckpt", type=str, default=None, help="Path to VAE") parser.add_argument("--t5_version", type=str, default="google/t5-v1_1-xxl") parser.add_argument("--clip_version", type=str, default="openai/clip-vit-large-patch14") - parser.add_argument("--do_convert_from_hf", action="store_true", default=False) parser.add_argument( "--prompts", type=str, @@ -38,6 +67,7 @@ def parse_args(): # noqa: D103 parser.add_argument("--num_inference_steps", type=int, default=10) parser.add_argument("--guidance_scale", type=float, default=0.0) parser.add_argument("--output_path", type=str, default="/tmp/flux_output") + parser.add_argument("--base_seed", type=int, default=42, help="Random seed for reproducibility") return parser.parse_args() diff --git a/examples/diffusion/recipes/flux/pretrain_flux.py b/examples/diffusion/recipes/flux/pretrain_flux.py index 127c4346c0..1280f7b49f 100644 --- a/examples/diffusion/recipes/flux/pretrain_flux.py +++ b/examples/diffusion/recipes/flux/pretrain_flux.py @@ -19,31 +19,21 @@ This script provides a flexible way to pretrain FLUX models using Megatron-Bridge with support for both YAML configuration files and command-line overrides using Hydra-style syntax. -Forward Step Options: - - Automodel FlowMatchingPipeline (default): Unified flow matching implementation - - Original FluxForwardStep (--use-original-step): Classic implementation - Examples: - Basic usage with default configuration (uses automodel pipeline): - $ torchrun --nproc_per_node=8 pretrain_flux.py --mock - - Using original FluxForwardStep: - $ torchrun --nproc_per_node=8 pretrain_flux.py --mock --use-original-step + Basic usage with default configuration: + $ uv run torchrun --nproc_per_node=8 pretrain_flux.py --mock Using a custom YAML config file: - $ torchrun --nproc_per_node=8 pretrain_flux.py --config-file my_custom_config.yaml + $ uv run torchrun --nproc_per_node=8 pretrain_flux.py --config-file my_custom_config.yaml Using CLI overrides only: - $ torchrun --nproc_per_node=8 pretrain_flux.py model.tensor_model_parallel_size=4 train.train_iters=100000 + $ uv run torchrun --nproc_per_node=8 pretrain_flux.py model.tensor_model_parallel_size=4 train.train_iters=100000 Combining YAML and CLI overrides (CLI takes precedence): - $ torchrun --nproc_per_node=8 pretrain_flux.py --config-file conf/my_config.yaml \ + $ uv run torchrun --nproc_per_node=8 pretrain_flux.py --config-file conf/my_config.yaml \ model.pipeline_dtype=torch.float16 \ train.global_batch_size=512 - Using automodel pipeline with custom parameters (automodel is default): - $ torchrun --nproc_per_node=8 pretrain_flux.py --mock \ - --flow-shift=1.0 --use-loss-weighting Configuration Precedence: 1. Base configuration from pretrain_config() recipe @@ -68,7 +58,7 @@ from omegaconf import OmegaConf -from megatron.bridge.diffusion.models.flux.flux_step_with_automodel import create_flux_forward_step +from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep from megatron.bridge.diffusion.recipes.flux.flux import pretrain_config from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.pretrain import pretrain @@ -141,21 +131,16 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: parser.add_argument("--debug", action="store_true", help="Enable debug logging") # Forward step implementation choice - parser.add_argument( - "--use-original-step", - action="store_true", - help="Use original FluxForwardStep instead of automodel FlowMatchingPipeline (default)", - ) parser.add_argument( "--flow-shift", type=float, default=1.0, - help="Flow shift parameter (for automodel pipeline)", + help="Flow shift parameter", ) parser.add_argument( "--use-loss-weighting", action="store_true", - help="Use loss weighting (for automodel pipeline)", + help="Use loss weighting", ) # Parse known args for the script, remaining will be treated as overrides @@ -177,22 +162,19 @@ def main() -> None: and handles type conversions automatically. Examples of CLI usage: - # Use default config with custom learning rate (automodel pipeline is default) + # Use default config with custom learning rate torchrun --nproc_per_node=8 pretrain_flux.py --mock optimizer.lr=0.0002 - # Use original FluxForwardStep instead of automodel pipeline - torchrun --nproc_per_node=8 pretrain_flux.py --mock --use-original-step - # Custom config file with additional overrides torchrun --nproc_per_node=8 pretrain_flux.py --config-file my_config.yaml train.train_iters=50000 - # Multiple overrides for distributed training (uses automodel by default) + # Multiple overrides for distributed training torchrun --nproc_per_node=8 pretrain_flux.py --mock \ model.tensor_model_parallel_size=4 \ model.pipeline_model_parallel_size=2 \ train.global_batch_size=512 - # Automodel pipeline with custom flow matching parameters + # Pipeline with custom flow matching parameters torchrun --nproc_per_node=8 pretrain_flux.py --mock \ --flow-shift=1.0 --use-loss-weighting """ @@ -234,43 +216,21 @@ def main() -> None: # Apply overrides while preserving excluded fields apply_overrides(cfg, final_overrides_as_dict, excluded_fields) - # Create forward step (configurable: original or automodel pipeline) - # Default is automodel pipeline unless --use-original-step is specified - if not args.use_original_step: - # Use automodel FlowMatchingPipeline - flux_forward_step = create_flux_forward_step( - use_automodel_pipeline=True, - timestep_sampling=args.timestep_sampling, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - flow_shift=args.flow_shift, - scheduler_steps=args.scheduler_steps, - guidance_scale=args.guidance_scale, - use_loss_weighting=args.use_loss_weighting, - ) - if get_rank_safe() == 0: - logger.info("=" * 70) - logger.info("✅ Using AUTOMODEL FlowMatchingPipeline") - logger.info(f" Timestep Sampling: {args.timestep_sampling}") - logger.info(f" Flow Shift: {args.flow_shift}") - logger.info(f" Loss Weighting: {args.use_loss_weighting}") - logger.info("=" * 70) - else: - # Use original FluxForwardStep - flux_forward_step = create_flux_forward_step( - use_automodel_pipeline=False, - timestep_sampling=args.timestep_sampling, - logit_mean=args.logit_mean, - logit_std=args.logit_std, - mode_scale=args.mode_scale, - scheduler_steps=args.scheduler_steps, - guidance_scale=args.guidance_scale, - ) - if get_rank_safe() == 0: - logger.info("=" * 70) - logger.info("✅ Using ORIGINAL FluxForwardStep") - logger.info(f" Timestep Sampling: {args.timestep_sampling}") - logger.info("=" * 70) + flux_forward_step = FluxForwardStep( + timestep_sampling=args.timestep_sampling, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + flow_shift=args.flow_shift, + scheduler_steps=args.scheduler_steps, + guidance_scale=args.guidance_scale, + use_loss_weighting=args.use_loss_weighting, + ) + if get_rank_safe() == 0: + logger.info("=" * 70) + logger.info(f" Timestep Sampling: {args.timestep_sampling}") + logger.info(f" Flow Shift: {args.flow_shift}") + logger.info(f" Loss Weighting: {args.use_loss_weighting}") + logger.info("=" * 70) # Display final configuration if get_rank_safe() == 0: @@ -284,9 +244,8 @@ def main() -> None: logger.info(f" mode_scale: {args.mode_scale}") logger.info(f" scheduler_steps: {args.scheduler_steps}") logger.info(f" guidance_scale: {args.guidance_scale}") - if not args.use_original_step: - logger.info(f" flow_shift: {args.flow_shift}") - logger.info(f" use_loss_weighting: {args.use_loss_weighting}") + logger.info(f" flow_shift: {args.flow_shift}") + logger.info(f" use_loss_weighting: {args.use_loss_weighting}") # Start training logger.debug("Starting pretraining...") diff --git a/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py b/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py index b517883859..7ed21b2775 100644 --- a/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py +++ b/src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Mapping +from typing import Dict, Mapping import torch from diffusers import FluxTransformer2DModel @@ -21,7 +21,7 @@ from megatron.bridge.diffusion.models.flux.flux_model import Flux from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, QKVMapping, @@ -94,6 +94,44 @@ def maybe_modify_loaded_hf_weight( hf_weights = {k: hf_state_dict[v] for k, v in hf_param.items()} return hf_weights + def maybe_modify_converted_hf_weight( + self, + task: WeightConversionTask, + converted_weights_dict: Dict[str, torch.Tensor], + hf_state_dict: Mapping[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Merge split proj_out weight_1 and weight_2 back into a single HF 'weight' for export. + + On load we split HF proj_out.weight into weight_1 (linear_fc2) and weight_2 (linear_proj). + On export we must merge them back as [weight_2, weight_1] along dim=1 to match HF format. + """ + if not hasattr(self, "_export_proj_out_pending"): + self._export_proj_out_pending = {} + + result = {} + for hf_name, tensor in list(converted_weights_dict.items()): + if hf_name.endswith(".weight_1"): + base = hf_name[: -len(".weight_1")] + self._export_proj_out_pending.setdefault(base, {})["weight_1"] = tensor + if "weight_2" in self._export_proj_out_pending[base]: + w1 = self._export_proj_out_pending[base]["weight_1"] + w2 = self._export_proj_out_pending[base]["weight_2"] + merged = torch.cat([w2, w1], dim=1) + result[f"{base}.weight"] = merged + del self._export_proj_out_pending[base] + elif hf_name.endswith(".weight_2"): + base = hf_name[: -len(".weight_2")] + self._export_proj_out_pending.setdefault(base, {})["weight_2"] = tensor + if "weight_1" in self._export_proj_out_pending[base]: + w1 = self._export_proj_out_pending[base]["weight_1"] + w2 = self._export_proj_out_pending[base]["weight_2"] + merged = torch.cat([w2, w1], dim=1) + result[f"{base}.weight"] = merged + del self._export_proj_out_pending[base] + else: + result[hf_name] = tensor + return result + def mapping_registry(self) -> MegatronMappingRegistry: """Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format. diff --git a/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py b/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py index 26605b48b6..35215e2a0a 100644 --- a/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py +++ b/src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py @@ -57,4 +57,8 @@ def __post_init__(self): self.sequence_length = self.dataset.seq_length def build_datasets(self, context: DatasetBuildContext): - return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + return ( + iter(self.dataset.train_dataloader()), + iter(self.dataset.val_dataloader()), + iter(self.dataset.val_dataloader()), + ) diff --git a/src/megatron/bridge/diffusion/models/flux/flux_attention.py b/src/megatron/bridge/diffusion/models/flux/flux_attention.py index 72bd1f91a2..5e77005c10 100644 --- a/src/megatron/bridge/diffusion/models/flux/flux_attention.py +++ b/src/megatron/bridge/diffusion/models/flux/flux_attention.py @@ -115,6 +115,7 @@ def __init__( tp_comm_buffer_name="qkv", ) + self.context_pre_only = context_pre_only if not context_pre_only: self.added_linear_proj = build_module( submodules.linear_proj, @@ -128,6 +129,8 @@ def __init__( is_expert=False, tp_comm_buffer_name="proj", ) + else: + self.added_linear_proj = None if submodules.q_layernorm is not None: self.q_layernorm = build_module( @@ -360,10 +363,14 @@ def forward( attention_output = core_attn_out[additional_hidden_states.shape[0] :, :, :] output, bias = self.linear_proj(attention_output) - encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output) + if self.added_linear_proj is not None: + encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output) + encoder_output = encoder_output + encoder_bias + else: + # context_pre_only: encoder output not used by caller; return as-is + encoder_output = encoder_attention_output output = output + bias - encoder_output = encoder_output + encoder_bias return output, encoder_output diff --git a/src/megatron/bridge/diffusion/models/flux/flux_step.py b/src/megatron/bridge/diffusion/models/flux/flux_step.py index f9de829bb7..c8570bc3bf 100644 --- a/src/megatron/bridge/diffusion/models/flux/flux_step.py +++ b/src/megatron/bridge/diffusion/models/flux/flux_step.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +FLUX Forward Step. + +This is a prototype showing how to integrate the FlowMatchingPipeline +into Megatron's training flow, reusing the well-tested flow matching logic. +""" import logging -import math -from functools import lru_cache, partial +from functools import partial from typing import Iterable import torch @@ -23,7 +28,8 @@ from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.utils import get_model_config -from megatron.bridge.diffusion.models.flux.flow_matching.flux_inference_pipeline import FlowMatchEulerDiscreteScheduler +from megatron.bridge.diffusion.common.flow_matching.flow_matching_pipeline import FlowMatchingPipeline +from megatron.bridge.diffusion.models.flux.flow_matching.flux_adapter import MegatronFluxAdapter from megatron.bridge.training.losses import masked_next_token_loss from megatron.bridge.training.state import GlobalState @@ -31,6 +37,11 @@ logger = logging.getLogger(__name__) +# ============================================================================= +# Megatron Forward Step +# ============================================================================= + + def flux_data_step(dataloader_iter, store_in_state=False): """Process batch data for FLUX model. @@ -66,21 +77,18 @@ def flux_data_step(dataloader_iter, store_in_state=False): class FluxForwardStep: - """Forward step for FLUX diffusion model training. - - This class handles the forward pass during training, including: - - Timestep sampling using flow matching - - Noise injection with latent packing - - Model prediction - - Loss computation + """ + Forward step for FLUX using FlowMatchingPipeline. + This class demonstrates how to integrate the FlowMatchingPipeline Args: timestep_sampling: Method for sampling timesteps ("logit_normal", "uniform", "mode"). logit_mean: Mean for logit-normal sampling. logit_std: Standard deviation for logit-normal sampling. - mode_scale: Scale for mode sampling. + flow_shift: Shift parameter for timestep transformation (default: 1.0 for FLUX). scheduler_steps: Number of scheduler training steps. guidance_scale: Guidance scale for FLUX-dev models. + use_loss_weighting: Whether to apply flow-based loss weighting. """ def __init__( @@ -88,24 +96,44 @@ def __init__( timestep_sampling: str = "logit_normal", logit_mean: float = 0.0, logit_std: float = 1.0, - mode_scale: float = 1.29, + flow_shift: float = 1.0, # FLUX uses shift=1.0 typically scheduler_steps: int = 1000, guidance_scale: float = 3.5, + use_loss_weighting: bool = False, # FLUX typically doesn't use loss weighting ): - self.timestep_sampling = timestep_sampling - self.logit_mean = logit_mean - self.logit_std = logit_std - self.mode_scale = mode_scale - self.scheduler_steps = scheduler_steps - self.guidance_scale = guidance_scale self.autocast_dtype = torch.bfloat16 - # Initialize scheduler for timestep/sigma computations - self.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=scheduler_steps) + + # Create the FlowMatchingPipeline with Megatron adapter + adapter = MegatronFluxAdapter(guidance_scale=guidance_scale) + + self.pipeline = FlowMatchingPipeline( + model_adapter=adapter, + num_train_timesteps=scheduler_steps, + timestep_sampling=timestep_sampling, + flow_shift=flow_shift, + logit_mean=logit_mean, + logit_std=logit_std, + sigma_min=0.0, + sigma_max=1.0, + use_loss_weighting=use_loss_weighting, + cfg_dropout_prob=0.0, # No CFG dropout in Megatron training + log_interval=100, + summary_log_interval=10, + ) + + logger.info( + f"FluxForwardStep initialized with:\n" + f" - Timestep sampling: {timestep_sampling}\n" + f" - Flow shift: {flow_shift}\n" + f" - Guidance scale: {guidance_scale}\n" + f" - Loss weighting: {use_loss_weighting}" + ) def __call__( self, state: GlobalState, data_iterator: Iterable, model: VisionModule ) -> tuple[torch.Tensor, partial]: - """Forward training step. + """ + Forward training step using FlowMatchingPipeline. Args: state: Global state for the run. @@ -118,7 +146,7 @@ def __call__( timers = state.timers straggler_timer = state.straggler_timer - config = get_model_config(model) + config = get_model_config(model) # noqa: F841 timers("batch-generator", log_level=2).start() @@ -132,289 +160,111 @@ def __call__( check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss - # Run diffusion training step + # Prepare batch for FlowMatchingPipeline + # Map Megatron keys to FlowMatchingPipeline expected keys + pipeline_batch = self._prepare_batch_for_pipeline(batch) + + # Run the pipeline step with straggler_timer: if parallel_state.is_pipeline_last_stage(): - output_tensor, loss, loss_mask = self._training_step(model, batch, config) + output_tensor, loss, loss_mask = self._training_step_with_pipeline(model, pipeline_batch) + # loss_mask is already created correctly in _training_step_with_pipeline batch["loss_mask"] = loss_mask else: - output_tensor = self._training_step(model, batch, config) + # For non-final pipeline stages, we still need to run the model + # but loss computation happens only on the last stage + output_tensor = self._training_step_with_pipeline(model, pipeline_batch) + loss_mask = None - loss = output_tensor - if "loss_mask" not in batch or batch["loss_mask"] is None: - loss_mask = torch.ones_like(loss) - else: - loss_mask = batch["loss_mask"] + # Use the loss_mask from training step (already has correct shape) + if loss_mask is None: + # This should only happen for non-final pipeline stages + loss_mask = torch.ones(1, device="cuda") loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) return output_tensor, loss_function - def _training_step( - self, model: VisionModule, batch: dict, config + def _prepare_batch_for_pipeline(self, batch: dict) -> dict: + """ + Prepare Megatron batch for FlowMatchingPipeline. + + Maps Megatron batch keys to FlowMatchingPipeline expected format: + - latents -> image_latents (for consistency) + - Keeps prompt_embeds, pooled_prompt_embeds, text_ids as-is + """ + pipeline_batch = { + "image_latents": batch["latents"], # Map to FlowMatchingPipeline expected key + "prompt_embeds": batch.get("prompt_embeds"), + "pooled_prompt_embeds": batch.get("pooled_prompt_embeds"), + "text_ids": batch.get("text_ids"), + "data_type": "image", # FLUX is for image generation + } + + # Copy any additional keys + for key in batch: + if key not in pipeline_batch and key != "latents": + pipeline_batch[key] = batch[key] + + return pipeline_batch + + def _training_step_with_pipeline( + self, model: VisionModule, batch: dict ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor: - """Perform single training step with flow matching. + """ + Perform single training step using FlowMatchingPipeline. Args: model: The FLUX model. - batch: Data batch containing latents and text embeddings. - config: Model configuration. + batch: Data batch prepared for pipeline. Returns: On last pipeline stage: tuple of (output_tensor, loss, loss_mask). - On other stages: hidden_states tensor. + On other stages: output tensor. """ - # Get latents from batch - expected in [B, C, H, W] format - if "latents" in batch: - latents = batch["latents"] - else: - raise ValueError("Expected 'latents' in batch. VAE encoding should be done in data preprocessing.") - - # Prepare image latents with flow matching noise - ( - latents, - noise, - packed_noisy_model_input, - latent_image_ids, - guidance_vec, - timesteps, - ) = self.prepare_image_latent(latents, model) - - # Get text embeddings (precached) - if "prompt_embeds" in batch: - prompt_embeds = batch["prompt_embeds"].transpose(0, 1) - pooled_prompt_embeds = batch["pooled_prompt_embeds"] - text_ids = batch["text_ids"] - else: - raise ValueError("Expected precached text embeddings in batch.") - - # Forward pass - with torch.amp.autocast( - "cuda", enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype - ): - noise_pred = model( - img=packed_noisy_model_input, - txt=prompt_embeds, - y=pooled_prompt_embeds, - timesteps=timesteps / 1000, - img_ids=latent_image_ids, - txt_ids=text_ids, - guidance=guidance_vec, + device = torch.device("cuda") + dtype = self.autocast_dtype + + # Pass model in batch so adapter can check for guidance support + batch["_model"] = model + + with torch.amp.autocast("cuda", enabled=dtype in (torch.half, torch.bfloat16), dtype=dtype): + # Run the FlowMatchingPipeline step (global_step defaults to 0) + weighted_loss, average_weighted_loss, loss_mask, metrics = self.pipeline.step( + model=model, + batch=batch, + device=device, + dtype=dtype, ) - # Unpack predictions for loss computation - noise_pred = self._unpack_latents( - noise_pred.transpose(0, 1), - latents.shape[2], - latents.shape[3], - ).transpose(0, 1) + # Clean up temporary model reference + batch.pop("_model", None) - # Flow matching target: v = noise - latents (velocity formulation) - target = noise - latents + if parallel_state.is_pipeline_last_stage(): + # Match original implementation's reduction pattern + # Original does: loss = mse(..., reduction="none"), then output_tensor = mean(loss, dim=-1) + # This keeps most dimensions and only reduces the last one + # But FlowMatchingPipeline returns full loss, so we reduce to match expected shape - # MSE loss - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - output_tensor = torch.mean(loss, dim=-1) + # For FLUX with images: weighted_loss is [B, C, H, W] + # Original pattern: mean over spatial dimensions -> [B, C] or similar + # But Megatron expects a 1D tensor per sample, so reduce to [B] + output_tensor = torch.mean(weighted_loss, dim=list(range(1, weighted_loss.ndim))) - # Create loss mask (all ones for now) + # Always create a fresh loss_mask matching output_tensor shape + # Ignore any loss_mask from batch as it may have incompatible shape loss_mask = torch.ones_like(output_tensor) - return output_tensor, loss, loss_mask - - # else: - # hidden_states = model( - # img=packed_noisy_model_input, - # txt=prompt_embeds, - # y=pooled_prompt_embeds, - # timesteps=timesteps / 1000, - # img_ids=latent_image_ids, - # txt_ids=text_ids, - # guidance=guidance_vec, - # ) - # return hidden_states - - def prepare_image_latent(self, latents: torch.Tensor, model: VisionModule): - """Prepare image latents with flow matching noise. - - Args: - latents: Input latent tensor [B, C, H, W]. - model: The FLUX model (for guidance_embed config). - - Returns: - Tuple of (latents, noise, packed_noisy_input, latent_image_ids, guidance, timesteps). - """ - latent_image_ids = self._prepare_latent_image_ids( - latents.shape[0], - latents.shape[2], - latents.shape[3], - latents.device, - latents.dtype, - ) - - noise = torch.randn_like(latents, device=latents.device, dtype=latents.dtype) - batch_size = latents.shape[0] - u = self.compute_density_for_timestep_sampling( - self.timestep_sampling, - batch_size, - ) - indices = (u * self.scheduler.num_train_timesteps).long() - timesteps = self.scheduler.timesteps[indices].to(device=latents.device) - - sigmas = self.scheduler.sigmas.to(device=latents.device, dtype=latents.dtype) - scheduler_timesteps = self.scheduler.timesteps.to(device=latents.device) - step_indices = [(scheduler_timesteps == t).nonzero().item() for t in timesteps] - timesteps = timesteps.to(dtype=latents.dtype) - sigma = sigmas[step_indices].flatten() - - while len(sigma.shape) < latents.ndim: - sigma = sigma.unsqueeze(-1) - - noisy_model_input = (1.0 - sigma) * latents + sigma * noise - packed_noisy_model_input = self._pack_latents( - noisy_model_input, - batch_size=latents.shape[0], - num_channels_latents=latents.shape[1], - height=latents.shape[2], - width=latents.shape[3], - ) - - # Guidance embedding (for FLUX-dev) - if hasattr(model, "guidance_embed") and model.guidance_embed: - guidance_vec = torch.full( - (noisy_model_input.shape[0],), - self.guidance_scale, - device=latents.device, - dtype=latents.dtype, - ) + return output_tensor, average_weighted_loss, loss_mask else: - guidance_vec = None - - return ( - latents.transpose(0, 1), - noise.transpose(0, 1), - packed_noisy_model_input.transpose(0, 1), - latent_image_ids, - guidance_vec, - timesteps, - ) - - def compute_density_for_timestep_sampling( - self, - weighting_scheme: str, - batch_size: int, - logit_mean: float = None, - logit_std: float = None, - mode_scale: float = None, - ) -> torch.Tensor: - """Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - - Args: - weighting_scheme: Sampling scheme ("logit_normal", "mode", or "uniform"). - batch_size: Number of samples in batch. - logit_mean: Mean for logit-normal sampling. - logit_std: Standard deviation for logit-normal sampling. - mode_scale: Scale for mode sampling. - - Returns: - Tensor of sampled u values in [0, 1]. - """ - # Use instance defaults if not provided - logit_mean = logit_mean if logit_mean is not None else self.logit_mean - logit_std = logit_std if logit_std is not None else self.logit_std - mode_scale = mode_scale if mode_scale is not None else self.mode_scale - - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$) - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device="cpu") - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device="cpu") - return u - - @lru_cache - def _prepare_latent_image_ids( - self, batch_size: int, height: int, width: int, device: torch.device, dtype: torch.dtype - ) -> torch.Tensor: - """Prepare latent image IDs for positional encoding. - - Args: - batch_size: Number of samples. - height: Latent height. - width: Latent width. - device: Target device. - dtype: Target dtype. - - Returns: - Tensor of shape [B, (H/2)*(W/2), 3] with position IDs. - """ - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) - latent_image_ids = latent_image_ids.reshape( - batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype, non_blocking=True) - - def _pack_latents( - self, latents: torch.Tensor, batch_size: int, num_channels_latents: int, height: int, width: int - ) -> torch.Tensor: - """Pack latents for FLUX processing. - - Rearranges [B, C, H, W] -> [B, (H/2)*(W/2), C*4]. - - Args: - latents: Input tensor [B, C, H, W]. - batch_size: Batch size. - num_channels_latents: Number of latent channels. - height: Latent height. - width: Latent width. - - Returns: - Packed tensor [B, num_patches, C*4]. - """ - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - return latents - - def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: - """Unpack latents from FLUX format. - - Rearranges [B, num_patches, C*4] -> [B, C, H, W]. - - Args: - latents: Packed tensor [B, num_patches, C*4]. - height: Target height. - width: Target width. - - Returns: - Unpacked tensor [B, C, H, W]. - """ - batch_size, num_patches, channels = latents.shape - - # Adjust h and w for patching - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // 4, height, width) - - return latents + # For intermediate stages, return the tensor for pipeline communication + return weighted_loss def _create_loss_function( self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool ) -> partial: - """Create a partial loss function with the specified configuration. + """ + Create a partial loss function with the specified configuration. Args: loss_mask: Used to mask out some portions of the loss. diff --git a/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py b/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py deleted file mode 100644 index e2f021d604..0000000000 --- a/src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -FLUX Forward Step with Automodel Pipeline Integration. - -This is a prototype showing how to integrate the automodel FlowMatchingPipeline -into Megatron's training flow, reusing the well-tested flow matching logic. -""" - -import logging -from functools import partial -from typing import Iterable - -import torch -from megatron.core import parallel_state -from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.core.utils import get_model_config - -# Import automodel pipeline components -from megatron.bridge.diffusion.common.flow_matching.flow_matching_pipeline import FlowMatchingPipeline - -# Import MegatronFluxAdapter from flow_matching module -from megatron.bridge.diffusion.models.flux.flow_matching.flux_adapter import MegatronFluxAdapter -from megatron.bridge.training.losses import masked_next_token_loss -from megatron.bridge.training.state import GlobalState - - -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Megatron Forward Step with Automodel Pipeline -# ============================================================================= - - -def flux_data_step(dataloader_iter, store_in_state=False): - """Process batch data for FLUX model. - - Args: - dataloader_iter: Iterator over the dataloader. - store_in_state: If True, store the batch in GlobalState for callbacks. - - Returns: - Processed batch dictionary with tensors moved to CUDA. - """ - batch = next(dataloader_iter) - if isinstance(batch, tuple) and len(batch) == 3: - _batch = batch[0] - else: - _batch = batch - - _batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in _batch.items()} - - if "loss_mask" not in _batch or _batch["loss_mask"] is None: - _batch["loss_mask"] = torch.ones(1, device="cuda") - - # Store batch in state for callbacks (e.g., validation image generation) - if store_in_state: - try: - from megatron.bridge.training.pretrain import get_current_state - - state = get_current_state() - state._last_validation_batch = _batch - except: - pass # If state access fails, silently continue - - return _batch - - -class FluxForwardStepWithAutomodel: - """ - Forward step for FLUX using the automodel FlowMatchingPipeline. - - This class demonstrates how to integrate the well-tested automodel pipeline - into Megatron's training flow, gaining benefits like: - - Unified flow matching implementation - - Better logging and debugging - - Consistent timestep sampling across models - - Easier maintenance - - Args: - timestep_sampling: Method for sampling timesteps ("logit_normal", "uniform", "mode"). - logit_mean: Mean for logit-normal sampling. - logit_std: Standard deviation for logit-normal sampling. - flow_shift: Shift parameter for timestep transformation (default: 1.0 for FLUX). - scheduler_steps: Number of scheduler training steps. - guidance_scale: Guidance scale for FLUX-dev models. - use_loss_weighting: Whether to apply flow-based loss weighting. - """ - - def __init__( - self, - timestep_sampling: str = "logit_normal", - logit_mean: float = 0.0, - logit_std: float = 1.0, - flow_shift: float = 1.0, # FLUX uses shift=1.0 typically - scheduler_steps: int = 1000, - guidance_scale: float = 3.5, - use_loss_weighting: bool = False, # FLUX typically doesn't use loss weighting - ): - self.autocast_dtype = torch.bfloat16 - - # Create the automodel pipeline with Megatron adapter - adapter = MegatronFluxAdapter(guidance_scale=guidance_scale) - - self.pipeline = FlowMatchingPipeline( - model_adapter=adapter, - num_train_timesteps=scheduler_steps, - timestep_sampling=timestep_sampling, - flow_shift=flow_shift, - logit_mean=logit_mean, - logit_std=logit_std, - sigma_min=0.0, - sigma_max=1.0, - use_loss_weighting=use_loss_weighting, - cfg_dropout_prob=0.0, # No CFG dropout in Megatron training - log_interval=100, - summary_log_interval=10, - ) - - logger.info( - f"FluxForwardStepWithAutomodel initialized with:\n" - f" - Timestep sampling: {timestep_sampling}\n" - f" - Flow shift: {flow_shift}\n" - f" - Guidance scale: {guidance_scale}\n" - f" - Loss weighting: {use_loss_weighting}" - ) - - def __call__( - self, state: GlobalState, data_iterator: Iterable, model: VisionModule - ) -> tuple[torch.Tensor, partial]: - """ - Forward training step using automodel pipeline. - - Args: - state: Global state for the run. - data_iterator: Input data iterator. - model: The FLUX model. - - Returns: - Tuple containing the output tensor and the loss function. - """ - timers = state.timers - straggler_timer = state.straggler_timer - - config = get_model_config(model) # noqa: F841 - - timers("batch-generator", log_level=2).start() - - with straggler_timer(bdata=True): - batch = flux_data_step(data_iterator) - # Store batch for validation callbacks (only during evaluation) - if not torch.is_grad_enabled(): - state._last_batch = batch - timers("batch-generator").stop() - - check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss - check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss - - # Prepare batch for automodel pipeline - # Map Megatron keys to automodel expected keys - pipeline_batch = self._prepare_batch_for_pipeline(batch) - - # Run the pipeline step - with straggler_timer: - if parallel_state.is_pipeline_last_stage(): - output_tensor, loss, loss_mask = self._training_step_with_pipeline(model, pipeline_batch) - # loss_mask is already created correctly in _training_step_with_pipeline - batch["loss_mask"] = loss_mask - else: - # For non-final pipeline stages, we still need to run the model - # but loss computation happens only on the last stage - output_tensor = self._training_step_with_pipeline(model, pipeline_batch) - loss_mask = None - - # Use the loss_mask from training step (already has correct shape) - if loss_mask is None: - # This should only happen for non-final pipeline stages - loss_mask = torch.ones(1, device="cuda") - - loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) - - return output_tensor, loss_function - - def _prepare_batch_for_pipeline(self, batch: dict) -> dict: - """ - Prepare Megatron batch for automodel pipeline. - - Maps Megatron batch keys to automodel expected format: - - latents -> image_latents (for consistency) - - Keeps prompt_embeds, pooled_prompt_embeds, text_ids as-is - """ - pipeline_batch = { - "image_latents": batch["latents"], # Map to automodel expected key - "prompt_embeds": batch.get("prompt_embeds"), - "pooled_prompt_embeds": batch.get("pooled_prompt_embeds"), - "text_ids": batch.get("text_ids"), - "data_type": "image", # FLUX is for image generation - } - - # Copy any additional keys - for key in batch: - if key not in pipeline_batch and key != "latents": - pipeline_batch[key] = batch[key] - - return pipeline_batch - - def _training_step_with_pipeline( - self, model: VisionModule, batch: dict - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | torch.Tensor: - """ - Perform single training step using automodel pipeline. - - Args: - model: The FLUX model. - batch: Data batch prepared for pipeline. - - Returns: - On last pipeline stage: tuple of (output_tensor, loss, loss_mask). - On other stages: output tensor. - """ - device = torch.device("cuda") - dtype = self.autocast_dtype - - # Pass model in batch so adapter can check for guidance support - batch["_model"] = model - - with torch.amp.autocast("cuda", enabled=dtype in (torch.half, torch.bfloat16), dtype=dtype): - # Run the automodel pipeline step (global_step defaults to 0) - weighted_loss, average_weighted_loss, loss_mask, metrics = self.pipeline.step( - model=model, - batch=batch, - device=device, - dtype=dtype, - ) - - # Clean up temporary model reference - batch.pop("_model", None) - - if parallel_state.is_pipeline_last_stage(): - # Match original implementation's reduction pattern - # Original does: loss = mse(..., reduction="none"), then output_tensor = mean(loss, dim=-1) - # This keeps most dimensions and only reduces the last one - # But automodel returns full loss, so we reduce to match expected shape - - # For FLUX with images: weighted_loss is [B, C, H, W] - # Original pattern: mean over spatial dimensions -> [B, C] or similar - # But Megatron expects a 1D tensor per sample, so reduce to [B] - output_tensor = torch.mean(weighted_loss, dim=list(range(1, weighted_loss.ndim))) - - # Always create a fresh loss_mask matching output_tensor shape - # Ignore any loss_mask from batch as it may have incompatible shape - loss_mask = torch.ones_like(output_tensor) - - return output_tensor, average_weighted_loss, loss_mask - else: - # For intermediate stages, return the tensor for pipeline communication - return weighted_loss - - def _create_loss_function( - self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool - ) -> partial: - """ - Create a partial loss function with the specified configuration. - - Args: - loss_mask: Used to mask out some portions of the loss. - check_for_nan_in_loss: Whether to check for NaN values in the loss. - check_for_spiky_loss: Whether to check for spiky loss values. - - Returns: - A partial function that can be called with output_tensor to compute the loss. - """ - return partial( - masked_next_token_loss, - loss_mask, - check_for_nan_in_loss=check_for_nan_in_loss, - check_for_spiky_loss=check_for_spiky_loss, - ) - - -# ============================================================================= -# Convenience Factory -# ============================================================================= - - -def create_flux_forward_step( - use_automodel_pipeline: bool = True, - **kwargs, -): - """ - Factory function to create either the automodel-based or original forward step. - - Args: - use_automodel_pipeline: If True, use FluxForwardStepWithAutomodel. - If False, use original FluxForwardStep. - **kwargs: Arguments passed to the forward step constructor. - - Returns: - Forward step instance. - - Example: - # Use automodel pipeline - forward_step = create_flux_forward_step( - use_automodel_pipeline=True, - timestep_sampling="logit_normal", - flow_shift=1.0, - ) - - # Use original implementation - forward_step = create_flux_forward_step( - use_automodel_pipeline=False, - timestep_sampling="logit_normal", - ) - """ - if use_automodel_pipeline: - return FluxForwardStepWithAutomodel(**kwargs) - else: - from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep - - return FluxForwardStep(**kwargs) diff --git a/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py b/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py index 419e02392c..e0de7ce590 100644 --- a/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py +++ b/tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator + from megatron.bridge.diffusion.data.flux import flux_energon_datamodule as flux_dm_mod from megatron.bridge.diffusion.data.flux.flux_taskencoder import FluxTaskEncoder @@ -42,6 +44,9 @@ def __init__( def train_dataloader(self): return "train" + def val_dataloader(self): + return "val" + def test_flux_datamodule_config_initialization(monkeypatch): # Patch the symbol used inside flux_energon_datamodule module @@ -69,9 +74,12 @@ def test_flux_datamodule_config_initialization(monkeypatch): assert cfg.dataset.task_encoder.latent_channels == 16 assert cfg.dataset.use_train_split_for_val is True - # build_datasets should return train loader thrice + # build_datasets returns (iter(train_dataloader()), iter(val_dataloader()), iter(val_dataloader())) train, val, test = cfg.build_datasets(context=None) - assert train == "train" and val == "train" and test == "train" + assert isinstance(train, Iterator) and isinstance(val, Iterator) and isinstance(test, Iterator) + assert list(train) == list("train") + assert list(val) == list("val") + assert list(test) == list("val") def test_flux_datamodule_config_with_custom_parameters(monkeypatch): diff --git a/tests/unit_tests/diffusion/model/flux/flow_matching/test_flux_adapter.py b/tests/unit_tests/diffusion/model/flux/flow_matching/test_flux_adapter.py new file mode 100644 index 0000000000..3c859544b4 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/flow_matching/test_flux_adapter.py @@ -0,0 +1,373 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MegatronFluxAdapter.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from megatron.bridge.diffusion.common.flow_matching.adapters.base import FlowMatchingContext +from megatron.bridge.diffusion.models.flux.flow_matching.flux_adapter import MegatronFluxAdapter + + +pytestmark = [pytest.mark.unit] + + +class TestMegatronFluxAdapterInit: + """Test MegatronFluxAdapter initialization.""" + + def test_init_default(self): + """Test initialization with default guidance_scale.""" + adapter = MegatronFluxAdapter() + assert adapter.guidance_scale == 3.5 + + def test_init_custom_guidance_scale(self): + """Test initialization with custom guidance_scale.""" + adapter = MegatronFluxAdapter(guidance_scale=7.5) + assert adapter.guidance_scale == 7.5 + + +class TestMegatronFluxAdapterPackLatents: + """Test _pack_latents.""" + + def test_pack_latents_shape(self): + """Test packed latents have correct shape [B, (H//2)*(W//2), C*4].""" + adapter = MegatronFluxAdapter() + batch_size, channels, height, width = 2, 16, 64, 64 + latents = torch.randn(batch_size, channels, height, width) + + packed = adapter._pack_latents(latents) + + expected_seq = (height // 2) * (width // 2) + expected_channels = channels * 4 + assert packed.shape == (batch_size, expected_seq, expected_channels) + + def test_pack_latents_different_sizes(self): + """Test _pack_latents with different H, W.""" + adapter = MegatronFluxAdapter() + latents = torch.randn(1, 8, 32, 48) + packed = adapter._pack_latents(latents) + assert packed.shape == (1, (32 // 2) * (48 // 2), 8 * 4) + + +class TestMegatronFluxAdapterUnpackLatents: + """Test _unpack_latents.""" + + def test_unpack_latents_shape(self): + """Test unpacked latents have correct shape [B, C, H, W].""" + adapter = MegatronFluxAdapter() + batch_size, height, width = 2, 64, 64 + num_patches = (height // 2) * (width // 2) + channels_packed = 16 * 4 + packed = torch.randn(batch_size, num_patches, channels_packed) + + unpacked = adapter._unpack_latents(packed, height, width) + + assert unpacked.shape == (batch_size, 16, height, width) + + def test_pack_unpack_roundtrip(self): + """Test that pack then unpack recovers original shape and values.""" + adapter = MegatronFluxAdapter() + batch_size, channels, height, width = 2, 16, 64, 64 + original = torch.randn(batch_size, channels, height, width) + + packed = adapter._pack_latents(original) + unpacked = adapter._unpack_latents(packed, height, width) + + assert unpacked.shape == original.shape + assert torch.allclose(unpacked, original) + + +class TestMegatronFluxAdapterPrepareLatentImageIds: + """Test _prepare_latent_image_ids.""" + + def test_prepare_latent_image_ids_shape(self): + """Test output shape [B, (H//2)*(W//2), 3] with (col0, y, x); implementation uses zeros for col0.""" + adapter = MegatronFluxAdapter() + batch_size, height, width = 2, 64, 64 + device = torch.device("cpu") + dtype = torch.float32 + + ids = adapter._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + expected_seq = (height // 2) * (width // 2) + assert ids.shape == (batch_size, expected_seq, 3) + assert ids.device == device + assert ids.dtype == dtype + # Implementation only sets columns 1 (y) and 2 (x); column 0 stays 0 + assert ids[0, 0, 0] == 0 + assert ids[1, 0, 0] == 0 + # y, x indices + assert ids[0, 0, 1] == 0 + assert ids[0, 0, 2] == 0 + + def test_prepare_latent_image_ids_device_dtype(self): + """Test device and dtype are applied.""" + adapter = MegatronFluxAdapter() + ids = adapter._prepare_latent_image_ids(1, 32, 32, torch.device("cpu"), torch.float64) + assert ids.dtype == torch.float64 + + +def _make_context( + batch_size=2, + height=64, + width=64, + channels=16, + text_seq_len=77, + text_dim=4096, + device=None, + dtype=torch.float32, + cfg_dropout_prob=0.0, + batch_extras=None, +): + """Build a FlowMatchingContext for tests.""" + device = device or torch.device("cpu") + batch = { + "prompt_embeds": torch.randn(batch_size, text_seq_len, text_dim), + "pooled_prompt_embeds": torch.randn(batch_size, 768), + **(batch_extras or {}), + } + return FlowMatchingContext( + noisy_latents=torch.randn(batch_size, channels, height, width), + latents=torch.randn(batch_size, channels, height, width), + timesteps=torch.randint(0, 1000, (batch_size,)).float(), + sigma=torch.ones(batch_size), + task_type="t2v", + data_type="image", + device=device, + dtype=dtype, + batch=batch, + cfg_dropout_prob=cfg_dropout_prob, + ) + + +class TestMegatronFluxAdapterPrepareInputs: + """Test prepare_inputs.""" + + def test_prepare_inputs_keys_and_shapes(self): + """Test prepare_inputs returns expected keys and sequence-first layout.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(batch_size=2, height=64, width=64) + + inputs = adapter.prepare_inputs(ctx) + + # Sequence-first: seq dim first + seq_len = (64 // 2) * (64 // 2) + assert inputs["img"].shape == (seq_len, 2, 16 * 4) + assert inputs["txt"].shape == (77, 2, 4096) + assert inputs["y"].shape == (2, 768) + assert inputs["timesteps"].shape == (2,) + assert inputs["img_ids"].shape == (2, seq_len, 3) + assert inputs["txt_ids"].shape == (2, 77, 3) + assert inputs["_original_shape"] == (2, 16, 64, 64) + assert (inputs["timesteps"] >= 0).all() and (inputs["timesteps"] <= 1).all() + + def test_prepare_inputs_prompt_embeds_sb_layout(self): + """Test prepare_inputs when prompt_embeds are [S, B, D] (Megatron layout).""" + adapter = MegatronFluxAdapter() + batch_size, text_seq_len, text_dim = 2, 77, 4096 + # [S, B, D] + prompt_embeds = torch.randn(text_seq_len, batch_size, text_dim) + ctx = _make_context( + batch_size=batch_size, + batch_extras={"prompt_embeds": prompt_embeds}, + ) + + inputs = adapter.prepare_inputs(ctx) + + # Should be transposed to [S, B, D] for model (sequence-first) + assert inputs["txt"].shape == (text_seq_len, batch_size, text_dim) + + def test_prepare_inputs_missing_prompt_embeds_raises(self): + """Test prepare_inputs raises when prompt_embeds not in batch.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(batch_extras={}) + ctx.batch.pop("prompt_embeds", None) + + with pytest.raises(ValueError, match="Expected 'prompt_embeds' in batch"): + adapter.prepare_inputs(ctx) + + def test_prepare_inputs_non_4d_latents_raises(self): + """Test prepare_inputs raises when noisy_latents are not 4D.""" + adapter = MegatronFluxAdapter() + ctx = _make_context() + ctx.noisy_latents = torch.randn(2, 16, 64) # 3D + + with pytest.raises(ValueError, match="expects 4D latents"): + adapter.prepare_inputs(ctx) + + def test_prepare_inputs_default_pooled_embeds(self): + """Test prepare_inputs uses zeros when pooled_prompt_embeds missing.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(batch_extras={}) + ctx.batch.pop("pooled_prompt_embeds", None) + + inputs = adapter.prepare_inputs(ctx) + + assert inputs["y"].shape == (2, 768) + assert torch.all(inputs["y"] == 0) + + def test_prepare_inputs_text_ids_from_batch(self): + """Test prepare_inputs uses text_ids from batch when present.""" + adapter = MegatronFluxAdapter() + text_ids = torch.randn(2, 77, 3) + ctx = _make_context(batch_extras={"text_ids": text_ids}) + + inputs = adapter.prepare_inputs(ctx) + + assert "txt_ids" in inputs + assert inputs["txt_ids"].shape == (2, 77, 3) + + def test_prepare_inputs_cfg_dropout_zero_never_drops(self): + """Test with cfg_dropout_prob=0 text is never zeroed.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(cfg_dropout_prob=0.0) + + inputs = adapter.prepare_inputs(ctx) + + assert not torch.all(inputs["txt"] == 0) + assert not torch.all(inputs["y"] == 0) + + def test_prepare_inputs_cfg_dropout_one_always_drops(self): + """Test with cfg_dropout_prob=1.0 text and pooled are zeroed.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(cfg_dropout_prob=1.0) + + inputs = adapter.prepare_inputs(ctx) + + assert torch.all(inputs["txt"] == 0) + assert torch.all(inputs["y"] == 0) + + def test_prepare_inputs_guidance_when_model_has_guidance_embed(self): + """Test guidance key is set when model has guidance_embed=True.""" + adapter = MegatronFluxAdapter(guidance_scale=5.0) + + # Use a plain object for unwrapped model so hasattr(unwrapped, "module") is False + # (MagicMock would have .module and the unwrap loop would go to None) + class UnwrappedModel: + guidance_embed = True + + wrapper = MagicMock() + wrapper.module = UnwrappedModel() + ctx = _make_context(batch_extras={"_model": wrapper}) + + inputs = adapter.prepare_inputs(ctx) + + assert "guidance" in inputs + assert inputs["guidance"].shape == (2,) + assert (inputs["guidance"] == 5.0).all() + + def test_prepare_inputs_no_guidance_without_model(self): + """Test guidance is not in inputs when batch has no _model.""" + adapter = MegatronFluxAdapter() + ctx = _make_context(batch_extras={}) + + inputs = adapter.prepare_inputs(ctx) + + assert "guidance" not in inputs + + def test_prepare_inputs_no_guidance_when_guidance_embed_false(self): + """Test guidance is not in inputs when model has guidance_embed=False.""" + adapter = MegatronFluxAdapter() + + # Unwrapped model with guidance_embed=False + class UnwrappedModel: + guidance_embed = False + + wrapper = MagicMock() + wrapper.module = UnwrappedModel() + ctx = _make_context(batch_extras={"_model": wrapper}) + + inputs = adapter.prepare_inputs(ctx) + + assert "guidance" not in inputs + + +class TestMegatronFluxAdapterForward: + """Test forward.""" + + def test_forward_returns_unpacked_shape(self): + """Test forward returns [B, C, H, W] and uses _original_shape.""" + adapter = MegatronFluxAdapter() + batch_size, channels, height, width = 2, 16, 64, 64 + seq_len = (height // 2) * (width // 2) + # Simulate prepare_inputs output (without popping _original_shape) + inputs = { + "img": torch.randn(seq_len, batch_size, channels * 4), + "txt": torch.randn(77, batch_size, 4096), + "y": torch.randn(batch_size, 768), + "timesteps": torch.rand(batch_size), + "img_ids": torch.zeros(batch_size, seq_len, 3), + "txt_ids": torch.zeros(batch_size, 77, 3), + "_original_shape": (batch_size, channels, height, width), + } + # Model returns [S, B, D] (sequence-first) + model = MagicMock() + model.return_value = torch.randn(seq_len, batch_size, channels * 4) + + out = adapter.forward(model, inputs) + + assert out.shape == (batch_size, channels, height, width) + model.assert_called_once() + call_kw = model.call_args[1] + assert call_kw.get("guidance") is None + + def test_forward_handles_tuple_output(self): + """Test forward uses first element when model returns tuple.""" + adapter = MegatronFluxAdapter() + batch_size, channels, height, width = 2, 16, 64, 64 + seq_len = (height // 2) * (width // 2) + pred = torch.randn(seq_len, batch_size, channels * 4) + inputs = { + "img": torch.randn(seq_len, batch_size, channels * 4), + "txt": torch.randn(77, batch_size, 4096), + "y": torch.randn(batch_size, 768), + "timesteps": torch.rand(batch_size), + "img_ids": torch.zeros(batch_size, seq_len, 3), + "txt_ids": torch.zeros(batch_size, 77, 3), + "_original_shape": (batch_size, channels, height, width), + } + model = MagicMock() + model.return_value = (pred, None) + + out = adapter.forward(model, inputs) + + assert out.shape == (batch_size, channels, height, width) + + def test_forward_passes_guidance_when_present(self): + """Test forward passes guidance to model when in inputs.""" + adapter = MegatronFluxAdapter() + batch_size, channels, height, width = 2, 16, 64, 64 + seq_len = (height // 2) * (width // 2) + inputs = { + "img": torch.randn(seq_len, batch_size, channels * 4), + "txt": torch.randn(77, batch_size, 4096), + "y": torch.randn(batch_size, 768), + "timesteps": torch.rand(batch_size), + "img_ids": torch.zeros(batch_size, seq_len, 3), + "txt_ids": torch.zeros(batch_size, 77, 3), + "guidance": torch.full((batch_size,), 3.5), + "_original_shape": (batch_size, channels, height, width), + } + model = MagicMock() + model.return_value = torch.randn(seq_len, batch_size, channels * 4) + + adapter.forward(model, inputs) + + call_kw = model.call_args[1] + assert "guidance" in call_kw + assert call_kw["guidance"] is not None + assert call_kw["guidance"].shape == (batch_size,) diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_attention.py b/tests/unit_tests/diffusion/model/flux/test_flux_attention.py new file mode 100644 index 0000000000..b021a3bd28 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_attention.py @@ -0,0 +1,152 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FLUX attention modules.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from megatron.bridge.diffusion.models.flux.flux_attention import ( + JointSelfAttention, + JointSelfAttentionSubmodules, +) + + +pytestmark = [pytest.mark.unit] + + +class TestJointSelfAttentionSubmodules: + """Test JointSelfAttentionSubmodules dataclass.""" + + def test_default_instantiation(self): + """Test submodules can be created with defaults (all None).""" + sub = JointSelfAttentionSubmodules() + assert sub.linear_qkv is None + assert sub.added_linear_qkv is None + assert sub.core_attention is None + assert sub.linear_proj is None + assert sub.q_layernorm is None + assert sub.k_layernorm is None + assert sub.added_q_layernorm is None + assert sub.added_k_layernorm is None + + def test_custom_instantiation(self): + """Test submodules can be created with custom types.""" + linear_cls = MagicMock() + sub = JointSelfAttentionSubmodules( + linear_qkv=linear_cls, + added_linear_qkv=linear_cls, + core_attention=MagicMock(), + linear_proj=linear_cls, + q_layernorm=MagicMock(), + k_layernorm=MagicMock(), + ) + assert sub.linear_qkv is linear_cls + assert sub.added_linear_qkv is linear_cls + assert sub.linear_proj is linear_cls + assert sub.q_layernorm is not None + assert sub.k_layernorm is not None + assert sub.added_q_layernorm is None + assert sub.added_k_layernorm is None + + +class TestJointSelfAttentionSplitQkv: + """Test JointSelfAttention._split_qkv logic in isolation.""" + + def test_split_qkv_output_shapes(self): + """Test _split_qkv splits mixed QKV into Q, K, V with correct shapes.""" + # Use attributes that match a typical FLUX config: 4 heads, 4 groups, head_dim 64 + num_query_groups_per_partition = 4 + num_attention_heads_per_partition = 4 + hidden_size_per_attention_head = 64 + + # mixed_qkv last dim = ng * (np/ng + 2) * hn = 4 * (1 + 2) * 64 = 768 + q_per_group = num_attention_heads_per_partition // num_query_groups_per_partition + mixed_qkv_last_dim = num_query_groups_per_partition * (q_per_group + 2) * hidden_size_per_attention_head + assert mixed_qkv_last_dim == 768 + + sq, b = 8, 2 + mixed_qkv = torch.randn(sq, b, mixed_qkv_last_dim) + + # Bind _split_qkv to a mock that has the required attributes + receiver = MagicMock() + receiver.num_query_groups_per_partition = num_query_groups_per_partition + receiver.num_attention_heads_per_partition = num_attention_heads_per_partition + receiver.hidden_size_per_attention_head = hidden_size_per_attention_head + split_qkv = JointSelfAttention._split_qkv.__get__(receiver, JointSelfAttention) + + query, key, value = split_qkv(mixed_qkv) + + # query: [sq, b, np, hn] = [8, 2, 4, 64] + assert query.shape == (sq, b, num_attention_heads_per_partition, hidden_size_per_attention_head) + # key, value: [sq, b, ng, hn] = [8, 2, 4, 64] + assert key.shape == (sq, b, num_query_groups_per_partition, hidden_size_per_attention_head) + assert value.shape == (sq, b, num_query_groups_per_partition, hidden_size_per_attention_head) + + def test_split_qkv_with_gqa(self): + """Test _split_qkv with grouped query (num_heads > num_groups).""" + num_query_groups_per_partition = 2 + num_attention_heads_per_partition = 4 + hidden_size_per_attention_head = 32 + + q_per_group = num_attention_heads_per_partition // num_query_groups_per_partition # 2 + mixed_qkv_last_dim = num_query_groups_per_partition * (q_per_group + 2) * hidden_size_per_attention_head + # 2 * (2+2) * 32 = 256 + assert mixed_qkv_last_dim == 256 + + mixed_qkv = torch.randn(4, 1, mixed_qkv_last_dim) + receiver = MagicMock() + receiver.num_query_groups_per_partition = num_query_groups_per_partition + receiver.num_attention_heads_per_partition = num_attention_heads_per_partition + receiver.hidden_size_per_attention_head = hidden_size_per_attention_head + split_qkv = JointSelfAttention._split_qkv.__get__(receiver, JointSelfAttention) + + query, key, value = split_qkv(mixed_qkv) + + assert query.shape == (4, 1, 4, 32) + assert key.shape == (4, 1, 2, 32) + assert value.shape == (4, 1, 2, 32) + + +class TestFluxSingleAttentionRotaryPosEmb: + """Test FluxSingleAttention rotary_pos_emb handling (code path).""" + + def test_rotary_pos_emb_single_wrapped_to_tuple(self): + """Test that single rotary_pos_emb is duplicated to (emb, emb) in forward.""" + # We only verify the logic: when rotary_pos_emb is not a tuple, it becomes (rotary_pos_emb,) * 2. + # This is the same logic used in FluxSingleAttention.forward and JointSelfAttention.forward. + rotary_pos_emb = torch.randn(1, 2, 32) + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + wrapped = (rotary_pos_emb,) * 2 + else: + wrapped = rotary_pos_emb + assert isinstance(wrapped, tuple) + assert len(wrapped) == 2 + assert wrapped[0] is rotary_pos_emb + assert wrapped[1] is rotary_pos_emb + + def test_rotary_pos_emb_tuple_unchanged(self): + """Test that tuple rotary_pos_emb is left as-is.""" + q_emb = torch.randn(1, 2, 32) + k_emb = torch.randn(1, 2, 32) + rotary_pos_emb = (q_emb, k_emb) + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + wrapped = (rotary_pos_emb,) * 2 + else: + wrapped = rotary_pos_emb + assert wrapped is rotary_pos_emb + assert wrapped[0] is q_emb + assert wrapped[1] is k_emb diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_model.py b/tests/unit_tests/diffusion/model/flux/test_flux_model.py new file mode 100644 index 0000000000..f033e9df14 --- /dev/null +++ b/tests/unit_tests/diffusion/model/flux/test_flux_model.py @@ -0,0 +1,337 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Flux model.""" + +from contextlib import nullcontext + +import pytest +import torch +from megatron.core import parallel_state +from torch import nn + +from megatron.bridge.diffusion.models.flux.flux_model import Flux +from megatron.bridge.diffusion.models.flux.flux_provider import FluxProvider + + +pytestmark = [pytest.mark.unit] + + +# Dummy blocks so we can build Flux without Transformer Engine (TE) or distributed init. +# They accept the same constructor/forward args as the real layers and preserve shapes. + + +class DummyMMDiTLayer(nn.Module): + """Pass-through double block; same forward signature as MMDiTLayer.""" + + def __init__(self, config=None, submodules=None, layer_number=0, context_pre_only=False): + super().__init__() + self.layer_number = layer_number + + def _get_layer_offset(self, config): + return 0 + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + return {} + + def forward(self, hidden_states, encoder_hidden_states, rotary_pos_emb, emb): + return hidden_states, encoder_hidden_states + + +class DummyFluxSingleTransformerBlock(nn.Module): + """Pass-through single block; same forward signature as FluxSingleTransformerBlock.""" + + def __init__(self, config=None, submodules=None, layer_number=0): + super().__init__() + self.layer_number = layer_number + + def _get_layer_offset(self, config): + return 0 + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): + return {} + + def forward(self, hidden_states, rotary_pos_emb, emb): + return hidden_states, None + + +def _mock_flux_layers(monkeypatch): + """Replace TE-dependent Flux layers with dummies so Flux can be built without TE.""" + import megatron.bridge.diffusion.models.flux.flux_model as flux_model_module + + monkeypatch.setattr(flux_model_module, "MMDiTLayer", DummyMMDiTLayer, raising=False) + monkeypatch.setattr( + flux_model_module, "FluxSingleTransformerBlock", DummyFluxSingleTransformerBlock, raising=False + ) + + +def _mock_parallel_state(monkeypatch): + """Mock parallel_state and Flux TE-dependent layers so Flux can be built without TE/distributed init.""" + _mock_flux_layers(monkeypatch) + monkeypatch.setattr(parallel_state, "is_pipeline_first_stage", lambda: True, raising=False) + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: True, raising=False) + monkeypatch.setattr(parallel_state, "get_tensor_model_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_pipeline_model_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_data_parallel_world_size", lambda *args, **kwargs: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) + monkeypatch.setattr(parallel_state, "get_tensor_model_parallel_group", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_data_parallel_group", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_context_parallel_group", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr( + parallel_state, "get_tensor_and_data_parallel_group", lambda *args, **kwargs: None, raising=False + ) + monkeypatch.setattr( + parallel_state, "get_pipeline_model_parallel_group", lambda *args, **kwargs: None, raising=False + ) + monkeypatch.setattr(parallel_state, "model_parallel_is_initialized", lambda: False, raising=False) + # Additional mocks for layer/TE code that may call with check_initialized=False + monkeypatch.setattr(parallel_state, "get_embedding_group", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_position_embedding_group", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr(parallel_state, "get_amax_reduction_group", lambda *args, **kwargs: None, raising=False) + + +def _minimal_flux_provider( + num_joint_layers=1, + num_single_layers=1, + hidden_size=64, + in_channels=16, + context_dim=64, + vec_in_dim=32, + guidance_embed=False, + **kwargs, +): + """FluxProvider with minimal sizes for fast unit tests.""" + return FluxProvider( + num_layers=1, + num_joint_layers=num_joint_layers, + num_single_layers=num_single_layers, + hidden_size=hidden_size, + ffn_hidden_size=hidden_size * 4, + num_attention_heads=4, + kv_channels=hidden_size // 4, + num_query_groups=4, + in_channels=in_channels, + context_dim=context_dim, + model_channels=32, + vec_in_dim=vec_in_dim, + patch_size=2, + guidance_embed=guidance_embed, + axes_dims_rope=[8, 4, 4], + **kwargs, + ) + + +class TestFluxInit: + """Test Flux model initialization.""" + + def test_flux_init_from_provider(self, monkeypatch): + """Test Flux can be built from FluxProvider with minimal config.""" + _mock_parallel_state(monkeypatch) + _mock_flux_layers(monkeypatch) + provider = _minimal_flux_provider() + + model = Flux(config=provider) + + assert model.config is provider + assert model.hidden_size == provider.hidden_size + assert model.num_attention_heads == provider.num_attention_heads + assert model.in_channels == provider.in_channels + assert model.out_channels == provider.in_channels + assert model.patch_size == provider.patch_size + assert model.guidance_embed is provider.guidance_embed + assert model.pre_process is True + assert model.post_process is True + assert len(model.double_blocks) == provider.num_joint_layers + assert len(model.single_blocks) == provider.num_single_layers + assert hasattr(model, "img_embed") + assert hasattr(model, "txt_embed") + assert hasattr(model, "timestep_embedding") + assert hasattr(model, "vector_embedding") + assert hasattr(model, "pos_embed") + assert hasattr(model, "norm_out") + assert hasattr(model, "proj_out") + + def test_flux_init_with_guidance_embed(self, monkeypatch): + """Test Flux has guidance_embedding when guidance_embed=True.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider(guidance_embed=True) + + model = Flux(config=provider) + + assert model.guidance_embed is True + assert hasattr(model, "guidance_embedding") + + def test_flux_init_without_guidance_embed(self, monkeypatch): + """Test Flux has no guidance_embedding when guidance_embed=False.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider(guidance_embed=False) + + model = Flux(config=provider) + + assert model.guidance_embed is False + assert not hasattr(model, "guidance_embedding") + + +class TestFluxGetFp8Context: + """Test get_fp8_context.""" + + def test_get_fp8_context_when_fp8_disabled(self, monkeypatch): + """Test get_fp8_context returns nullcontext when fp8 is not set.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider() + model = Flux(config=provider) + assert getattr(provider, "fp8", None) in (None, False, "") + + ctx = model.get_fp8_context() + + assert isinstance(ctx, nullcontext) + + def test_get_fp8_context_when_fp8_false(self, monkeypatch): + """Test get_fp8_context returns nullcontext when config.fp8 is False.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider() + provider.fp8 = False + model = Flux(config=provider) + + ctx = model.get_fp8_context() + + assert isinstance(ctx, nullcontext) + + +class TestFluxSetInputTensor: + """Test set_input_tensor (pipeline parallelism hook).""" + + def test_set_input_tensor_no_op(self, monkeypatch): + """Test set_input_tensor is a no-op and does not raise.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider() + model = Flux(config=provider) + + model.set_input_tensor(torch.randn(2, 3, 4)) + model.set_input_tensor(None) + + +class TestFluxForward: + """Test Flux forward pass.""" + + def test_forward_output_shape(self, monkeypatch): + """Test forward returns correct shape [S, B, out_channels] (sequence-first).""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider() + model = Flux(config=provider) + batch_size = 2 + txt_seq_len = 4 + img_seq_len = 8 + # Sequence-first as used by Megatron adapter + img = torch.randn(img_seq_len, batch_size, provider.in_channels) + txt = torch.randn(txt_seq_len, batch_size, provider.context_dim) + y = torch.randn(batch_size, provider.vec_in_dim) + timesteps = torch.rand(batch_size) + img_ids = torch.zeros(batch_size, img_seq_len, 3) + txt_ids = torch.zeros(batch_size, txt_seq_len, 3) + + out = model.forward( + img=img, + txt=txt, + y=y, + timesteps=timesteps, + img_ids=img_ids, + txt_ids=txt_ids, + ) + + # Output is image part only, sequence-first (Flux sets out_channels = config.in_channels) + assert out.shape == ( + img_seq_len, + batch_size, + provider.patch_size * provider.patch_size * provider.in_channels, + ) + + def test_forward_with_guidance(self, monkeypatch): + """Test forward with guidance tensor when guidance_embed=True.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider(guidance_embed=True) + model = Flux(config=provider) + batch_size = 2 + txt_seq_len = 4 + img_seq_len = 8 + img = torch.randn(img_seq_len, batch_size, provider.in_channels) + txt = torch.randn(txt_seq_len, batch_size, provider.context_dim) + y = torch.randn(batch_size, provider.vec_in_dim) + timesteps = torch.rand(batch_size) + img_ids = torch.zeros(batch_size, img_seq_len, 3) + txt_ids = torch.zeros(batch_size, txt_seq_len, 3) + guidance = torch.full((batch_size,), 3.5) + + out = model.forward( + img=img, + txt=txt, + y=y, + timesteps=timesteps, + img_ids=img_ids, + txt_ids=txt_ids, + guidance=guidance, + ) + + assert out.shape[0] == img_seq_len + assert out.shape[1] == batch_size + + def test_forward_guidance_none_when_embed_disabled(self, monkeypatch): + """Test forward accepts guidance=None when guidance_embed=False.""" + _mock_parallel_state(monkeypatch) + provider = _minimal_flux_provider(guidance_embed=False) + model = Flux(config=provider) + batch_size = 2 + img = torch.randn(8, batch_size, provider.in_channels) + txt = torch.randn(4, batch_size, provider.context_dim) + y = torch.randn(batch_size, provider.vec_in_dim) + timesteps = torch.rand(batch_size) + img_ids = torch.zeros(batch_size, 8, 3) + txt_ids = torch.zeros(batch_size, 4, 3) + + out = model.forward( + img=img, + txt=txt, + y=y, + timesteps=timesteps, + img_ids=img_ids, + txt_ids=txt_ids, + guidance=None, + ) + + assert out.dim() == 3 + + +class TestFluxShardedStateDict: + """Test sharded_state_dict (requires parallel_state for replica IDs).""" + + def test_sharded_state_dict_returns_dict(self, monkeypatch): + """Test sharded_state_dict returns a dict-like structure.""" + _mock_parallel_state(monkeypatch) + monkeypatch.setattr(parallel_state, "get_tensor_model_parallel_rank", lambda: 0, raising=False) + monkeypatch.setattr(parallel_state, "get_virtual_pipeline_model_parallel_rank", lambda: None, raising=False) + monkeypatch.setattr( + parallel_state, "get_virtual_pipeline_model_parallel_world_size", lambda: None, raising=False + ) + monkeypatch.setattr(parallel_state, "get_pipeline_model_parallel_rank", lambda: 0, raising=False) + monkeypatch.setattr( + parallel_state, "get_data_parallel_rank", lambda with_context_parallel=False: 0, raising=False + ) + provider = _minimal_flux_provider() + model = Flux(config=provider) + + result = model.sharded_state_dict(prefix="", sharded_offsets=(), metadata=None) + + assert isinstance(result, dict) + # Should contain keys for double_blocks, single_blocks, and other modules + assert len(result) > 0 diff --git a/tests/unit_tests/diffusion/model/flux/test_flux_step.py b/tests/unit_tests/diffusion/model/flux/test_flux_step.py index 4deefa8356..5784c34021 100644 --- a/tests/unit_tests/diffusion/model/flux/test_flux_step.py +++ b/tests/unit_tests/diffusion/model/flux/test_flux_step.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import partial -from unittest.mock import MagicMock import pytest import torch @@ -95,14 +94,15 @@ def test_initialization_defaults(self): """Test FluxForwardStep initialization with default values.""" step = FluxForwardStep() - assert step.timestep_sampling == "logit_normal" - assert step.logit_mean == 0.0 - assert step.logit_std == 1.0 - assert step.mode_scale == 1.29 - assert step.scheduler_steps == 1000 - assert step.guidance_scale == 3.5 assert step.autocast_dtype == torch.bfloat16 - assert hasattr(step, "scheduler") + assert hasattr(step, "pipeline") + # Pipeline holds timestep/config; check pipeline was created with defaults + assert step.pipeline.timestep_sampling == "logit_normal" + assert step.pipeline.flow_shift == 1.0 + assert step.pipeline.logit_mean == 0.0 + assert step.pipeline.logit_std == 1.0 + assert step.pipeline.num_train_timesteps == 1000 + assert step.pipeline.model_adapter.guidance_scale == 3.5 def test_initialization_custom(self): """Test FluxForwardStep initialization with custom values.""" @@ -110,207 +110,71 @@ def test_initialization_custom(self): timestep_sampling="uniform", logit_mean=1.0, logit_std=2.0, - mode_scale=1.5, + flow_shift=1.5, scheduler_steps=500, guidance_scale=7.5, ) - assert step.timestep_sampling == "uniform" - assert step.logit_mean == 1.0 - assert step.logit_std == 2.0 - assert step.mode_scale == 1.5 - assert step.scheduler_steps == 500 - assert step.guidance_scale == 7.5 + assert step.pipeline.timestep_sampling == "uniform" + assert step.pipeline.logit_mean == 1.0 + assert step.pipeline.logit_std == 2.0 + assert step.pipeline.flow_shift == 1.5 + assert step.pipeline.num_train_timesteps == 500 + assert step.pipeline.model_adapter.guidance_scale == 7.5 + def test_initialization_use_loss_weighting(self): + """Test FluxForwardStep with use_loss_weighting=True.""" + step = FluxForwardStep(use_loss_weighting=True) + assert step.pipeline.use_loss_weighting is True -class TestFluxForwardStepTimestepSampling: - """Test timestep sampling methods.""" - def test_compute_density_logit_normal(self): - """Test logit-normal timestep sampling.""" - step = FluxForwardStep(timestep_sampling="logit_normal", logit_mean=0.0, logit_std=1.0) - batch_size = 10 +class TestFluxForwardStepPrepareBatch: + """Test _prepare_batch_for_pipeline.""" - u = step.compute_density_for_timestep_sampling("logit_normal", batch_size) - - assert u.shape == (batch_size,) - assert (u >= 0).all() - assert (u <= 1).all() - - def test_compute_density_mode(self): - """Test mode-based timestep sampling.""" - step = FluxForwardStep(timestep_sampling="mode", mode_scale=1.29) - batch_size = 10 - - u = step.compute_density_for_timestep_sampling("mode", batch_size) - - assert u.shape == (batch_size,) - assert (u >= 0).all() - assert (u <= 1).all() - - def test_compute_density_uniform(self): - """Test uniform timestep sampling.""" - step = FluxForwardStep(timestep_sampling="uniform") - batch_size = 10 - - u = step.compute_density_for_timestep_sampling("uniform", batch_size) - - assert u.shape == (batch_size,) - assert (u >= 0).all() - assert (u <= 1).all() - - def test_compute_density_uses_instance_defaults(self): - """Test that compute_density uses instance defaults when not provided.""" - step = FluxForwardStep(logit_mean=0.5, logit_std=0.8, mode_scale=1.5) - - # Should use instance defaults - u = step.compute_density_for_timestep_sampling("logit_normal", batch_size=5) - - assert u.shape == (5,) - - def test_compute_density_override_defaults(self): - """Test that compute_density can override instance defaults.""" - step = FluxForwardStep(logit_mean=0.0, logit_std=1.0) - - # Override with custom values - u = step.compute_density_for_timestep_sampling("logit_normal", batch_size=5, logit_mean=1.0, logit_std=0.5) - - assert u.shape == (5,) - - -class TestFluxForwardStepLatentOperations: - """Test latent packing/unpacking operations.""" - - def test_pack_latents(self): - """Test _pack_latents method.""" + def test_prepare_batch_maps_keys(self): + """Test that Megatron keys are mapped to pipeline keys.""" step = FluxForwardStep() - batch_size = 2 - num_channels = 16 - height = 64 - width = 64 - - latents = torch.randn(batch_size, num_channels, height, width) - packed = step._pack_latents(latents, batch_size, num_channels, height, width) - - expected_seq_len = (height // 2) * (width // 2) - expected_channels = num_channels * 4 - assert packed.shape == (batch_size, expected_seq_len, expected_channels) - - def test_unpack_latents(self): - """Test _unpack_latents method.""" - step = FluxForwardStep() - batch_size = 2 - num_patches = 1024 # (64 // 2) * (64 // 2) - channels = 64 # 16 * 4 - height = 64 - width = 64 - - packed_latents = torch.randn(batch_size, num_patches, channels) - unpacked = step._unpack_latents(packed_latents, height, width) - - expected_channels = channels // 4 - assert unpacked.shape == (batch_size, expected_channels, height, width) - - def test_pack_unpack_roundtrip(self): - """Test that pack and unpack are consistent.""" - step = FluxForwardStep() - batch_size = 2 - num_channels = 16 - height = 64 - width = 64 + batch = { + "latents": torch.randn(2, 16, 64, 64), + "prompt_embeds": torch.randn(2, 77, 4096), + "pooled_prompt_embeds": torch.randn(2, 768), + "text_ids": torch.zeros(2, 77, 3), + } - original = torch.randn(batch_size, num_channels, height, width) - packed = step._pack_latents(original, batch_size, num_channels, height, width) - unpacked = step._unpack_latents(packed, height, width) + pipeline_batch = step._prepare_batch_for_pipeline(batch) - assert unpacked.shape == original.shape - # Note: Due to the reshape operations, values should be approximately equal - # but the exact comparison might not hold due to floating point operations + assert "image_latents" in pipeline_batch + assert pipeline_batch["image_latents"] is batch["latents"] + assert pipeline_batch["prompt_embeds"] is batch["prompt_embeds"] + assert pipeline_batch["pooled_prompt_embeds"] is batch["pooled_prompt_embeds"] + assert pipeline_batch["text_ids"] is batch["text_ids"] + assert pipeline_batch["data_type"] == "image" + assert "latents" not in pipeline_batch - def test_prepare_latent_image_ids(self): - """Test _prepare_latent_image_ids method.""" - step = FluxForwardStep() - batch_size = 2 - height = 64 - width = 64 - device = torch.device("cpu") - dtype = torch.float32 - - # First call creates the IDs - ids = step._prepare_latent_image_ids(batch_size, height, width, device, dtype) - - expected_seq_len = (height // 2) * (width // 2) - assert ids.shape == (batch_size, expected_seq_len, 3) - assert ids.device == device - assert ids.dtype == dtype - - # Second call should use cache - ids2 = step._prepare_latent_image_ids(batch_size, height, width, device, dtype) - assert ids2.shape == ids.shape - - def test_prepare_latent_image_ids_caching(self): - """Test that _prepare_latent_image_ids uses LRU cache.""" + def test_prepare_batch_extra_keys_copied(self): + """Test that extra batch keys are copied (except latents).""" step = FluxForwardStep() + batch = { + "latents": torch.randn(1, 16, 32, 32), + "custom_key": "value", + } - # Cache should work with same parameters - ids1 = step._prepare_latent_image_ids(2, 64, 64, torch.device("cpu"), torch.float32) - ids2 = step._prepare_latent_image_ids(2, 64, 64, torch.device("cpu"), torch.float32) - - # Should be the same object from cache - assert ids1.data_ptr() == ids2.data_ptr() - + pipeline_batch = step._prepare_batch_for_pipeline(batch) -@pytest.mark.run_only_on("GPU") -class TestFluxForwardStepPrepareImageLatent: - """Test prepare_image_latent method.""" + assert pipeline_batch["custom_key"] == "value" + assert pipeline_batch["image_latents"] is batch["latents"] - def test_prepare_image_latent_basic(self): - """Test prepare_image_latent with basic input.""" + def test_prepare_batch_optional_keys(self): + """Test prepare_batch when optional keys are missing.""" step = FluxForwardStep() - batch_size = 2 - channels = 16 - height = 64 - width = 64 - - latents = torch.randn(batch_size, channels, height, width, device="cuda") - - # Mock model - mock_model = MagicMock() - mock_model.guidance_embed = False - - result = step.prepare_image_latent(latents, mock_model) - - # Unpack result tuple - ret_latents, noise, packed_noisy_input, latent_ids, guidance_vec, timesteps = result + batch = {"latents": torch.randn(1, 16, 32, 32)} - # Check shapes (transposed from [B, ...] to [seq, B, ...] format) - assert ret_latents.shape[1] == batch_size - assert noise.shape[1] == batch_size - assert packed_noisy_input.shape[1] == batch_size - assert latent_ids.shape[0] == batch_size - assert guidance_vec is None - assert timesteps.shape[0] == batch_size + pipeline_batch = step._prepare_batch_for_pipeline(batch) - def test_prepare_image_latent_with_guidance(self): - """Test prepare_image_latent with guidance embedding.""" - step = FluxForwardStep(guidance_scale=7.5) - batch_size = 2 - channels = 16 - height = 64 - width = 64 - - latents = torch.randn(batch_size, channels, height, width, device="cuda") - - # Mock model with guidance - mock_model = MagicMock() - mock_model.guidance_embed = True - - result = step.prepare_image_latent(latents, mock_model) - ret_latents, noise, packed_noisy_input, latent_ids, guidance_vec, timesteps = result - - assert guidance_vec is not None - assert guidance_vec.shape == (batch_size,) - assert torch.all(guidance_vec == 7.5) + assert pipeline_batch["image_latents"] is batch["latents"] + assert pipeline_batch.get("prompt_embeds") is None + assert pipeline_batch.get("pooled_prompt_embeds") is None + assert pipeline_batch.get("text_ids") is None class TestFluxForwardStepLossFunction: @@ -333,106 +197,21 @@ def test_create_loss_function_parameters(self): loss_fn = step._create_loss_function(loss_mask, check_for_nan_in_loss=False, check_for_spiky_loss=True) - # Verify it's a partial with expected arguments assert loss_fn.func.__name__ == "masked_next_token_loss" assert loss_fn.keywords["check_for_nan_in_loss"] is False assert loss_fn.keywords["check_for_spiky_loss"] is True -class TestFluxForwardStepIntegration: - """Integration tests for FluxForwardStep.""" +class TestFluxForwardStepPipelineConfig: + """Test pipeline configuration exposed by FluxForwardStep.""" - def test_timestep_sampling_methods_produce_valid_values(self): - """Test that all timestep sampling methods produce valid u values.""" - batch_size = 100 + def test_pipeline_num_train_timesteps(self): + """Test that pipeline has correct num_train_timesteps.""" + step = FluxForwardStep(scheduler_steps=500) + assert step.pipeline.num_train_timesteps == 500 - for method in ["logit_normal", "mode", "uniform"]: + def test_pipeline_timestep_sampling_options(self): + """Test pipeline accepts different timestep_sampling values.""" + for method in ("logit_normal", "uniform", "mode"): step = FluxForwardStep(timestep_sampling=method) - u = step.compute_density_for_timestep_sampling(method, batch_size) - - assert u.shape == (batch_size,) - assert (u >= 0).all(), f"{method} produced u < 0" - assert (u <= 1).all(), f"{method} produced u > 1" - assert not torch.isnan(u).any(), f"{method} produced NaN values" - - def test_latent_operations_preserve_batch_dimension(self): - """Test that latent operations preserve batch dimension.""" - step = FluxForwardStep() - - for batch_size in [1, 2, 4]: - latents = torch.randn(batch_size, 16, 64, 64) - packed = step._pack_latents(latents, batch_size, 16, 64, 64) - unpacked = step._unpack_latents(packed, 64, 64) - - assert packed.shape[0] == batch_size - assert unpacked.shape[0] == batch_size - - -class TestFluxForwardStepEdgeCases: - """Test edge cases and error handling.""" - - def test_pack_latents_small_dimensions(self): - """Test _pack_latents with small dimensions.""" - step = FluxForwardStep() - latents = torch.randn(1, 4, 4, 4) - - packed = step._pack_latents(latents, 1, 4, 4, 4) - - assert packed.shape == (1, 4, 16) # (4/2) * (4/2) = 4, 4*4 = 16 - - def test_unpack_latents_small_dimensions(self): - """Test _unpack_latents with small dimensions.""" - step = FluxForwardStep() - packed = torch.randn(1, 4, 16) - - unpacked = step._unpack_latents(packed, 4, 4) - - assert unpacked.shape == (1, 4, 4, 4) - - def test_compute_density_mode_with_extreme_scale(self): - """Test mode sampling with extreme scale values.""" - step = FluxForwardStep() - - # Test with very small scale - u_small = step.compute_density_for_timestep_sampling("mode", 10, mode_scale=0.01) - assert (u_small >= 0).all() and (u_small <= 1).all() - - # Test with larger scale - u_large = step.compute_density_for_timestep_sampling("mode", 10, mode_scale=2.0) - assert (u_large >= 0).all() and (u_large <= 1).all() - - def test_prepare_latent_image_ids_different_sizes(self): - """Test _prepare_latent_image_ids with different image sizes.""" - step = FluxForwardStep() - - for height, width in [(32, 32), (64, 64), (128, 128)]: - ids = step._prepare_latent_image_ids(2, height, width, torch.device("cpu"), torch.float32) - - expected_seq_len = (height // 2) * (width // 2) - assert ids.shape == (2, expected_seq_len, 3) - - -class TestFluxForwardStepScheduler: - """Test scheduler integration.""" - - def test_scheduler_initialized_with_correct_steps(self): - """Test that scheduler is initialized with correct number of steps.""" - scheduler_steps = 500 - step = FluxForwardStep(scheduler_steps=scheduler_steps) - - assert step.scheduler.num_train_timesteps == scheduler_steps - assert len(step.scheduler.timesteps) == scheduler_steps - - def test_scheduler_timesteps_in_valid_range(self): - """Test that scheduler timesteps are in valid range.""" - step = FluxForwardStep() - - assert (step.scheduler.timesteps >= 0).all() - assert (step.scheduler.timesteps <= step.scheduler.num_train_timesteps).all() - - def test_scheduler_sigmas_in_valid_range(self): - """Test that scheduler sigmas are in valid range.""" - step = FluxForwardStep() - - assert (step.scheduler.sigmas >= 0).all() - assert (step.scheduler.sigmas <= 1).all() + assert step.pipeline.timestep_sampling == method