Skip to content

Commit

Permalink
test sdpa with fp16 (pytorch#553)
Browse files Browse the repository at this point in the history
* test sdpa with fp16

* kv cache fp32

* typo
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 49fd4b7 commit c4b430a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions export_et_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class CustomKVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
super().__init__()

dtype = torch.float

# This is flipped around from what is in build.model's KVCache
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer(
Expand All @@ -21,8 +23,8 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
def update(self, input_pos, k_val, v_val):
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
k_out[:, :, input_pos] = k_val.float()
v_out[:, :, input_pos] = v_val.float()

return k_out, v_out

Expand Down Expand Up @@ -67,15 +69,15 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
# KV cache should always be enabled
assert self.kv_cache is not None
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
q.float(),
k.float(),
v.float(),
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, self.dim)
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
return self.wo(output)


Expand Down

0 comments on commit c4b430a

Please sign in to comment.