From 2bee171734cc5a53cc049a16ed235fbbd7c7b519 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 19 Mar 2024 09:42:45 -0400 Subject: [PATCH] [Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache Since #16692 introduced the copy stream separation, the function `GetQueryPositions` also needs to eagerly call sync to work properly. This PR fixes the previous wrong behavior. --- src/runtime/relax_vm/kv_state.h | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 2227944b8653..f6857a9dceae 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj { * This function is supposed to be invoked after calling BeginForward. * \return The in-sequence query positions, in shape `(total_length,)`. */ - virtual NDArray GetQueryPositions() const = 0; + virtual NDArray GetQueryPositions() = 0; /************** Debug Helpers **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 0c64800cec2d..9c3ee5d427c2 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - NDArray GetQueryPositions() const final { - CHECK(!dirty_aux_data_device_) - << "The auxiliary arrays are not synchronized to device. Please call " - "`BeginForward` to synchronize before calling `GetQueryPositions`."; + NDArray GetQueryPositions() final { + // Sync the copy stream and the compute stream. + ComputeStreamWaitForCopyStream(); + // The auxiliary data structure on device must have been synchronized. + ICHECK(!dirty_aux_data_device_); return q_rope_position_map_view_; };