Skip to content

Commit 538314b

Browse files
author
Griffin Adams
committed
First pass at implementing fixed-window KV-cache.
- Adds global attention, minor refactoring of cache.py. - Ruff formatting.
1 parent c2af69b commit 538314b

File tree

5 files changed

+352
-58
lines changed

5 files changed

+352
-58
lines changed

cache.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
LARGE_INTEGER = int(1e9) # This is used to assign high priority ids
8+
9+
10+
class KVCache(ABC, nn.Module):
11+
# Define which hyperparameters are relevant for the cache.
12+
# Override as needed for sub-classes.
13+
relevant_kwargs = ["max_cache_length"]
14+
15+
def __init__(
16+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, head_specific=False, **kwargs
17+
):
18+
super().__init__()
19+
20+
# Assign each kwarg as an attribute of the class
21+
for key, value in kwargs.items():
22+
setattr(self, key, value)
23+
24+
cache_shape = (max_batch_size, n_heads, self.max_cache_length, head_dim)
25+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
26+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
27+
# This is used to keep track of the order in which the cache is filled.
28+
# We use n_heads as an optional second dimension to allow for head-specific evictions.
29+
self.register_buffer(
30+
"pos",
31+
torch.full((max_batch_size, n_heads if head_specific else 1, self.max_cache_length), -1, dtype=torch.int),
32+
)
33+
34+
self.updates = 0
35+
self.insertions = 0
36+
37+
def is_prefill(self):
38+
# If we are in the prefill stage, we have updated the cache at most once (self.updates <=1)
39+
# Prefill --> full self-attention (no KV-cache needed).
40+
# Otherwise --> query the KV-cache.
41+
return self.updates == 0
42+
43+
def reset(self):
44+
"""
45+
If needed, this will reset the cache, although it is likely not necessary for most cache types.
46+
"""
47+
self.k_cache.zero_()
48+
self.v_cache.zero_()
49+
self.pos.fill_(-1)
50+
self.insertions = 0
51+
self.updates = 0
52+
53+
def update(self, input_pos, k_val, v_val):
54+
"""
55+
Updates the cache with the given input positions, keys, and values.
56+
57+
Parameters:
58+
input_pos (torch.Tensor): A tensor of input positions.
59+
k_val (torch.Tensor): A tensor of keys.
60+
v_val (torch.Tensor): A tensor of values.
61+
62+
Returns:
63+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the updated cache of keys and values,
64+
both truncated to the minimum of the current insertions and the maximum cache length.
65+
"""
66+
67+
self._update(input_pos, k_val, v_val)
68+
69+
# Update counters
70+
self.updates += 1
71+
self.insertions += input_pos.shape[0]
72+
73+
# Truncate the unfilled part of the cache
74+
# Since we always fill in-order it will be at the end
75+
truncate_idx = min(self.insertions, self.max_cache_length)
76+
return self.k_cache[:, :, :truncate_idx, :], self.v_cache[:, :, :truncate_idx, :]
77+
78+
@abstractmethod
79+
def _update(self, input_pos, k_val, v_val):
80+
"""
81+
Cache-specific update logic.
82+
Takes in the input positions and the corresponding k and v values.
83+
Modifies self.pos, self.k_cache, self.v_cache place.
84+
"""
85+
pass
86+
87+
def fill(self, fill_indices, input_pos, k_val, v_val):
88+
self.k_cache[:, :, fill_indices] = k_val
89+
self.v_cache[:, :, fill_indices] = v_val
90+
self.pos[:, :, fill_indices] = input_pos.int()
91+
92+
93+
class KVCacheFull(KVCache):
94+
def __init__(
95+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
96+
):
97+
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
98+
99+
def _update(self, input_pos, k_val, v_val):
100+
# input_pos: [S], k_val: [B, H, S, D]
101+
assert input_pos.shape[0] == k_val.shape[2]
102+
self.fill(fill_indices=input_pos, input_pos=input_pos, k_val=k_val, v_val=v_val)
103+
104+
105+
class KVCacheWindow(KVCache):
106+
relevant_kwargs = ["max_cache_length", "global_tokens"]
107+
108+
def __init__(
109+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
110+
):
111+
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
112+
113+
# This turns True when the global tokens are fully filled
114+
self.global_filled = self.global_tokens == 0
115+
116+
def mark_global_tokens(self) -> bool:
117+
"""
118+
Update POS tensor to give global tokens highest priority.
119+
120+
Return a boolean indicating whether or not all global tokens were filled.
121+
122+
If it returns True, this function won't be called again to save computation.
123+
"""
124+
# We put max priority on leading "global" tokens
125+
global_mask = torch.logical_and(
126+
self.pos < self.global_tokens, self.pos >= 0
127+
)
128+
# Give self.score an arbitrary high value for global tokens so that they are not replaced
129+
self.pos.masked_fill_(global_mask, LARGE_INTEGER)
130+
return global_mask.sum() == self.global_tokens
131+
132+
133+
def _update(self, input_pos, k_val, v_val):
134+
# Prefill case: If prompt > window, then we need to chop off early positions
135+
window = self.k_cache.shape[2]
136+
if input_pos.shape[0] > window:
137+
# [global; ...; window - global] --> [global; window - global]
138+
# Indices for first global_tokens tokens and last (window - global_tokens) tokens
139+
keep_idxs = list(range(self.global_tokens)) + list(
140+
range(
141+
input_pos.shape[0] - window + self.global_tokens, input_pos.shape[0]
142+
)
143+
)
144+
input_pos = input_pos[keep_idxs]
145+
k_val = k_val[:, :, keep_idxs]
146+
v_val = v_val[:, :, keep_idxs]
147+
148+
# Identify the lowest positions in the cache that are not filled
149+
# For window, all heads are the same so let's just use the first head for "pos"
150+
pos = self.pos[:, 0, :].squeeze(1)
151+
_, min_k_indices = pos.topk(input_pos.shape[0], largest=False)
152+
153+
# Sort the indices in ascending order
154+
min_k_indices, _ = min_k_indices.squeeze(0).sort()
155+
156+
self.fill(fill_indices=min_k_indices, input_pos=input_pos, k_val=k_val, v_val=v_val)
157+
158+
# This is a potentially costly operation which doesn't need to be repeated once we've filled the global tokens
159+
self.global_filled |= self.mark_global_tokens()
160+
161+
162+
def get_cache_constructor(cache_strategy):
163+
if cache_strategy == "full":
164+
return KVCacheFull
165+
elif cache_strategy == "window":
166+
return KVCacheWindow
167+
else:
168+
raise ValueError(f"Invalid cache strategy: {cache_strategy}")

0 commit comments

Comments
 (0)