Skip to content

Commit a2f8029

Browse files
author
Griffin Adams
committed
Remove mutable args and remove costly index sorting for window attn.
1 parent 538314b commit a2f8029

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

cache.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ class KVCache(ABC, nn.Module):
1313
relevant_kwargs = ["max_cache_length"]
1414

1515
def __init__(
16-
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, head_specific=False, **kwargs
16+
self,
17+
max_batch_size,
18+
n_heads,
19+
head_dim,
20+
dtype=torch.bfloat16,
21+
head_specific=False,
22+
**kwargs,
1723
):
1824
super().__init__()
1925

@@ -28,7 +34,15 @@ def __init__(
2834
# We use n_heads as an optional second dimension to allow for head-specific evictions.
2935
self.register_buffer(
3036
"pos",
31-
torch.full((max_batch_size, n_heads if head_specific else 1, self.max_cache_length), -1, dtype=torch.int),
37+
torch.full(
38+
(
39+
max_batch_size,
40+
n_heads if head_specific else 1,
41+
self.max_cache_length,
42+
),
43+
-1,
44+
dtype=torch.int,
45+
),
3246
)
3347

3448
self.updates = 0
@@ -49,7 +63,7 @@ def reset(self):
4963
self.pos.fill_(-1)
5064
self.insertions = 0
5165
self.updates = 0
52-
66+
5367
def update(self, input_pos, k_val, v_val):
5468
"""
5569
Updates the cache with the given input positions, keys, and values.
@@ -73,7 +87,9 @@ def update(self, input_pos, k_val, v_val):
7387
# Truncate the unfilled part of the cache
7488
# Since we always fill in-order it will be at the end
7589
truncate_idx = min(self.insertions, self.max_cache_length)
76-
return self.k_cache[:, :, :truncate_idx, :], self.v_cache[:, :, :truncate_idx, :]
90+
return self.k_cache[:, :, :truncate_idx, :], self.v_cache[
91+
:, :, :truncate_idx, :
92+
]
7793

7894
@abstractmethod
7995
def _update(self, input_pos, k_val, v_val):
@@ -116,19 +132,16 @@ def __init__(
116132
def mark_global_tokens(self) -> bool:
117133
"""
118134
Update POS tensor to give global tokens highest priority.
119-
135+
120136
Return a boolean indicating whether or not all global tokens were filled.
121137
122138
If it returns True, this function won't be called again to save computation.
123139
"""
124140
# We put max priority on leading "global" tokens
125-
global_mask = torch.logical_and(
126-
self.pos < self.global_tokens, self.pos >= 0
127-
)
141+
global_mask = torch.logical_and(self.pos < self.global_tokens, self.pos >= 0)
128142
# Give self.score an arbitrary high value for global tokens so that they are not replaced
129143
self.pos.masked_fill_(global_mask, LARGE_INTEGER)
130-
return global_mask.sum() == self.global_tokens
131-
144+
return (global_mask.sum() == self.global_tokens).item()
132145

133146
def _update(self, input_pos, k_val, v_val):
134147
# Prefill case: If prompt > window, then we need to chop off early positions
@@ -144,19 +157,18 @@ def _update(self, input_pos, k_val, v_val):
144157
input_pos = input_pos[keep_idxs]
145158
k_val = k_val[:, :, keep_idxs]
146159
v_val = v_val[:, :, keep_idxs]
147-
160+
148161
# Identify the lowest positions in the cache that are not filled
149-
# For window, all heads are the same so let's just use the first head for "pos"
150162
pos = self.pos[:, 0, :].squeeze(1)
151163
_, min_k_indices = pos.topk(input_pos.shape[0], largest=False)
164+
min_k_indices = min_k_indices.squeeze(0)
152165

153-
# Sort the indices in ascending order
154-
min_k_indices, _ = min_k_indices.squeeze(0).sort()
155-
156-
self.fill(fill_indices=min_k_indices, input_pos=input_pos, k_val=k_val, v_val=v_val)
166+
self.fill(
167+
fill_indices=min_k_indices, input_pos=input_pos, k_val=k_val, v_val=v_val
168+
)
157169

158170
# This is a potentially costly operation which doesn't need to be repeated once we've filled the global tokens
159-
self.global_filled |= self.mark_global_tokens()
171+
self.global_filled = self.global_filled or self.mark_global_tokens()
160172

161173

162174
def get_cache_constructor(cache_strategy):

generate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def decode_n_tokens(
9999
cur_token: torch.Tensor,
100100
input_pos: torch.Tensor,
101101
num_new_tokens: int,
102-
terminator_ids: Optional[list] = [],
102+
terminator_ids: Optional[list] = None,
103103
callback=lambda _: _,
104104
**sampling_kwargs,
105105
):
@@ -200,8 +200,8 @@ def generate(
200200
speculate_k: Optional[int] = 8,
201201
max_cache_length: Optional[float] = 1.0,
202202
callback=lambda x: x,
203-
terminator_ids: Optional[list] = [],
204-
cache_kwargs: dict = {"max_cache_length": 1.0},
203+
terminator_ids: Optional[list] = None,
204+
cache_kwargs: dict = None,
205205
**sampling_kwargs,
206206
) -> torch.Tensor:
207207
"""
@@ -235,7 +235,9 @@ def generate(
235235
), f"Specified max cache length ({max_cache_length}) must be less than max_seq_length ({max_seq_length})."
236236
cache_kwargs["max_cache_length"] = max_cache_length
237237

238-
assert cache_kwargs["global_tokens"] <= max_cache_length, "Global tokens must be less than max_cache_length."
238+
assert (
239+
cache_kwargs["global_tokens"] <= max_cache_length
240+
), "Global tokens must be less than max_cache_length."
239241

240242
with torch.device(device):
241243
model.setup_caches(max_batch_size=1, **cache_kwargs)
@@ -615,10 +617,9 @@ def callback(x):
615617
args.max_cache_length == 1.0
616618
), "Full cache strategy only supports max_cache_length=1.0."
617619

618-
# TODO Nicer way to bundle these?
619620
cache_kwargs = {
620-
"max_cache_length": args.max_cache_length,
621621
"cache_strategy": args.cache_strategy,
622+
"max_cache_length": args.max_cache_length,
622623
"global_tokens": args.global_tokens,
623624
}
624625

model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def forward(
244244
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
245245

246246
is_prefill = self.kv_cache.is_prefill()
247+
247248
cache_k, cache_v = self.kv_cache.update(input_pos, k, v)
248249

249250
# If we are in the prefill stage, we use the provided prompt kv-pairs
@@ -253,7 +254,14 @@ def forward(
253254

254255
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
255256
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
256-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
257+
y = F.scaled_dot_product_attention(
258+
q,
259+
k,
260+
v,
261+
is_causal=False,
262+
attn_mask=mask,
263+
dropout_p=0.0,
264+
)
257265

258266
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
259267

0 commit comments

Comments
 (0)