Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cacheflow/models/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.nn as nn

from cacheflow import activation_ops


class SiluAndMul(nn.Module):

def __init__(self):
super().__init__()

def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out
92 changes: 46 additions & 46 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional
from typing import Optional

from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn

Expand All @@ -16,40 +16,38 @@ def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)

self.flash_attn = FlashAttention(softmax_scale=self.scale)

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
prompt_lens: List[int],
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[2]
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

device = query.device
prefix_sum = [0]
for prompt_len in prompt_lens:
prefix_sum.append(prefix_sum[-1] + prompt_len)
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
max_prompt_len = max(prompt_lens)

# FIXME(woosuk): Unnecessary copy. Optimize this.
qkv = torch.stack([query, key, value], dim=1)
out = self.flash_attn(
qkv,
cu_seqlens=prefix_sum,
max_s=max_prompt_len,
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
)[0]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output.copy_(out, non_blocking=True)
return_softmax=False,
)
Comment on lines +35 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, so flash attention natively supports non-contiguous QKV tensors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It actually requires qkv tensor of shape [num_tokens, 3, num_heads, head_size]. Previously, we inserted torch.stack to meet this shape requirement, and this PR eliminates this inefficiency.


def single_query_cached_kv_attention(
self,
Expand Down Expand Up @@ -90,21 +88,18 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Prune out paddings if any.
query = query[:input_metadata.num_valid_tokens]
key = key[:input_metadata.num_valid_tokens]
value = value[:input_metadata.num_valid_tokens]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].

# Reshape the input tensors.
# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
output = output.view(-1, num_heads, head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)

# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
Expand All @@ -114,22 +109,31 @@ def forward(
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
)

# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()

# Reshape the keys and values and store them in the cache.
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)

if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:],
query[num_prompt_tokens:],
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
Expand Down Expand Up @@ -186,19 +190,15 @@ def forward(
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
out_query = torch.empty_like(query)
out_key = torch.empty_like(key)
pos_encoding_ops.rotary_embedding_neox(
out_query,
out_key,
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
out_query,
out_key,
query,
key,
value,
key_cache,
value_cache,
Expand Down
5 changes: 5 additions & 0 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
Expand All @@ -20,13 +21,15 @@ def __init__(
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
Expand All @@ -40,11 +43,13 @@ def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
100 changes: 50 additions & 50 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import LlamaConfig

from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
Expand All @@ -33,23 +34,20 @@ def __init__(
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()

def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
x = self.act_fn(gate) * up
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x

Expand All @@ -70,24 +68,9 @@ def __init__(
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5

# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
Expand All @@ -109,9 +92,8 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
Expand Down Expand Up @@ -230,32 +212,50 @@ def forward(
return next_tokens

_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
if "qkv_proj" in name or "gate_up_proj" in name:
if "qkv_proj" in name:
original_name = "qkv_proj"
weight_names = ["q_proj", "k_proj", "v_proj"]
shard_size = param.shape[0] // 3
else:
original_name = "gate_up_proj"
weight_names = ["gate_proj", "up_proj"]
shard_size = param.shape[0] // 2
weights_to_concat = []
for weight_name in weight_names:
weight = np.load(os.path.join(
weights_path, name.replace(original_name, weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break

assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
Expand Down
Loading