Skip to content

Conversation

DavidBert
Copy link

@DavidBert DavidBert commented Oct 9, 2025

This commit adds support for the Photon image generation model:

  • PhotonTransformer2DModel: Core transformer architecture
  • PhotonPipeline: Text-to-image generation pipeline
  • Attention processor updates for Photon-specific attention mechanism
  • Conversion script for loading Photon checkpoints
  • Documentation and tests

Some exemples below with the 512 model fine-tuned on the Alchemist dataset and distilled with PAG

image_10 image_4 image_0 image_1

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This commit adds support for the Photon image generation model:
- PhotonTransformer2DModel: Core transformer architecture
- PhotonPipeline: Text-to-image generation pipeline
- Attention processor updates for Photon-specific attention mechanism
- Conversion script for loading Photon checkpoints
- Documentation and tests
print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

print(f"✓ Saved VAE to {vae_path}")


def download_and_save_text_encoder(output_path: str):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for the Text Encoder.

print("✓ Created scheduler config")


def download_and_save_vae(vae_type: str, output_path: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it's okay to keep this as is. This way, everything is under the same model repo.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.

Also, it would be great to see some samples of Photon!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left a couple more comments. Let's also add the pipeline-level tests.

<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>

Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @stevhliu for a review on the docs.

return xq_out.reshape(*xq.shape).type_as(xq)


class PhotonAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we write it in a fashion similar to

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this suggestion - in particular, I think it would be more in line with other diffusers models implementations to reuse the layers defined in Attention, such as to_q/to_k/to_v, etc. instead of defining them in PhotonBlock (e.g. PhotonBlock.img_qkv_proj), and to keep the entire attention implementation in the PhotonAttnProcessor2_0 class.

Attention supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention, you could potentially add it in there or create a new Attention-style class like Flux does:

class FluxAttention(torch.nn.Module, AttentionModuleMixin):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change and updated both the conversion script and the checkpoints on the hub.

def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support passing prompt embeddings too in case users want to supply them precomputed:

prompt_embeds: Optional[torch.FloatTensor] = None,

Comment on lines 484 to 486
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION)
height = height or default_sample_size
width = width or default_sample_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer this pattern:

height = height or self.default_sample_size * self.vae_scale_factor

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.

Copy link
Member

@sayakpaul sayakpaul Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good point! I think we could make a small utility function in the pipeline class that determines the default resolution given the VAE that's loaded into it? WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, way cleaner! I did it.

@DavidBert
Copy link
Author

Thanks @dg845 and @stevhliu for your last reviews! I updated the PR and hopefully addressed all your suggestions.

@DavidBert DavidBert requested review from dg845 and stevhliu October 16, 2025 09:51
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, docs LGTM

Comment on lines +308 to +310
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to set a meaningful default argument for checkpoint_path (for example, if the model checkpoint has been open-sourced and is available on e.g. HF hub, we could set it as a default)?

from safetensors.torch import save_file


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary for the correct functioning of the conversion script? Can it be safely removed?

Copy link
Collaborator

@dg845 dg845 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are getting import errors, e.g. on from diffusers.pipelines.photon import PhotonPipeline, I think this will probably be fixed by the changes suggested in #12456 (comment).


def forward(
self,
image_latent: Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could the names here be changed similarly to #12456 (comment)? So something like image_latent --> hidden_states, cross_attn_conditioning --> encoder_hidden_states?

Comment on lines +167 to +170
# Apply scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, have you tested Photon with any other attention backends (e.g. Flash Attention, Sage Attention, etc.)? Not a blocker, but if so you could consider refactoring to use dispatch_attention_fn to add support for these backends.

You can look at the Flux attention processor for an example:

hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

See PR #11916 and the attention backend docs for more info.

r"""
Generates 2D patch coordinate indices for a batch of images.
Parameters:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: diffusers uses Args: rather Parameters: in its docstrings, so could Parameters be changed to Args for them? Love the comments / docstrings BTW!

>>> from diffusers import PhotonPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to put in a valid checkpoint path for the example?

from ..test_pipelines_common import PipelineTesterMixin


class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add a corresponding PhotonPipelineSlowTests class where we test whether inference on a full checkpoint is consistent between diffusers and the original code? You can refer to FluxPipelineSlowTests as a reference:

@nightly
@require_big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay to skip it for now IMO since we also don't add it for Qwen.



# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotoEmbedND(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename PhotoEmbedND --> PhotonEmbedND :)

List of embedding dimensions for each axis (each must be even).
"""

def __init__(self, dim: int, theta: int, axes_dim: list[int]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):

I believe the CI will raise an error without this change

| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |

| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these model links expected to be broken for now? I get a 404 for https://huggingface.co/Photoroom/photon-256-t2i-sft currently and see that only the Photoroom/photon-256-t2i model is currently in the Photon collection.

"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PhotonTransformer2DModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add PhotonPipeline to the main __init__? As an example, here is how FluxPipeline is added:

"FluxPipeline",

FluxPipeline,

Also, could you add PhotonTransformer2DModel to the TYPE_CHECKING section of __init__? Here is how FluxTransformer2DModel is added:

FluxTransformer2DModel,

_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add PhotonTransformer2DModel to the TYPE_CHECKING section of the src/diffusers/models/__init__.py file? Here is how FluxTransformer2DModel is added:

FluxTransformer2DModel,

"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["photon"] = ["PhotonPipeline"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add PhotonPipeline to the TYPE_CHECKING section of src/diffusers/pipelines/__init__.py? Here is how FluxPipeline is added:

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! The PR is close to merge, I think the most important things left are to fix the imports (e.g. #12456 (comment)) and other changes to make the CI green :).

Comment on lines 72 to 92
encoder_params = dict(
vocab_size=tokenizer.vocab_size,
hidden_size=8,
intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
layer_types=["full_attention"],
attention_bias=False,
attention_dropout=0.0,
dropout_rate=0.0,
hidden_activation="gelu_pytorch_tanh",
rms_norm_eps=1e-06,
attn_logit_softcapping=50.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=4,
rope_theta=10000.0,
sliding_window=4096,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_params = dict(
vocab_size=tokenizer.vocab_size,
hidden_size=8,
intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
layer_types=["full_attention"],
attention_bias=False,
attention_dropout=0.0,
dropout_rate=0.0,
hidden_activation="gelu_pytorch_tanh",
rms_norm_eps=1e-06,
attn_logit_softcapping=50.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=4,
rope_theta=10000.0,
sliding_window=4096,
)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}

make style/make quality complain about the dict(...) call here and I think it will happier if a dict literal is used instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants