@@ -168,13 +168,53 @@ void ShiftPointerIoMgr::init_io() {
168168 }
169169}
170170
171- void ShiftPointerIoMgr::reset_io () {
171+ void ShiftPointerIoMgr::reset_io (
172+ const std::vector<executorch::runtime::Result<
173+ executorch::runtime::MethodMeta>>& prefill_methods_meta,
174+ const std::vector<
175+ executorch::runtime::Result<executorch::runtime::MethodMeta>>&
176+ kv_methods_meta) {
172177 IO* ptr = static_cast <IO*>(data_ptr_.get ());
178+ std::fill (ptr->prefill_input_pos .begin (), ptr->prefill_input_pos .end (), 0 );
179+ ptr->kv_input_pos = 0 ;
173180 std::fill (
174181 ptr->prefill_attention_mask .begin (),
175182 ptr->prefill_attention_mask .end (),
176183 0 );
177184 std::fill (ptr->kv_attention_mask .begin (), ptr->kv_attention_mask .end (), 0 );
185+
186+ input_tensors_[kv_forward_name_].clear ();
187+ input_tensors_[kv_forward_name_].resize (modules_.size ());
188+ output_tensors_[kv_forward_name_].clear ();
189+ output_tensors_[kv_forward_name_].resize (modules_.size ());
190+
191+ k_cache_in_[kv_forward_name_].clear ();
192+ v_cache_in_[kv_forward_name_].clear ();
193+ k_cache_out_[kv_forward_name_].clear ();
194+ v_cache_out_[kv_forward_name_].clear ();
195+
196+ input_tensors_[prefill_forward_name_].clear ();
197+ input_tensors_[prefill_forward_name_].resize (modules_.size ());
198+ output_tensors_[prefill_forward_name_].clear ();
199+ output_tensors_[prefill_forward_name_].resize (modules_.size ());
200+
201+ k_cache_in_[prefill_forward_name_].clear ();
202+ v_cache_in_[prefill_forward_name_].clear ();
203+ k_cache_out_[prefill_forward_name_].clear ();
204+ v_cache_out_[prefill_forward_name_].clear ();
205+
206+ switch (eval_mode_) {
207+ case EvalMode::kKVCached :
208+ prepare_kv_io (kv_methods_meta);
209+ break ;
210+ case EvalMode::kHybrid :
211+ prepare_prefill_io (prefill_methods_meta);
212+ prepare_kv_io (kv_methods_meta);
213+ break ;
214+ default :
215+ ET_CHECK_MSG (false , " unsupported mode" );
216+ break ;
217+ }
178218}
179219void ShiftPointerIoMgr::prepare_kv_io (
180220 const std::vector<Result<MethodMeta>>& methods_meta) {
@@ -893,7 +933,12 @@ void SmartMaskIoMgr::init_io() {
893933 ptr->init_io_ptrs (shared_ptr, io_bytes_map);
894934}
895935
896- void SmartMaskIoMgr::reset_io () {
936+ void SmartMaskIoMgr::reset_io (
937+ const std::vector<executorch::runtime::Result<
938+ executorch::runtime::MethodMeta>>& prefill_methods_meta,
939+ const std::vector<
940+ executorch::runtime::Result<executorch::runtime::MethodMeta>>&
941+ kv_methods_meta) {
897942 IO* ptr = static_cast <IO*>(data_ptr_.get ());
898943 int32_t prefill_attn_size = prefill_ar_len_ * context_len_;
899944 int32_t kv_attn_size = kv_ar_len_ * context_len_;
0 commit comments