Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Flex Decoding #196

Merged
merged 15 commits into from
Dec 14, 2024
40 changes: 24 additions & 16 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config
from torch.nn.attention.flex_attention import create_block_mask

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -54,29 +55,35 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

def roundup(val, multiplier):
return ((val - 1) // multiplier + 1) * multiplier

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
mask = create_block_mask(model.get_mask_mod(0), 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda")
logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)[0]

def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
block_index = input_pos // model.block_mask.BLOCK_SIZE[0]
mask = model.block_mask[block_index]
mask.mask_mod = model.get_mask_mod(input_pos[0])
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)

def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)

return new_tokens, new_probs

Expand Down Expand Up @@ -154,10 +161,11 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
T_buf = roundup(T_new, 128) # round up to multiple of 128 to use flex_attention
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
if interactive:
max_seq_length = 350
max_seq_length = 384
else:
max_seq_length = min(T_new, model.config.block_size)
max_seq_length = min(T_buf, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
Expand All @@ -167,7 +175,7 @@ def generate(
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty = torch.empty(T_buf, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
Expand Down Expand Up @@ -198,12 +206,12 @@ def generate(
next_token = next_tokens[-1]
else:
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
seq[T + 1:T_new] = torch.cat(generated_tokens)

generate_stats = {
'accept_counts': accept_counts
}
return seq, generate_stats
return seq[:T_new], generate_stats

def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
Expand Down
32 changes: 23 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
from typing import Callable, Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
BlockMask,
create_block_mask,
flex_attention,
)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


def get_causal_mask(offset):
def causal_mask(b, h, q, kv):
return offset + q >= kv

return causal_mask


@dataclass
class ModelArgs:
block_size: int = 2048
Expand Down Expand Up @@ -89,7 +103,7 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
def __init__(self, config: ModelArgs, get_mask_mod: Callable[[int], _mask_mod_signature]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_mask_mod shouldn't take an integer - it should take a mask_mod. We also don't need to set at as an argument, just set it as an attribute within the module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, you should be able to take any existing mask_mod and wrap it to make it automatically support an offset.

super().__init__()
self.config = config

Expand All @@ -102,6 +116,7 @@ def __init__(self, config: ModelArgs) -> None:
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
self.get_mask_mod = get_mask_mod

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
Expand All @@ -120,11 +135,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
self.block_mask = create_block_mask(self.get_mask_mod(0), 1, 1, max_seq_length, max_seq_length, device="cuda")

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

Expand All @@ -136,7 +150,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
return cls(ModelArgs.from_name(name), get_causal_mask)


class TransformerBlock(nn.Module):
Expand All @@ -147,7 +161,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -177,7 +191,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -197,7 +211,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down