From dd04ea94f54af9b274b4199f06dcfa2f57ced851 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Tue, 11 Mar 2025 09:48:53 -0700 Subject: [PATCH] Fix static attention mask update (#9101) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/9101 The range based for loop was making a copy of the mask, and thus the updates did not take effect. Remove the copy and move constructors of StaticKVCache and StaticAttention as they are not needed. Also add the missing deallocate call in mask's destructor. Reviewed By: billmguo Differential Revision: D70914174 --- .../llama/runner/static_attention_io_manager.h | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index f5a8c04b085..127414ac432 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -38,6 +38,11 @@ class StaticKVCache { reset(); } + StaticKVCache(const StaticKVCache& other) = delete; + StaticKVCache& operator=(const StaticKVCache& other) = delete; + StaticKVCache(StaticKVCache&& other) = delete; + StaticKVCache& operator=(StaticKVCache&& other) = delete; + ~StaticKVCache() { allocator_.deallocate(data_, data_size_); } @@ -200,6 +205,15 @@ class StaticAttentionMask { reset(); } + StaticAttentionMask(const StaticAttentionMask& other) = delete; + StaticAttentionMask& operator=(const StaticAttentionMask& other) = delete; + StaticAttentionMask(StaticAttentionMask&& other) = delete; + StaticAttentionMask& operator=(StaticAttentionMask&& other) = delete; + + ~StaticAttentionMask() { + allocator_.deallocate(data_, data_size_); + } + /** * Reset the mask to the state where the cache contains no valid data. */ @@ -315,7 +329,7 @@ class StaticAttentionIOManager { input_pos_ += update_len; kCaches_.update(method, k_cache_output_indices, update_len); vCaches_.update(method, v_cache_output_indices, update_len); - for (auto it : attentionMasks_) { + for (auto& it : attentionMasks_) { it.second.updateCacheMask(update_len); } } @@ -324,7 +338,7 @@ class StaticAttentionIOManager { input_pos_ = 0; kCaches_.reset(); vCaches_.reset(); - for (auto it : attentionMasks_) { + for (auto& it : attentionMasks_) { it.second.reset(); } }