Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Flex Decoding #196

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

def roundup(val, multiplier):
return ((val - 1) // multiplier + 1) * multiplier

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
logits = model.prefill(x, input_pos)
return sample(logits, **sampling_kwargs)[0]

def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -68,15 +71,14 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)

return new_tokens, new_probs

Expand Down Expand Up @@ -154,10 +156,11 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
T_buf = roundup(T_new, 128) # round up to multiple of 128 to use flex_attention
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
if interactive:
max_seq_length = 350
max_seq_length = roundup(350, 128)
else:
max_seq_length = min(T_new, model.config.block_size)
max_seq_length = min(T_buf, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
Expand All @@ -167,7 +170,7 @@ def generate(
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty = torch.empty(T_buf, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
Expand Down Expand Up @@ -198,12 +201,12 @@ def generate(
next_token = next_tokens[-1]
else:
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
seq[T + 1:T_new] = torch.cat(generated_tokens)

generate_stats = {
'accept_counts': accept_counts
}
return seq, generate_stats
return seq[:T_new], generate_stats

def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
Expand Down
57 changes: 51 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,51 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


def causal_mask(b, h, q, kv):
return q >= kv


class CausalMask:
def __init__(self, max_seq_length):
self.input_pos = None
self.max_seq_length = max_seq_length
self.block_masks = create_block_mask(causal_mask, 1, 1, max_seq_length, max_seq_length, device="cuda")

def decoding_causal_mask(self, b, h, q, kv):
offset = self.input_pos[0]
return offset + q >= kv

def get_mask(self, kv_len, input_pos) -> BlockMask:
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
self.input_pos = input_pos
offset = self.input_pos // self.block_masks.BLOCK_SIZE[0]
max_block_in_kv = (kv_len - 1) // self.block_masks.BLOCK_SIZE[1] + 1
new_kv_num_blocks = self.block_masks.kv_num_blocks[:, :, offset]
new_kv_indices = self.block_masks.kv_indices[:, :, offset, :max_block_in_kv]
new_full_kv_num_blocks = self.block_masks.full_kv_num_blocks[:, :, offset]
new_full_kv_indices = self.block_masks.full_kv_indices[:, :, offset, :max_block_in_kv]
layer_mask = BlockMask.from_kv_blocks(
new_kv_num_blocks,
new_kv_indices,
new_full_kv_num_blocks,
new_full_kv_indices,
BLOCK_SIZE=self.block_masks.BLOCK_SIZE,
mask_mod=self.decoding_causal_mask,
)
return layer_mask

def gen_prefill_mask(self, kv_len, q_len) -> BlockMask:
return create_block_mask(causal_mask, 1, 1, q_len, kv_len)


@dataclass
class ModelArgs:
block_size: int = 2048
Expand Down Expand Up @@ -102,6 +140,7 @@ def __init__(self, config: ModelArgs) -> None:
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
self.causal_mask = CausalMask(config.block_size)

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
Expand All @@ -120,11 +159,9 @@ def setup_caches(self, max_batch_size, max_seq_length):
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def _forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

Expand All @@ -134,6 +171,14 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
logits = self.output(x)
return logits

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask.get_mask(self.max_seq_length, input_pos)
return self._forward(mask, idx, input_pos)

def prefill(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask.gen_prefill_mask(self.max_seq_length, input_pos.shape[0])
return self._forward(mask, idx, input_pos)

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
Expand All @@ -147,7 +192,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -177,7 +222,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -197,7 +242,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down