Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
e52b79a
worable code for checkpoint conversion, inferenceing, and training. N…
Mar 4, 2026
788fbe6
with the M-LM dev PR (https://github.com/NVIDIA/Megatron-LM/pull/3759…
Mar 10, 2026
f675c1e
just keep apply_rope_fusion for now'
Mar 10, 2026
3e434a9
fix lint
Mar 10, 2026
5855e76
verify flux ckpt conversion/inference/training
suiyoubi Mar 10, 2026
1060cf8
pre-commit
suiyoubi Mar 11, 2026
8cb593a
fix unit test
Mar 11, 2026
5d83b67
add Wan instructions
Mar 11, 2026
0f968e3
merge migration/dfm
Mar 11, 2026
27ed8f6
fix lint
Mar 11, 2026
069ce05
add dependencies
Mar 12, 2026
2f26e00
uv lock
Mar 12, 2026
171fdc0
fix unit tests
Mar 12, 2026
466cad3
fix unit tests
Mar 12, 2026
6260447
fix lint
Mar 12, 2026
ec29480
Merge branch 'migration/dfm_wan' of https://github.com/NVIDIA-NeMo/Me…
suiyoubi Mar 12, 2026
38843fb
fix ft docstring
suiyoubi Mar 12, 2026
a258371
adding back __subflavor__
Mar 12, 2026
5f2c1a9
Merge branch 'migration/dfm_wan' of https://github.com/NVIDIA-NeMo/Me…
suiyoubi Mar 12, 2026
3886a9e
fix broken test
suiyoubi Mar 12, 2026
20d2b24
add functional tests for Wan, FLUX
Mar 12, 2026
20862cf
remove original fluxstep and use the common flowmatchingpipeline
suiyoubi Mar 12, 2026
a5f6f81
Merge branch 'migration/dfm_wan' of https://github.com/NVIDIA-NeMo/Me…
suiyoubi Mar 12, 2026
07d21a5
fix functional tests for wan, flux; cherry pick for - scrip…
Mar 12, 2026
5e524ba
restore all tests
Mar 12, 2026
bf552c3
Merge remote-tracking branch 'origin/migration/dfm' into migration/df…
Mar 13, 2026
8606064
fix conftest.py, update uv.lock
Mar 13, 2026
0bb18d4
update megatron.core.transformer.custom_layers.transformer_engine to …
Mar 13, 2026
6a78af4
lint
Mar 13, 2026
c5b226f
Merge branch 'migration/dfm_wan' of https://github.com/NVIDIA-NeMo/Me…
suiyoubi Mar 13, 2026
d6ccea3
update uv lock
suiyoubi Mar 13, 2026
a7094aa
fix jit expansion after Pytorch update
Mar 13, 2026
e56b06e
export checkpoint
suiyoubi Mar 13, 2026
cf67933
Merge branch 'migration/dfm_wan' of https://github.com/NVIDIA-NeMo/Me…
suiyoubi Mar 13, 2026
4e2969c
add readme
suiyoubi Mar 13, 2026
7777ea0
grammar
suiyoubi Mar 13, 2026
2d3b598
revert uvlock
suiyoubi Mar 13, 2026
ee2b5cb
Refactor JointSelfAttention to handle context_pre_only condition. Add…
suiyoubi Mar 13, 2026
01b43d3
imporve unit test coverage
suiyoubi Mar 13, 2026
311b3e1
Merge branch 'migration/dfm' of https://github.com/NVIDIA-NeMo/Megatr…
suiyoubi Mar 13, 2026
9253443
merge conflict
suiyoubi Mar 13, 2026
bb1df8e
linter fix
suiyoubi Mar 13, 2026
865d738
test flux only
suiyoubi Mar 13, 2026
30fc6ec
revert cicd
suiyoubi Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions examples/diffusion/recipes/flux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# FLUX Examples

This directory contains example scripts for the FLUX diffusion model (text-to-image) with Megatron-Bridge: checkpoint conversion, inference, pretraining, and fine-tuning.

All commands below assume you run them from the **Megatron-Bridge repository root** unless noted. Use `uv run` when you need the project’s virtualenv (e.g. `uv run python ...`, `uv run torchrun ...`).

## Workspace Configuration

Use a `WORKSPACE` environment variable as the base directory for checkpoints and results. Default is `/workspace`. Override it if needed:

```bash
export WORKSPACE=/your/custom/path
```

Suggested layout:

- `${WORKSPACE}/checkpoints/flux/` – Megatron FLUX checkpoints (after import)
- `${WORKSPACE}/checkpoints/flux_hf/` – Hugging Face FLUX model (download or export)
- `${WORKSPACE}/results/flux/` – Training outputs (pretrain/finetune)

---

## 1. Checkpoint Conversion

The script [conversion/convert_checkpoints.py](conversion/convert_checkpoints.py) converts between Hugging Face (diffusers) and Megatron checkpoint formats.

**Source model:** [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) (or a local clone).

### Download the Hugging Face model (optional)

If you want a local copy before conversion:

```bash
huggingface-cli download black-forest-labs/FLUX.1-dev \
--local-dir ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \
--local-dir-use-symlinks False
```

**Note**: It is recommended to save the checkpoint because we will need to reuse the VAE and text encoders for the inference pipeline later as well.

### Import: Hugging Face → Megatron

Convert a Hugging Face FLUX model to Megatron format:

```bash
uv run python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py import \
--hf-model ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \
--megatron-path ${WORKSPACE}/checkpoints/flux/flux.1-dev
```

The Megatron checkpoint is written under `--megatron-path` (e.g. `.../flux.1-dev/iter_0000000/`). Use that path for inference and fine-tuning.

### Export: Megatron → Hugging Face

Export a Megatron checkpoint back to Hugging Face (e.g. for use in diffusers). You must pass the **reference** HF model (for config and non-DiT components) and the **Megatron iteration directory**:

```bash
uv run python examples/diffusion/recipes/flux/conversion/convert_checkpoints.py export \
--hf-model ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev \
--megatron-path ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \
--hf-path ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev_export
```

**Note:** The exported directory contains only the DiT transformer weights. For a full pipeline (VAE, text encoders, etc.), copy the original HF repo and replace its `transformer` folder with the exported one.

---

## 2. Inference

The script [inference_flux.py](inference_flux.py) runs text-to-image generation with a Megatron-format FLUX checkpoint. You need:

- **FLUX checkpoint:** Megatron DiT (e.g. from the import step above).
- **VAE:** Path to VAE weights (often inside the same HF repo as FLUX, e.g. `transformer` sibling directory or a separate VAE checkpoint).
- **Text encoders:** T5 and CLIP are loaded from Hugging Face by default; you can override with local paths.

### Single prompt (default 1024×1024, 10 steps)

```bash
uv run python examples/diffusion/recipes/flux/inference_flux.py \
--flux_ckpt ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \
--vae_ckpt ${WORKSPACE}/checkpoints/flux_hf/flux.1-dev/vae \
--prompts "a dog holding a sign that says hello world" \
--output_path ./flux_output
```


**VAE path:** If you downloaded FLUX.1-dev with `huggingface-cli`, the VAE is usually in the same repo (e.g. `${WORKSPACE}/checkpoints/flux_hf/flux.1-dev/vae`); use the path to the VAE subfolder or the main repo, depending on how the pipeline expects it.

---

## 3. Pretraining

The script [pretrain_flux.py](pretrain_flux.py) runs FLUX pretraining with the `pretrain_config()` recipe. Configuration can be overridden with Hydra-style CLI keys.

**Recipe:** [megatron.bridge.diffusion.recipes.flux.flux.pretrain_config](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/src/megatron/bridge/diffusion/recipes/flux/flux.py)

### Quick run with mock data (single node, 8 GPUs)

```bash
uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock
```

### With CLI overrides only

```bash
uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock \
model.tensor_model_parallel_size=4 \
train.train_iters=10000 \
optimizer.lr=1e-4
```


### Flow matching options

```bash
uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/pretrain_flux.py --mock \
--timestep-sampling logit_normal \
--flow-shift 1.0 \
--use-loss-weighting
```

Before pretraining with real data, set the dataset in the recipe or in your YAML/CLI (e.g. `data_paths`, dataset blend, and cache paths). For data preprocessing, see the Megatron-Bridge data tutorials.

---

## 4. Fine-Tuning

The script [finetune_flux.py](finetune_flux.py) fine-tunes a pretrained FLUX checkpoint (Megatron format). It loads model weights and resets optimizer and step count; config can be overridden via YAML and CLI as with pretraining.

Point `--load-checkpoint` at the **Megatron checkpoint directory** (either the base dir, e.g. `.../flux.1-dev`, or a specific iteration, e.g. `.../flux.1-dev/iter_0000000`):

```bash
uv run torchrun --nproc_per_node=8 examples/diffusion/recipes/flux/finetune_flux.py \
--load-checkpoint ${WORKSPACE}/checkpoints/flux/flux.1-dev/iter_0000000 \
--mock
```

**Note**: If you pass a path that ends with an `iter_XXXXXXX` directory, the script loads that iteration; otherwise it uses the latest iteration under the given path.

**Note**: Loss might explode if you are using a mock dataset.

---

## Summary: End-to-End Flow

1. **Conversion (HF → Megatron)**
Download FLUX.1-dev (optional), then run the `import` command. Use the created `iter_0000000` path as your Megatron checkpoint.

2. **Inference**
Run [inference_flux.py](inference_flux.py) with `--flux_ckpt` (Megatron `iter_*` path), `--vae_ckpt`, and `--prompts`.

3. **Pretraining**
Run [pretrain_flux.py](pretrain_flux.py) with `--mock` or your data config; optionally use `--config-file` and CLI overrides.

4. **Fine-Tuning**
Run [finetune_flux.py](finetune_flux.py) with `--load-checkpoint` set to a Megatron checkpoint (import or pretrain/finetune output), then `--mock` or your data and overrides.

For more details, see the docstrings in each script and the recipe in `src/megatron/bridge/diffusion/recipes/flux/flux.py`.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def import_hf_to_megatron(
bridge = FluxBridge()
provider = bridge.provider_bridge(hf)
provider.perform_initialization = False
# Finalize config so init_method/output_layer_init_method are set (required by Megatron MLP)
provider.finalize()
megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True)
bridge.load_weights_hf_to_megatron(hf, megatron_models)

Expand Down
103 changes: 27 additions & 76 deletions examples/diffusion/recipes/flux/finetune_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,28 @@
The script loads a pretrained checkpoint and continues training with your custom dataset.
Fine-tuning typically uses lower learning rates and fewer training iterations compared to pretraining.

Forward Step Options:
- Automodel FlowMatchingPipeline (default): Unified flow matching implementation
- Original FluxForwardStep (--use-original-step): Classic implementation

Examples:
Basic usage with checkpoint loading (uses automodel pipeline):
$ torchrun --nproc_per_node=8 finetune_flux.py \
Basic usage with checkpoint loading:
$ uv run torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint --mock

Using original FluxForwardStep:
$ torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint --mock --use-original-step

Using a custom YAML config file:
$ torchrun --nproc_per_node=8 finetune_flux.py \
$ uv run torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint \
--config-file my_custom_config.yaml

Using CLI overrides only:
$ torchrun --nproc_per_node=8 finetune_flux.py \
$ uv run torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint \
model.tensor_model_parallel_size=4 train.train_iters=5000 optimizer.lr=1e-5

Combining YAML and CLI overrides (CLI takes precedence):
$ torchrun --nproc_per_node=8 finetune_flux.py \
$ uv run torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint \
--config-file conf/my_config.yaml \
model.pipeline_dtype=torch.float16 \
train.global_batch_size=512

Using automodel pipeline with custom parameters (automodel is default):
$ torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/pretrained/checkpoint --mock \
--flow-shift=1.0 --use-loss-weighting

Configuration Precedence:
1. Base configuration from pretrain_config() recipe
Expand All @@ -81,7 +69,7 @@

from omegaconf import OmegaConf

from megatron.bridge.diffusion.models.flux.flux_step_with_automodel import create_flux_forward_step
from megatron.bridge.diffusion.models.flux.flux_step import FluxForwardStep
from megatron.bridge.diffusion.recipes.flux.flux import pretrain_config
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.pretrain import pretrain
Expand Down Expand Up @@ -159,22 +147,16 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")

# Forward step implementation choice
parser.add_argument(
"--use-original-step",
action="store_true",
help="Use original FluxForwardStep instead of automodel FlowMatchingPipeline (default)",
)
parser.add_argument(
"--flow-shift",
type=float,
default=1.0,
help="Flow shift parameter (for automodel pipeline)",
help="Flow shift parameter",
)
parser.add_argument(
"--use-loss-weighting",
action="store_true",
help="Use loss weighting (for automodel pipeline)",
help="Use loss weighting",
)

# Parse known args for the script, remaining will be treated as overrides
Expand All @@ -197,31 +179,23 @@ def main() -> None:
and handles type conversions automatically.

Examples of CLI usage:
# Fine-tune with default config and custom learning rate (automodel pipeline is default)
# Fine-tune with default config and custom learning rate
torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/checkpoint --mock optimizer.lr=1e-5

# Use original FluxForwardStep instead of automodel pipeline
torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/checkpoint --mock --use-original-step

# Custom config file with additional overrides
torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/checkpoint \
--config-file my_config.yaml train.train_iters=5000

# Multiple overrides for distributed fine-tuning (uses automodel by default)
# Multiple overrides for distributed fine-tuning
torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/checkpoint --mock \
model.tensor_model_parallel_size=4 \
model.pipeline_model_parallel_size=2 \
train.global_batch_size=512 \
optimizer.lr=5e-6

# Automodel pipeline with custom flow matching parameters
torchrun --nproc_per_node=8 finetune_flux.py \
--load-checkpoint /path/to/checkpoint --mock \
--flow-shift=1.0 --use-loss-weighting
"""
args, cli_overrides = parse_cli_args()

Expand Down Expand Up @@ -337,43 +311,21 @@ def main() -> None:
cfg.checkpoint.load = None # Clear load to ensure pretrained_checkpoint takes precedence
cfg.checkpoint.finetune = True

# Create forward step (configurable: original or automodel pipeline)
# Default is automodel pipeline unless --use-original-step is specified
if not args.use_original_step:
# Use automodel FlowMatchingPipeline
flux_forward_step = create_flux_forward_step(
use_automodel_pipeline=True,
timestep_sampling=args.timestep_sampling,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
flow_shift=args.flow_shift,
scheduler_steps=args.scheduler_steps,
guidance_scale=args.guidance_scale,
use_loss_weighting=args.use_loss_weighting,
)
if get_rank_safe() == 0:
logger.info("=" * 70)
logger.info("✅ Using AUTOMODEL FlowMatchingPipeline")
logger.info(f" Timestep Sampling: {args.timestep_sampling}")
logger.info(f" Flow Shift: {args.flow_shift}")
logger.info(f" Loss Weighting: {args.use_loss_weighting}")
logger.info("=" * 70)
else:
# Use original FluxForwardStep
flux_forward_step = create_flux_forward_step(
use_automodel_pipeline=False,
timestep_sampling=args.timestep_sampling,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
scheduler_steps=args.scheduler_steps,
guidance_scale=args.guidance_scale,
)
if get_rank_safe() == 0:
logger.info("=" * 70)
logger.info("✅ Using ORIGINAL FluxForwardStep")
logger.info(f" Timestep Sampling: {args.timestep_sampling}")
logger.info("=" * 70)
flux_forward_step = FluxForwardStep(
timestep_sampling=args.timestep_sampling,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
flow_shift=args.flow_shift,
scheduler_steps=args.scheduler_steps,
guidance_scale=args.guidance_scale,
use_loss_weighting=args.use_loss_weighting,
)
if get_rank_safe() == 0:
logger.info("=" * 70)
logger.info(f" Timestep Sampling: {args.timestep_sampling}")
logger.info(f" Flow Shift: {args.flow_shift}")
logger.info(f" Loss Weighting: {args.use_loss_weighting}")
logger.info("=" * 70)

# Display final configuration
if get_rank_safe() == 0:
Expand All @@ -394,9 +346,8 @@ def main() -> None:
logger.info(f" mode_scale: {args.mode_scale}")
logger.info(f" scheduler_steps: {args.scheduler_steps}")
logger.info(f" guidance_scale: {args.guidance_scale}")
if not args.use_original_step:
logger.info(f" flow_shift: {args.flow_shift}")
logger.info(f" use_loss_weighting: {args.use_loss_weighting}")
logger.info(f" flow_shift: {args.flow_shift}")
logger.info(f" use_loss_weighting: {args.use_loss_weighting}")

# Start training (fine-tuning)
logger.debug("Starting fine-tuning...")
Expand Down
Loading
Loading