-
Notifications
You must be signed in to change notification settings - Fork 5k
[Diffusion] Add QKV fusion optimization for Flux models #14505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6f48bbd
eeea208
83b3a76
ef908c7
ce08aed
8a23117
f8314e1
7f4c11b
6bdf938
e7335fc
418c633
4e9009b
6c61160
66fb578
dac9554
6c5e299
9a9b300
4b553e2
210dd8d
a45c708
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,13 +69,13 @@ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states | |
| def _get_fused_projections( | ||
| attn: "FluxAttention", hidden_states, encoder_hidden_states=None | ||
| ): | ||
| query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) | ||
| qkv, _ = attn.to_qkv(hidden_states) | ||
| query, key, value = qkv.chunk(3, dim=-1) | ||
|
|
||
| encoder_query = encoder_key = encoder_value = None | ||
| if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): | ||
| encoder_query, encoder_key, encoder_value = attn.to_added_qkv( | ||
| encoder_hidden_states | ||
| ).chunk(3, dim=-1) | ||
| added_qkv, _ = attn.to_added_qkv(encoder_hidden_states) | ||
| encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) | ||
|
|
||
| return query, key, value, encoder_query, encoder_key, encoder_value | ||
|
|
||
|
|
@@ -89,6 +89,7 @@ def _get_qkv_projections( | |
|
|
||
|
|
||
| class FluxAttention(torch.nn.Module, AttentionModuleMixin): | ||
| _supports_qkv_fusion = True | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -161,6 +162,61 @@ def __init__( | |
| }, | ||
| ) | ||
|
|
||
| self.fused_projections = False | ||
|
|
||
| @torch.no_grad() | ||
| def fuse_projections(self): | ||
| if self.fused_projections: | ||
| return | ||
|
|
||
| device = self.to_q.weight.data.device | ||
| dtype = self.to_q.weight.data.dtype | ||
|
|
||
| concatenated_weights = torch.cat( | ||
| [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_qkv = ReplicatedLinear(in_features, out_features, bias=self.use_bias) | ||
| self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) | ||
| if self.use_bias: | ||
| concatenated_bias = torch.cat( | ||
| [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] | ||
| ) | ||
| self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) | ||
|
|
||
| if self.added_kv_proj_dim is not None: | ||
| concatenated_weights = torch.cat( | ||
| [ | ||
| self.add_q_proj.weight.data, | ||
| self.add_k_proj.weight.data, | ||
| self.add_v_proj.weight.data, | ||
| ] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_added_qkv = ReplicatedLinear( | ||
| in_features, out_features, bias=self.added_proj_bias | ||
| ) | ||
| self.to_added_qkv.weight.data = concatenated_weights.to( | ||
| device=device, dtype=dtype | ||
| ) | ||
| if self.added_proj_bias: | ||
| concatenated_bias = torch.cat( | ||
| [ | ||
| self.add_q_proj.bias.data, | ||
| self.add_k_proj.bias.data, | ||
| self.add_v_proj.bias.data, | ||
| ] | ||
| ) | ||
| self.to_added_qkv.bias.data = concatenated_bias.to( | ||
| device=device, dtype=dtype | ||
| ) | ||
|
|
||
| self.fused_projections = True | ||
|
|
||
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
|
|
@@ -473,6 +529,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: | |
| bias=True, | ||
| ) | ||
|
|
||
| def fuse_qkv_projections(self): | ||
| for block in self.transformer_blocks: | ||
| if hasattr(block.attn, "fuse_projections") and getattr( | ||
| block.attn, "_supports_qkv_fusion", True | ||
| ): | ||
| block.attn.fuse_projections() | ||
|
|
||
| for block in self.single_transformer_blocks: | ||
| if hasattr(block.attn, "fuse_projections") and getattr( | ||
| block.attn, "_supports_qkv_fusion", True | ||
| ): | ||
| block.attn.fuse_projections() | ||
|
|
||
|
Comment on lines
+532
to
+544
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loops over def fuse_qkv_projections(self):
for block in list(self.transformer_blocks) + list(self.single_transformer_blocks):
if hasattr(block.attn, "fuse_projections"):
block.attn.fuse_projections() |
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,13 +52,13 @@ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_state | |
| def _get_fused_projections( | ||
| attn: "Flux2Attention", hidden_states, encoder_hidden_states=None | ||
| ): | ||
| query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) | ||
| qkv = attn.to_qkv(hidden_states) | ||
| query, key, value = qkv.chunk(3, dim=-1) | ||
|
|
||
| encoder_query = encoder_key = encoder_value = (None,) | ||
| encoder_query = encoder_key = encoder_value = None | ||
| if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): | ||
| encoder_query, encoder_key, encoder_value = attn.to_added_qkv( | ||
| encoder_hidden_states | ||
| ).chunk(3, dim=-1) | ||
| added_qkv = attn.to_added_qkv(encoder_hidden_states) | ||
| encoder_query, encoder_key, encoder_value = added_qkv.chunk(3, dim=-1) | ||
|
|
||
| return query, key, value, encoder_query, encoder_key, encoder_value | ||
|
|
||
|
|
@@ -114,6 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
|
|
||
|
|
||
| class Flux2Attention(torch.nn.Module, AttentionModuleMixin): | ||
| _supports_qkv_fusion = True | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -182,6 +183,61 @@ def __init__( | |
| }, | ||
| ) | ||
|
|
||
| self.fused_projections = False | ||
|
|
||
| @torch.no_grad() | ||
| def fuse_projections(self): | ||
| if self.fused_projections: | ||
| return | ||
|
|
||
| device = self.to_q.weight.data.device | ||
| dtype = self.to_q.weight.data.dtype | ||
|
|
||
| concatenated_weights = torch.cat( | ||
| [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_qkv = torch.nn.Linear(in_features, out_features, bias=self.use_bias) | ||
| self.to_qkv.weight.data = concatenated_weights.to(device=device, dtype=dtype) | ||
| if self.use_bias: | ||
| concatenated_bias = torch.cat( | ||
| [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] | ||
| ) | ||
| self.to_qkv.bias.data = concatenated_bias.to(device=device, dtype=dtype) | ||
|
|
||
| if self.added_kv_proj_dim is not None: | ||
| concatenated_weights = torch.cat( | ||
| [ | ||
| self.add_q_proj.weight.data, | ||
| self.add_k_proj.weight.data, | ||
| self.add_v_proj.weight.data, | ||
| ] | ||
| ) | ||
| in_features = concatenated_weights.shape[1] | ||
| out_features = concatenated_weights.shape[0] | ||
|
|
||
| self.to_added_qkv = torch.nn.Linear( | ||
| in_features, out_features, bias=self.added_proj_bias | ||
| ) | ||
| self.to_added_qkv.weight.data = concatenated_weights.to( | ||
| device=device, dtype=dtype | ||
| ) | ||
| if self.added_proj_bias: | ||
| concatenated_bias = torch.cat( | ||
| [ | ||
| self.add_q_proj.bias.data, | ||
| self.add_k_proj.bias.data, | ||
| self.add_v_proj.bias.data, | ||
| ] | ||
| ) | ||
| self.to_added_qkv.bias.data = concatenated_bias.to( | ||
| device=device, dtype=dtype | ||
| ) | ||
|
|
||
| self.fused_projections = True | ||
|
Comment on lines
+189
to
+239
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the @torch.no_grad()
def fuse_projections(self):
if self.fused_projections:
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
def _fuse_linear(q_proj, k_proj, v_proj, use_bias):
concatenated_weights = torch.cat(
[q_proj.weight.data, k_proj.weight.data, v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
fused_layer = torch.nn.Linear(in_features, out_features, bias=use_bias)
fused_layer.weight.data = concatenated_weights.to(device=device, dtype=dtype)
if use_bias:
concatenated_bias = torch.cat(
[q_proj.bias.data, k_proj.bias.data, v_proj.bias.data]
)
fused_layer.bias.data = concatenated_bias.to(device=device, dtype=dtype)
return fused_layer
self.to_qkv = _fuse_linear(self.to_q, self.to_k, self.to_v, self.use_bias)
del self.to_q, self.to_k, self.to_v
if self.added_kv_proj_dim is not None:
self.to_added_qkv = _fuse_linear(
self.add_q_proj, self.add_k_proj, self.add_v_proj, self.added_proj_bias
)
del self.add_q_proj, self.add_k_proj, self.add_v_proj
self.fused_projections = True |
||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
@@ -728,6 +784,19 @@ def __init__(self, config: FluxConfig, hf_config: dict[str, Any]): | |
|
|
||
| self.gradient_checkpointing = False | ||
|
|
||
| def fuse_qkv_projections(self): | ||
| for block in self.transformer_blocks: | ||
| if hasattr(block.attn, "fuse_projections") and getattr( | ||
| block.attn, "_supports_qkv_fusion", True | ||
| ): | ||
| block.attn.fuse_projections() | ||
|
|
||
| for block in self.single_transformer_blocks: | ||
| if hasattr(block.attn, "fuse_projections") and getattr( | ||
| block.attn, "_supports_qkv_fusion", True | ||
| ): | ||
| block.attn.fuse_projections() | ||
|
|
||
|
Comment on lines
+787
to
+799
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loops over def fuse_qkv_projections(self):
for block in list(self.transformer_blocks) + list(self.single_transformer_blocks):
if hasattr(block.attn, "fuse_projections"):
block.attn.fuse_projections() |
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for creating the fused
to_qkvandto_added_qkvlayers is very similar and contains duplicated code. This can be refactored into a helper method to improve maintainability and reduce redundancy. Additionally, after fusing the projection layers, the original layers (to_q,to_k,to_v, etc.) are no longer needed and can be deleted to free up GPU memory. This is an important optimization, especially for large models.