-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Add Photon model and pipeline support #12456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit adds support for the Photon image generation model: - PhotonTransformer2DModel: Core transformer architecture - PhotonPipeline: Text-to-image generation pipeline - Attention processor updates for Photon-specific attention mechanism - Conversion script for loading Photon checkpoints - Documentation and tests
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
print(f"✓ Saved VAE to {vae_path}") | ||
|
||
|
||
def download_and_save_text_encoder(output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for the Text Encoder.
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.
Also, it would be great to see some samples of Photon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a couple more comments. Let's also add the pipeline-level tests.
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/> | ||
</div> | ||
|
||
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @stevhliu for a review on the docs.
return xq_out.reshape(*xq.shape).type_as(xq) | ||
|
||
|
||
class PhotonAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we write it in a fashion similar to
class FluxAttnProcessor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this suggestion - in particular, I think it would be more in line with other diffusers
models implementations to reuse the layers defined in Attention
, such as to_q
/to_k
/to_v
, etc. instead of defining them in PhotonBlock
(e.g. PhotonBlock.img_qkv_proj
), and to keep the entire attention implementation in the PhotonAttnProcessor2_0
class.
Attention
supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention
, you could potentially add it in there or create a new Attention
-style class like Flux does:
class FluxAttention(torch.nn.Module, AttentionModuleMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the change and updated both the conversion script and the checkpoints on the hub.
def __call__( | ||
self, | ||
prompt: Union[str, List[str]] = None, | ||
height: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support passing prompt embeddings too in case users want to supply them precomputed:
prompt_embeds: Optional[torch.FloatTensor] = None, |
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) | ||
height = height or default_sample_size | ||
width = width or default_sample_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer this pattern:
height = height or self.default_sample_size * self.vae_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good point! I think we could make a small utility function in the pipeline class that determines the default resolution given the VAE that's loaded into it? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, way cleaner! I did it.
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, docs LGTM
parser.add_argument( | ||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to set a meaningful default argument for checkpoint_path
(for example, if the model checkpoint has been open-sourced and is available on e.g. HF hub, we could set it as a default)?
from safetensors.torch import save_file | ||
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary for the correct functioning of the conversion script? Can it be safely removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you are getting import errors, e.g. on from diffusers.pipelines.photon import PhotonPipeline
, I think this will probably be fixed by the changes suggested in #12456 (comment).
|
||
def forward( | ||
self, | ||
image_latent: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could the names here be changed similarly to #12456 (comment)? So something like image_latent
--> hidden_states
, cross_attn_conditioning
--> encoder_hidden_states
?
# Apply scaled dot-product attention | ||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, have you tested Photon with any other attention backends (e.g. Flash Attention, Sage Attention, etc.)? Not a blocker, but if so you could consider refactoring to use dispatch_attention_fn
to add support for these backends.
You can look at the Flux attention processor for an example:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 118 to 125 in dbe4136
hidden_states = dispatch_attention_fn( | |
query, | |
key, | |
value, | |
attn_mask=attention_mask, | |
backend=self._attention_backend, | |
parallel_config=self._parallel_config, | |
) |
See PR #11916 and the attention backend docs for more info.
r""" | ||
Generates 2D patch coordinate indices for a batch of images. | ||
Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: diffusers
uses Args:
rather Parameters:
in its docstrings, so could Parameters
be changed to Args
for them? Love the comments / docstrings BTW!
>>> from diffusers import PhotonPipeline | ||
>>> # Load pipeline with from_pretrained | ||
>>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to put in a valid checkpoint path for the example?
from ..test_pipelines_common import PipelineTesterMixin | ||
|
||
|
||
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add a corresponding PhotonPipelineSlowTests
class where we test whether inference on a full checkpoint is consistent between diffusers
and the original code? You can refer to FluxPipelineSlowTests
as a reference:
diffusers/tests/pipelines/flux/test_pipeline_flux.py
Lines 233 to 237 in dbe4136
@nightly | |
@require_big_accelerator | |
class FluxPipelineSlowTests(unittest.TestCase): | |
pipeline_class = FluxPipeline | |
repo_id = "black-forest-labs/FLUX.1-schnell" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay to skip it for now IMO since we also don't add it for Qwen.
|
||
|
||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py | ||
class PhotoEmbedND(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename PhotoEmbedND
--> PhotonEmbedND
:)
List of embedding dimensions for each axis (each must be even). | ||
""" | ||
|
||
def __init__(self, dim: int, theta: int, axes_dim: list[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, dim: int, theta: int, axes_dim: list[int]): | |
def __init__(self, dim: int, theta: int, axes_dim: List[int]): |
I believe the CI will raise an error without this change
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | ||
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | | ||
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | ||
| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | |
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | | ||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| | ||
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | ||
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these model links expected to be broken for now? I get a 404 for https://huggingface.co/Photoroom/photon-256-t2i-sft currently and see that only the Photoroom/photon-256-t2i
model is currently in the Photon collection.
"MultiControlNetModel", | ||
"OmniGenTransformer2DModel", | ||
"ParallelConfig", | ||
"PhotonTransformer2DModel", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add PhotonPipeline
to the main __init__
? As an example, here is how FluxPipeline
is added:
diffusers/src/diffusers/__init__.py
Line 457 in dbe4136
"FluxPipeline", |
diffusers/src/diffusers/__init__.py
Line 1119 in dbe4136
FluxPipeline, |
Also, could you add PhotonTransformer2DModel
to the TYPE_CHECKING
section of __init__
? Here is how FluxTransformer2DModel
is added:
diffusers/src/diffusers/__init__.py
Line 906 in dbe4136
FluxTransformer2DModel, |
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] | ||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] | ||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] | ||
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add PhotonTransformer2DModel
to the TYPE_CHECKING
section of the src/diffusers/models/__init__.py
file? Here is how FluxTransformer2DModel
is added:
diffusers/src/diffusers/models/__init__.py
Line 180 in dbe4136
FluxTransformer2DModel, |
"FluxKontextPipeline", | ||
"FluxKontextInpaintPipeline", | ||
] | ||
_import_structure["photon"] = ["PhotonPipeline"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add PhotonPipeline
to the TYPE_CHECKING
section of src/diffusers/pipelines/__init__.py
? Here is how FluxPipeline
is added:
diffusers/src/diffusers/pipelines/__init__.py
Line 636 in dbe4136
FluxPipeline, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! The PR is close to merge, I think the most important things left are to fix the imports (e.g. #12456 (comment)) and other changes to make the CI green :).
encoder_params = dict( | ||
vocab_size=tokenizer.vocab_size, | ||
hidden_size=8, | ||
intermediate_size=16, | ||
num_hidden_layers=1, | ||
num_attention_heads=2, | ||
num_key_value_heads=1, | ||
head_dim=4, | ||
max_position_embeddings=64, | ||
layer_types=["full_attention"], | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
dropout_rate=0.0, | ||
hidden_activation="gelu_pytorch_tanh", | ||
rms_norm_eps=1e-06, | ||
attn_logit_softcapping=50.0, | ||
final_logit_softcapping=30.0, | ||
query_pre_attn_scalar=4, | ||
rope_theta=10000.0, | ||
sliding_window=4096, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
encoder_params = dict( | |
vocab_size=tokenizer.vocab_size, | |
hidden_size=8, | |
intermediate_size=16, | |
num_hidden_layers=1, | |
num_attention_heads=2, | |
num_key_value_heads=1, | |
head_dim=4, | |
max_position_embeddings=64, | |
layer_types=["full_attention"], | |
attention_bias=False, | |
attention_dropout=0.0, | |
dropout_rate=0.0, | |
hidden_activation="gelu_pytorch_tanh", | |
rms_norm_eps=1e-06, | |
attn_logit_softcapping=50.0, | |
final_logit_softcapping=30.0, | |
query_pre_attn_scalar=4, | |
rope_theta=10000.0, | |
sliding_window=4096, | |
) | |
encoder_params = { | |
"vocab_size": tokenizer.vocab_size, | |
"hidden_size": 8, | |
"intermediate_size": 16, | |
"num_hidden_layers": 1, | |
"num_attention_heads": 2, | |
"num_key_value_heads": 1, | |
"head_dim": 4, | |
"max_position_embeddings": 64, | |
"layer_types": ["full_attention"], | |
"attention_bias": False, | |
"attention_dropout": 0.0, | |
"dropout_rate": 0.0, | |
"hidden_activation": "gelu_pytorch_tanh", | |
"rms_norm_eps": 1e-06, | |
"attn_logit_softcapping": 50.0, | |
"final_logit_softcapping": 30.0, | |
"query_pre_attn_scalar": 4, | |
"rope_theta": 10000.0, | |
"sliding_window": 4096, | |
} |
make style
/make quality
complain about the dict(...)
call here and I think it will happier if a dict
literal is used instead
…ts + some renaiming
This commit adds support for the Photon image generation model:
Some exemples below with the 512 model fine-tuned on the Alchemist dataset and distilled with PAG
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.