Skip to content

Commit

Permalink
Batch-aware torch.ops.llama.sdpa_with_kv_cache (pytorch#4822)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4822

This changes makes torch.ops.llama.sdpa_with_kv_cache batch aware. This is needed for batched sdpa cases, for example llm beam search.

* Makes update_cache update across the batch dimension

As a performance optimization, update_cache implements the following operation
```
    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
```
as part of the fused sdpa_with_kv_cache op. A naiive export of this code inserts expensive slice-scatter ops. sdpa_with_kv_cache fuses this update with the flash attention op for tensors that follow a predetermined format [batch, length, heads, dim]. This change removes the assumption that batch == 1.

* Makes sdpa_with_kv_cache apply cpu_flash_attention for all batch lines as well.

ExecuTorch-exported Llama models are implemented with a greedy search, so it has not been necessary for this op to be batch-aware. However when working with other models, or when doing LLM beam search, this is no longer true.

Reviewed By: kimishpatel, tarun292

Differential Revision: D61605316

fbshipit-source-id: 5274c8f65967a7ef7b9fa70e60c86d07269dd769
  • Loading branch information
meta-emilian authored and facebook-github-bot committed Sep 18, 2024
1 parent 1e4c316 commit 53c1a5f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 23 deletions.
58 changes: 45 additions & 13 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,23 @@ void update_cache(
const Tensor& cache,
int64_t start_pos,
int64_t seq_length) { // NOLINT: unused parameter 'seq_length'
// 1) Cache shape should be [bs, max_seq_len, num heads, head dim]
// 2) projected_value shape should be [bs, seq_len, num heads, head dim]
// 3) We're updating the cache with projected_value, at position start_pos

ET_CHECK_MSG(
projected_value.size(0) == cache.size(0),
"projected_value batch size should be equal to the cache batch size.");
ET_CHECK_MSG(
projected_value.size(2) == cache.size(2),
"projected_value number of heads should be equal to the cache number of heads.");
ET_CHECK_MSG(
projected_value.size(0) == 1,
"projected_value must have batch size of 1");
ET_CHECK_MSG(cache.size(0) == 1, "cache must have batch size of 1");
projected_value.size(3) == cache.size(3),
"projected_value embedding dimension should be equal to the cache embedding dimension.");
ET_CHECK_MSG(
projected_value.element_size() == cache.element_size(),
"projected_value data type size should be equal to the cache data type size.");

ET_CHECK_MSG(
is_contiguous_dim_order(
projected_value.dim_order().data(), projected_value.dim()),
Expand All @@ -714,16 +727,31 @@ void update_cache(
ET_CHECK_MSG(projected_value_data != nullptr, "projected_value data is null");
ET_CHECK_MSG(cache_data, "cache data is null");

auto strides = cache.strides();
exec_aten::StridesType seq_dim_stride = strides[1];
exec_aten::SizesType pos_offset = start_pos * seq_dim_stride;
exec_aten::SizesType pos_offset_bytes =
pos_offset * projected_value.element_size();
exec_aten::SizesType num_bytes =
projected_value.numel() * projected_value.element_size();
// NOLINTNEXTLINE
std::memcpy(
(uint8_t*)cache_data + pos_offset_bytes, projected_value_data, num_bytes);
auto cache_strides = cache.strides();
exec_aten::StridesType cache_batch_dim_stride = cache_strides[0];
exec_aten::StridesType cache_seq_dim_stride = cache_strides[1];

auto value_strides = projected_value.strides();
exec_aten::StridesType value_batch_dim_stride = value_strides[0];

exec_aten::SizesType num_bytes_to_copy =
(projected_value.numel() / projected_value.size(0)) *
projected_value.element_size();

for (int64_t batch_line = 0; batch_line < projected_value.size(0);
++batch_line) {
exec_aten::SizesType cache_pos_offset =
(batch_line * cache_batch_dim_stride +
start_pos * cache_seq_dim_stride) *
cache.element_size();
exec_aten::SizesType value_pos_offset =
(batch_line * value_batch_dim_stride) * cache.element_size();

std::memcpy(
(uint8_t*)cache_data + cache_pos_offset,
(uint8_t*)projected_value_data + value_pos_offset,
num_bytes_to_copy);
}
}

} // anonymous namespace
Expand Down Expand Up @@ -859,6 +887,8 @@ Tensor& sdpa_with_kv_cache_out(
sliced_key_dim_order.data(),
util::kKVDim,
sliced_key_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_key_strides[0] = key_cache.strides()[0];
void* key_cache_data = key_cache.mutable_data_ptr();
TensorImpl k_impl = TensorImpl(
key_cache.scalar_type(),
Expand All @@ -883,6 +913,8 @@ Tensor& sdpa_with_kv_cache_out(
sliced_value_dim_order.data(),
util::kKVDim,
sliced_value_strides.data());
// since the cache is sliced, the batch stride needs to stay the same.
sliced_value_strides[0] = value_cache.strides()[0];
void* value_cache_data = value_cache.mutable_data_ptr();
TensorImpl value_impl = TensorImpl(
value_cache.scalar_type(),
Expand Down
27 changes: 17 additions & 10 deletions extension/llm/custom_ops/test_sdpa_with_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ class SDPATestCommon(unittest.TestCase):

def setup_caches(self):
self.k_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.v_cache = torch.zeros(
(1, self.max_seq_len, self.n_heads_kv, self.head_dim)
(self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
)
self.mask = torch.full(
(self.max_seq_len, self.max_seq_len),
Expand All @@ -386,6 +386,7 @@ def setup_caches(self):

def setUp(self):
torch.manual_seed(42)
self.n_batch = 5
self.n_heads_kv = 32
self.n_heads_q = 32
self.head_dim = 128
Expand All @@ -410,27 +411,27 @@ def _test_sdpa_common(
scale_tensors=False,
):
# Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
tensor_scale_max = 20
tensor_scale_min = -20
tensor_scale_max = 15
tensor_scale_min = -15
self.n_heads_kv = n_heads_kv
self.n_heads_q = n_heads_q
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.setup_caches()
q = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, seq_len, self.n_heads_kv, self.head_dim)),
torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
Expand All @@ -448,19 +449,25 @@ def _test_sdpa_common(
self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))

q = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
k = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
)
v = self._scale_tensor(
torch.rand((1, next_iter_seq_len, self.n_heads_kv, self.head_dim)),
torch.rand(
(self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
),
tensor_scale_max,
tensor_scale_min,
scale_tensors,
Expand Down

0 comments on commit 53c1a5f

Please sign in to comment.