@@ -167,9 +167,15 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
167167}
168168
169169void llm_graph_input_cls::set_input (const llama_ubatch * ubatch) {
170- if (cparams.embeddings && (
171- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
172- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
170+ if (!cparams.embeddings ) {
171+ return ;
172+ }
173+
174+ const bool is_last_tok = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
175+ arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding models use last token
176+
177+ if (is_last_tok) {
178+ // set output to the last token of each sequence
173179 const int64_t n_tokens = ubatch->n_tokens ;
174180 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
175181 const int64_t n_seqs = ubatch->n_seqs ;
@@ -180,23 +186,33 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180186 uint32_t * data = (uint32_t *) cls->data ;
181187 memset (cls->data , 0 , n_tokens * ggml_element_size (cls));
182188
189+ std::vector<int > last_pos (n_tokens, -1 );
190+ std::vector<int > last_row (n_tokens, -1 );
191+
183192 for (int s = 0 ; s < n_seqs; ++s) {
184193 const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
185194
186195 // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
187- GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK " );
196+ GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == LAST " );
188197
189198 for (int i = 0 ; i < n_seq_tokens; ++i) {
190199 const llama_pos pos = ubatch->pos [s*n_seq_tokens + i];
191200
192- if (pos == 0 ) {
193- data[seq_id] = s*n_seq_tokens + i;
201+ if (pos >= last_pos[seq_id]) {
202+ last_pos[seq_id] = pos;
203+ last_row[seq_id] = s*n_seq_tokens + i;
194204 }
195205 }
196206 }
197- }
198207
199- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
208+ for (int i = 0 ; i < n_tokens; ++i) {
209+ if (last_row[i] >= 0 ) {
210+ data[i] = last_row[i];
211+ }
212+ }
213+
214+ } else {
215+ // set output to first token of each sequence
200216 const int64_t n_tokens = ubatch->n_tokens ;
201217 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
202218 const int64_t n_seqs = ubatch->n_seqs ;
@@ -207,30 +223,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
207223 uint32_t * data = (uint32_t *) cls->data ;
208224 memset (cls->data , 0 , n_tokens * ggml_element_size (cls));
209225
210- std::vector<int > last_pos (n_tokens, -1 );
211- std::vector<int > last_row (n_tokens, -1 );
212-
213226 for (int s = 0 ; s < n_seqs; ++s) {
214227 const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
215228
216229 // TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
217- GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == LAST " );
230+ GGML_ASSERT (seq_id < n_tokens && " seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK " );
218231
219232 for (int i = 0 ; i < n_seq_tokens; ++i) {
220233 const llama_pos pos = ubatch->pos [s*n_seq_tokens + i];
221234
222- if (pos >= last_pos[seq_id]) {
223- last_pos[seq_id] = pos;
224- last_row[seq_id] = s*n_seq_tokens + i;
235+ if (pos == 0 ) {
236+ data[seq_id] = s*n_seq_tokens + i;
225237 }
226238 }
227239 }
228-
229- for (int i = 0 ; i < n_tokens; ++i) {
230- if (last_row[i] >= 0 ) {
231- data[i] = last_row[i];
232- }
233- }
234240 }
235241}
236242
@@ -943,7 +949,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
943949}
944950
945951ggml_tensor * llm_graph_context::build_inp_cls () const {
946- auto inp = std::make_unique<llm_graph_input_cls>(cparams);
952+ auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch );
947953
948954 auto & cur = inp->cls ;
949955
0 commit comments