@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188188
189189void llm_graph_input_cls::set_input (const llama_ubatch * ubatch) {
190190 const int64_t n_tokens = ubatch->n_tokens ;
191- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
192191 const int64_t n_seqs_unq = ubatch->n_seqs_unq ;
193192
194193 if (cparams.embeddings && (
195- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197- )) {
194+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197+ )) {
198198 GGML_ASSERT (cls);
199199 GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
200200
201201 uint32_t * data = (uint32_t *) cls->data ;
202202 memset (cls->data , 0 , n_seqs_unq*ggml_element_size (cls));
203203
204- for (int i = 0 ; i < n_tokens; i += n_seq_tokens) {
205- for (int s = 0 ; s < ubatch->n_seq_id [i]; ++s) {
206- const llama_seq_id seq_id = ubatch->seq_id [i][s];
207- const int32_t seq_idx = ubatch->seq_idx [seq_id];
208-
209- data[seq_idx] = i;
210- }
211- }
212- }
213-
214- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
215- GGML_ASSERT (cls);
216- GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
217-
218- uint32_t * data = (uint32_t *) cls->data ;
219- memset (cls->data , 0 , n_seqs_unq*ggml_element_size (cls));
204+ std::vector<int > target_pos (n_seqs_unq, -1 );
205+ std::vector<int > target_row (n_seqs_unq, -1 );
220206
221- std::vector<int > last_pos (n_seqs_unq, -1 );
222- std::vector<int > last_row (n_seqs_unq, -1 );
207+ bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
223208
224209 for (int i = 0 ; i < n_tokens; ++i) {
225210 const llama_pos pos = ubatch->pos [i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228213 const llama_seq_id seq_id = ubatch->seq_id [i][s];
229214 const int32_t seq_idx = ubatch->seq_idx [seq_id];
230215
231- if (pos >= last_pos[seq_idx]) {
232- last_pos[seq_idx] = pos;
233- last_row[seq_idx] = i;
216+ if (
217+ (target_pos[seq_idx] == -1 ) ||
218+ ( last && pos >= target_pos[seq_idx]) ||
219+ (!last && pos < target_pos[seq_idx])
220+ ) {
221+ target_pos[seq_idx] = pos;
222+ target_row[seq_idx] = i;
234223 }
235224 }
236225 }
237226
238227 for (int s = 0 ; s < n_seqs_unq; ++s) {
239- if (last_row [s] >= 0 ) {
240- data[s] = last_row [s];
228+ if (target_row [s] >= 0 ) {
229+ data[s] = target_row [s];
241230 }
242231 }
243232 }
0 commit comments