Skip to content

Commit 06aa103

Browse files
author
Griffin Adams
committed
Adds fixed-window KV-cache.
- Creates cache.py - Introduces global_tokens - Formats repo with ruff - Speed parity with full KV-cache
1 parent c2af69b commit 06aa103

File tree

5 files changed

+374
-59
lines changed

5 files changed

+374
-59
lines changed

cache.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
LARGE_INTEGER = int(1e9) # This is used to assign high priority ids
8+
9+
10+
class KVCache(ABC, nn.Module):
11+
# Define which hyperparameters are relevant for the cache.
12+
# Override as needed for sub-classes.
13+
relevant_kwargs = ["max_cache_length"]
14+
15+
def __init__(
16+
self,
17+
max_batch_size,
18+
n_heads,
19+
head_dim,
20+
dtype=torch.bfloat16,
21+
head_specific=False,
22+
**kwargs,
23+
):
24+
super().__init__()
25+
26+
# Assign each kwarg as an attribute of the class
27+
for key, value in kwargs.items():
28+
setattr(self, key, value)
29+
30+
cache_shape = (max_batch_size, n_heads, self.max_cache_length, head_dim)
31+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
32+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
33+
# This is used to keep track of the order in which the cache is filled.
34+
# We use n_heads as an optional second dimension to allow for head-specific evictions.
35+
self.register_buffer(
36+
"pos",
37+
torch.full(
38+
(
39+
max_batch_size,
40+
n_heads if head_specific else 1,
41+
self.max_cache_length,
42+
),
43+
-1,
44+
dtype=torch.int,
45+
),
46+
)
47+
48+
self.updates = 0
49+
self.insertions = 0
50+
51+
def is_prefill(self):
52+
# If we are in the prefill stage, we have updated the cache at most once (self.updates <=1)
53+
# Prefill --> full self-attention (no KV-cache needed).
54+
# Otherwise --> query the KV-cache.
55+
return self.updates == 0
56+
57+
def reset(self):
58+
"""
59+
If needed, this will reset the cache, although it is likely not necessary for most cache types.
60+
"""
61+
self.k_cache.zero_()
62+
self.v_cache.zero_()
63+
self.pos.fill_(-1)
64+
self.insertions = 0
65+
self.updates = 0
66+
67+
def update(self, input_pos, k_val, v_val):
68+
"""
69+
Updates the cache with the given input positions, keys, and values.
70+
71+
Parameters:
72+
input_pos (torch.Tensor): A tensor of input positions.
73+
k_val (torch.Tensor): A tensor of keys.
74+
v_val (torch.Tensor): A tensor of values.
75+
76+
Returns:
77+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the updated cache of keys and values,
78+
both truncated to the minimum of the current insertions and the maximum cache length.
79+
"""
80+
81+
self._update(input_pos, k_val, v_val)
82+
83+
# Update counters
84+
self.updates += 1
85+
self.insertions += input_pos.shape[0]
86+
87+
# Truncate the unfilled part of the cache
88+
# Since we always fill in-order it will be at the end
89+
truncate_idx = min(self.insertions, self.max_cache_length)
90+
return self.k_cache[:, :, :truncate_idx, :], self.v_cache[
91+
:, :, :truncate_idx, :
92+
]
93+
94+
@abstractmethod
95+
def _update(self, input_pos, k_val, v_val):
96+
"""
97+
Cache-specific update logic.
98+
Takes in the input positions and the corresponding k and v values.
99+
Modifies self.pos, self.k_cache, self.v_cache place.
100+
"""
101+
pass
102+
103+
def fill(self, fill_indices, input_pos, k_val, v_val):
104+
self.k_cache[:, :, fill_indices] = k_val
105+
self.v_cache[:, :, fill_indices] = v_val
106+
self.pos[:, :, fill_indices] = input_pos.int()
107+
108+
109+
class KVCacheFull(KVCache):
110+
def __init__(
111+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
112+
):
113+
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
114+
115+
def _update(self, input_pos, k_val, v_val):
116+
# input_pos: [S], k_val: [B, H, S, D]
117+
assert input_pos.shape[0] == k_val.shape[2]
118+
self.fill(fill_indices=input_pos, input_pos=input_pos, k_val=k_val, v_val=v_val)
119+
120+
121+
class KVCacheWindow(KVCache):
122+
relevant_kwargs = ["max_cache_length", "global_tokens"]
123+
124+
def __init__(
125+
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
126+
):
127+
super().__init__(max_batch_size, n_heads, head_dim, dtype, **kwargs)
128+
129+
# This turns True when the global tokens are fully filled
130+
self.global_filled = self.global_tokens == 0
131+
132+
def mark_global_tokens(self) -> bool:
133+
"""
134+
Update POS tensor to give global tokens highest priority.
135+
136+
Return a boolean indicating whether or not all global tokens were filled.
137+
138+
If it returns True, this function won't be called again to save computation.
139+
"""
140+
# We put max priority on leading "global" tokens
141+
global_mask = torch.logical_and(self.pos < self.global_tokens, self.pos >= 0)
142+
# Give self.score an arbitrary high value for global tokens so that they are not replaced
143+
self.pos.masked_fill_(global_mask, LARGE_INTEGER)
144+
return (global_mask.sum() == self.global_tokens).item()
145+
146+
def _update(self, input_pos, k_val, v_val):
147+
# Prefill case: If prompt > window, then we need to chop off early positions
148+
window = self.k_cache.shape[2]
149+
if input_pos.shape[0] > window:
150+
# [global; ...; window - global] --> [global; window - global]
151+
# Indices for first global_tokens tokens and last (window - global_tokens) tokens
152+
keep_idxs = list(range(self.global_tokens)) + list(
153+
range(
154+
input_pos.shape[0] - window + self.global_tokens, input_pos.shape[0]
155+
)
156+
)
157+
input_pos = input_pos[keep_idxs]
158+
k_val = k_val[:, :, keep_idxs]
159+
v_val = v_val[:, :, keep_idxs]
160+
161+
# Identify the lowest positions in the cache that are not filled
162+
pos = self.pos[:, 0, :].squeeze(1)
163+
_, min_k_indices = pos.topk(input_pos.shape[0], largest=False)
164+
min_k_indices = min_k_indices.squeeze(0)
165+
166+
self.fill(
167+
fill_indices=min_k_indices, input_pos=input_pos, k_val=k_val, v_val=v_val
168+
)
169+
170+
# This is a potentially costly operation which doesn't need to be repeated once we've filled the global tokens
171+
self.global_filled = self.global_filled or self.mark_global_tokens()
172+
173+
174+
def get_cache_constructor(cache_strategy):
175+
if cache_strategy == "full":
176+
return KVCacheFull
177+
elif cache_strategy == "window":
178+
return KVCacheWindow
179+
else:
180+
raise ValueError(f"Invalid cache strategy: {cache_strategy}")

0 commit comments

Comments
 (0)