Skip to content

Commit f19b4e9

Browse files
sxufacebook-github-bot
authored andcommitted
Source transform to use static attention
Summary: Introduce a source transform to be more aligned with other transforms we run, also makes it less error prone (e.g. HF RoPE transformation needs to happen before turning linears into conv2ds). Differential Revision: D84769599
1 parent 11c0b4f commit f19b4e9

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

examples/models/llama/static_attention.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,38 @@ def __init__(
764764
self.q_norm = torch.nn.Identity()
765765
self.k_norm = torch.nn.Identity()
766766

767+
@classmethod
768+
def from_attention_mha(
769+
cls,
770+
other: AttentionMHA,
771+
split_mha: bool = True,
772+
**kwargs: Any,
773+
) -> "StaticAttention":
774+
config = ModelArgs(
775+
dim=other.dim,
776+
n_layers=1, # Not used in attention layer
777+
n_heads=other.n_heads,
778+
n_kv_heads=other.n_kv_heads,
779+
head_dim=other.head_dim,
780+
max_batch_size=other.max_batch_size,
781+
max_context_len=other.max_context_len,
782+
attention_qkv_bias=other.attention_qkv_bias,
783+
use_qk_norm=other.use_qk_norm,
784+
qk_norm_before_rope=other.qk_norm_before_rope,
785+
norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5,
786+
)
787+
788+
instance = cls(
789+
config=config,
790+
layer_id=other.layer_id,
791+
rope=other.rope,
792+
split_mha=split_mha,
793+
**kwargs,
794+
)
795+
instance.load_weights_from_attention_mha(other)
796+
797+
return instance
798+
767799
def forward(
768800
self,
769801
x: torch.Tensor,
@@ -1059,3 +1091,37 @@ def transfer_weight(linear, conv2d):
10591091
class StaticAttentionMHA(StaticAttention):
10601092
def __init__(self, config: ModelArgs, layer_id: int, rope: Rope, **kwargs: Any):
10611093
super().__init__(config, layer_id, rope, split_mha=False, **kwargs)
1094+
1095+
1096+
def transform_attention_mha_to_static_attention(
1097+
model: nn.Module,
1098+
split_mha: bool = True,
1099+
inplace: bool = True,
1100+
use_conv2d: bool = False,
1101+
use_hf_rope: bool = False,
1102+
**kwargs: Any,
1103+
) -> nn.Module:
1104+
if not inplace:
1105+
import copy
1106+
1107+
model = copy.deepcopy(model)
1108+
1109+
def helper(m):
1110+
for name, child in list(m.named_children()):
1111+
if isinstance(child, AttentionMHA):
1112+
static_attn = StaticAttention.from_attention_mha(
1113+
child, split_mha=split_mha, **kwargs
1114+
)
1115+
# Note: HF RoPE needs to be applied before linear to conv2d
1116+
if use_hf_rope:
1117+
static_attn.adopt_hf_rope()
1118+
if use_conv2d:
1119+
static_attn.linear_to_conv2d()
1120+
1121+
setattr(m, name, static_attn)
1122+
else:
1123+
helper(child)
1124+
1125+
return m
1126+
1127+
return helper(model)

examples/models/llama/tests/test_static_attention.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
StaticAttentionMask,
1515
StaticKCache,
1616
StaticKVCache,
17+
transform_attention_mha_to_static_attention,
1718
)
1819

1920

@@ -76,7 +77,6 @@ def test(
7677
layer_id = 0
7778
rope = Rope(config)
7879
attn_mha = AttentionMHA(config, layer_id, rope).eval()
79-
static_attn = StaticAttention(config, layer_id, rope).eval()
8080
if use_qk_norm:
8181
with torch.no_grad():
8282
attn_mha.q_norm_fn.weight.copy_(
@@ -85,7 +85,9 @@ def test(
8585
attn_mha.k_norm_fn.weight.copy_(
8686
torch.rand(config.head_dim) * 0.2 + 0.9
8787
)
88-
static_attn.load_weights_from_attention_mha(attn_mha)
88+
static_attn = StaticAttention.from_attention_mha(
89+
attn_mha, split_mha=split_mha
90+
).eval()
8991
if adopt_hf_rope:
9092
static_attn.adopt_hf_rope()
9193
if use_conv2d:
@@ -131,8 +133,7 @@ def test_with_cache(self):
131133
layer_id = 0
132134
rope = Rope(config)
133135
attn_mha = AttentionMHA(config, layer_id, rope).eval()
134-
static_attn = StaticAttention(config, layer_id, rope).eval()
135-
static_attn.load_weights_from_attention_mha(attn_mha)
136+
static_attn = StaticAttention.from_attention_mha(attn_mha).eval()
136137
static_attn.adopt_hf_rope()
137138

138139
x = torch.rand(1, config.max_seq_len, config.dim)
@@ -198,17 +199,16 @@ def test_with_style(style):
198199
def _get_test_transformers(self, config, attention_type="static", use_conv2d=False):
199200
mha_transformer = construct_transformer(config).eval()
200201

202+
static_transformer = transform_attention_mha_to_static_attention(
203+
mha_transformer,
204+
split_mha=(attention_type == "static"),
205+
inplace=False,
206+
use_conv2d=use_conv2d,
207+
use_hf_rope=True,
208+
).eval()
209+
201210
config = copy.copy(config)
202211
config.attention_type = attention_type
203-
static_transformer = construct_transformer(config).eval()
204-
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
205-
for mha_layer, static_layer in zip(
206-
mha_transformer.layers, static_transformer.layers
207-
):
208-
static_layer.attention.load_weights_from_attention_mha(mha_layer.attention)
209-
static_layer.attention.adopt_hf_rope()
210-
if use_conv2d:
211-
static_layer.linear_to_conv2d()
212212
config.use_hf_rope = True
213213

214214
return mha_transformer, static_transformer, config

0 commit comments

Comments
 (0)