Skip to content

Commit d344134

Browse files
committed
support automatic dispatch.
1 parent 18c3e8e commit d344134

File tree

3 files changed

+146
-16
lines changed

3 files changed

+146
-16
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import inspect
1818
import math
1919
from enum import Enum
20-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
20+
from functools import partial
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2122

2223
import torch
2324

@@ -84,12 +85,16 @@
8485
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
8586
)
8687
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
88+
from ..utils.sage_utils import _get_sage_attn_fn_for_device
8789

8890
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
8991
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
9092

9193
sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
92-
sage_attn_func_hub = sage_interface_hub.sageattn
94+
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
95+
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
96+
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])
97+
9398
else:
9499
flash_attn_3_func_hub = None
95100
sage_attn_func_hub = None
@@ -166,10 +171,6 @@ def wrap(func):
166171
# - CP with sage attention, flex, xformers, other missing backends
167172
# - Add support for normal and CP training with backends that don't support it yet
168173

169-
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
170-
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
171-
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
172-
173174

174175
class AttentionBackendName(str, Enum):
175176
# EAGER = "eager"
@@ -1777,15 +1778,7 @@ def _sage_attention_hub(
17771778
) -> torch.Tensor:
17781779
lse = None
17791780
if _parallel_config is None:
1780-
out = sage_attn_func_hub(
1781-
q=query,
1782-
k=key,
1783-
v=value,
1784-
tensor_layout="NHD",
1785-
is_causal=is_causal,
1786-
sm_scale=scale,
1787-
return_lse=return_lse,
1788-
)
1781+
out = sage_attn_func_hub(q=query, k=key, v=value)
17891782
if return_lse:
17901783
out, lse, *_ = out
17911784
else:

src/diffusers/utils/kernels_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
_KERNEL_REVISION = {
1111
# TODO: temporary revision for now. Remove when merged upstream into `main`.
1212
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
13-
_DEFAULT_HUB_ID_SAGE: None,
13+
_DEFAULT_HUB_ID_SAGE: "compile",
1414
}
1515

1616

src/diffusers/utils/sage_utils.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Copyright (c) 2024 by SageAttention, The HuggingFace team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
5+
License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
10+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
"""
13+
14+
"""
15+
Modified from
16+
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
17+
"""
18+
19+
20+
import torch # noqa
21+
22+
23+
SAGE_ATTENTION_DISPATCH = {
24+
"sm80": {
25+
"func": "sageattn_qk_int8_pv_fp16_cuda",
26+
"kwargs": {
27+
"tensor_layout": "NHD",
28+
"is_causal": False,
29+
"sm_scale": None,
30+
"return_lse": False,
31+
"pv_accum_dtype": "fp32",
32+
},
33+
},
34+
"sm89": {
35+
"func": "sageattn_qk_int8_pv_fp8_cuda",
36+
"kwargs": {
37+
"tensor_layout": "NHD",
38+
"is_causal": False,
39+
"sm_scale": None,
40+
"return_lse": False,
41+
"pv_accum_dtype": "fp32+fp16",
42+
},
43+
},
44+
"sm90": {
45+
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
46+
"kwargs": {
47+
"tensor_layout": "NHD",
48+
"is_causal": False,
49+
"sm_scale": None,
50+
"return_lse": False,
51+
"pv_accum_dtype": "fp32+fp32",
52+
},
53+
},
54+
"sm120": {
55+
"func": "sageattn_qk_int8_pv_fp8_cuda",
56+
"kwargs": {
57+
"tensor_layout": "NHD",
58+
"is_causal": False,
59+
"qk_quant_gran": "per_warp",
60+
"sm_scale": None,
61+
"return_lse": False,
62+
"pv_accum_dtype": "fp32+fp16",
63+
},
64+
},
65+
}
66+
67+
68+
def get_cuda_version():
69+
if torch.cuda.is_available():
70+
major, minor = torch.cuda.get_device_capability()
71+
return major, minor
72+
else:
73+
raise EnvironmentError("CUDA not found.")
74+
75+
76+
def get_cuda_arch_versions():
77+
if not torch.cuda.is_available():
78+
EnvironmentError("CUDA not found.")
79+
cuda_archs = []
80+
for i in range(torch.cuda.device_count()):
81+
major, minor = torch.cuda.get_device_capability(i)
82+
cuda_archs.append(f"sm{major}{minor}")
83+
return cuda_archs
84+
85+
86+
# Unlike the actual implementation, we just maintain function names rather than actual
87+
# implementations.
88+
def _get_sage_attn_fn_for_device():
89+
"""
90+
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
91+
capability.
92+
93+
Parameters ---------- q : torch.Tensor
94+
The query tensor. Shape:
95+
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
96+
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
97+
98+
k : torch.Tensor
99+
The key tensor. Shape:
100+
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
101+
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
102+
103+
v : torch.Tensor
104+
The value tensor. Shape:
105+
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
106+
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
107+
108+
tensor_layout : str
109+
The tensor layout, either "HND" or "NHD". Default: "HND".
110+
111+
is_causal : bool
112+
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.
113+
114+
sm_scale : Optional[float]
115+
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
116+
117+
return_lse : bool
118+
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
119+
Default: False.
120+
121+
Returns ------- torch.Tensor
122+
The output tensor. Shape:
123+
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
124+
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
125+
126+
torch.Tensor
127+
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
128+
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.
129+
130+
Note ----
131+
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
132+
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
133+
- All tensors must be on the same cuda device.
134+
"""
135+
device_index = torch.cuda.current_device()
136+
arch = get_cuda_arch_versions()[device_index]
137+
return SAGE_ATTENTION_DISPATCH[arch]

0 commit comments

Comments
 (0)