Skip to content

Commit a9436b8

Browse files
authored
[Fix][Builtin] Fix "GetQueryPosition" of PagedKVCache (#16746)
Since #16692 introduced the copy stream separation, the function `GetQueryPositions` also needs to eagerly call sync to work properly. This PR fixes the previous wrong behavior.
1 parent 48cedc7 commit a9436b8

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/runtime/relax_vm/kv_state.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class AttentionKVCacheObj : public KVStateObj {
159159
* This function is supposed to be invoked after calling BeginForward.
160160
* \return The in-sequence query positions, in shape `(total_length,)`.
161161
*/
162-
virtual NDArray GetQueryPositions() const = 0;
162+
virtual NDArray GetQueryPositions() = 0;
163163

164164
/************** Debug Helpers **************/
165165

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
838838
}
839839
}
840840

841-
NDArray GetQueryPositions() const final {
842-
CHECK(!dirty_aux_data_device_)
843-
<< "The auxiliary arrays are not synchronized to device. Please call "
844-
"`BeginForward` to synchronize before calling `GetQueryPositions`.";
841+
NDArray GetQueryPositions() final {
842+
// Sync the copy stream and the compute stream.
843+
ComputeStreamWaitForCopyStream();
844+
// The auxiliary data structure on device must have been synchronized.
845+
ICHECK(!dirty_aux_data_device_);
845846
return q_rope_position_map_view_;
846847
};
847848

0 commit comments

Comments
 (0)