Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Flex Decoding #196

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config
from torch.nn.attention.flex_attention import BlockMask, create_block_mask

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -56,29 +57,39 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

def roundup(val, multiplier):
return ((val - 1) // multiplier + 1) * multiplier

def causal_mask(b, h, q, kv):
return q >= kv

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device="cuda")
logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)[0]

def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, block_mask: BlockMask, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
block_index = input_pos // block_mask.BLOCK_SIZE[0]
BoyuanFeng marked this conversation as resolved.
Show resolved Hide resolved
mask = block_mask[:, :, block_index]
mask.mask_mod = block_mask.mask_mod
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline discussed that BlockMask getitem sets mask_mod as None and the user needs to specify the correct mask_mod. In GPT-Fast, we rely on model.get_mask_mod to do so.

logits = model(mask, x, input_pos)
return sample(logits, **sampling_kwargs)

def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
block_mask = create_block_mask(causal_mask, 1, 1, model.max_seq_length, model.max_seq_length, device="cuda")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try doing
create_block_mask_compile = torch.compile(create_block_mask)

as a global

new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.clone()
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, block_mask, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.clone()

return new_tokens, new_probs

Expand Down
25 changes: 19 additions & 6 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,26 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
BlockMask,
flex_attention,
)


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


def get_mask_mod(mask_mod: _mask_mod_signature, offset: int):
def _mask_mod(b, h, q, kv):
return mask_mod(b, h, q + offset, kv)

return _mask_mod


@dataclass
class ModelArgs:
block_size: int = 2048
Expand Down Expand Up @@ -102,6 +115,7 @@ def __init__(self, config: ModelArgs) -> None:
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
self.get_mask_mod = get_mask_mod

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
Expand All @@ -120,11 +134,10 @@ def setup_caches(self, max_batch_size, max_seq_length):
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, mask: BlockMask, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
mask.mask_mod = self.get_mask_mod(mask.mask_mod, input_pos[0])
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

Expand All @@ -147,7 +160,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -177,7 +190,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -197,7 +210,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = flex_attention(q, k, v, block_mask=mask, enable_gqa=(self.n_head != self.n_local_heads))

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down