@@ -873,16 +873,16 @@ struct LLM_TN {
873873// gguf helpers
874874//
875875
876- static const std::map<int32_t , const char *> LLAMA_ROPE_SCALING_TYPES = {
876+ static const std::map<llama_rope_scaling_type , const char *> LLAMA_ROPE_SCALING_TYPES = {
877877 { LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
878878 { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
879879 { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
880880};
881881
882- static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
882+ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
883883 for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
884884 if (kv.second == name) {
885- return kv.first;
885+ return (llama_rope_scaling_type) kv.first;
886886 }
887887 }
888888
@@ -1612,16 +1612,16 @@ struct llama_hparams {
16121612 float rope_freq_base_train;
16131613 float rope_freq_scale_train;
16141614 uint32_t n_yarn_orig_ctx;
1615- int32_t rope_scaling_type_train;
16161615
16171616 float f_clamp_kqv = 0.0f;
16181617 float f_max_alibi_bias = 0.0f;
16191618
16201619 bool causal_attn = true;
16211620 bool need_kq_pos = false;
16221621
1623- enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1624- enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1622+ enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
1623+ enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
1624+ enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
16251625
16261626 bool operator!=(const llama_hparams & other) const {
16271627 if (this->vocab_only != other.vocab_only) return true;
@@ -1670,8 +1670,8 @@ struct llama_cparams {
16701670 uint32_t n_threads; // number of threads to use for generation
16711671 uint32_t n_threads_batch; // number of threads to use for batch processing
16721672
1673- float rope_freq_base;
1674- float rope_freq_scale;
1673+ float rope_freq_base;
1674+ float rope_freq_scale;
16751675
16761676 uint32_t n_yarn_orig_ctx;
16771677 // These hyperparameters are not exposed in GGUF, because all
@@ -1683,7 +1683,7 @@ struct llama_cparams {
16831683 float defrag_thold;
16841684
16851685 bool offload_kqv;
1686- bool do_pooling ;
1686+ enum llama_pooling_type pooling_type ;
16871687
16881688 ggml_backend_sched_eval_callback cb_eval;
16891689 void * cb_eval_user_data;
@@ -2933,7 +2933,11 @@ template<>
29332933bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
29342934 uint32_t tmp;
29352935 const bool found = get_key(kid, tmp, required);
2936- result = (enum llama_pooling_type) tmp;
2936+ if (found) {
2937+ result = (enum llama_pooling_type) tmp;
2938+ } else {
2939+ result = LLAMA_POOLING_TYPE_UNSPECIFIED;
2940+ }
29372941 return found;
29382942}
29392943
@@ -3210,7 +3214,7 @@ static void llm_load_hparams(
32103214 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
32113215 ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
32123216 ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
3213- ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
3217+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false );
32143218
32153219 switch (hparams.n_layer) {
32163220 case 3:
@@ -5175,7 +5179,7 @@ struct llm_build_context {
51755179 n_kv (worst_case ? n_ctx : kv_self.n),
51765180 kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
51775181 n_orig_ctx (cparams.n_yarn_orig_ctx),
5178- pooling_type (cparams.do_pooling ? hparams. pooling_type : LLAMA_POOLING_TYPE_NONE ),
5182+ pooling_type (cparams.pooling_type),
51795183 rope_type (hparams.rope_type),
51805184 cb (cb),
51815185 buf_compute_meta (lctx.buf_compute_meta) {
@@ -8015,7 +8019,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
80158019 }
80168020 }
80178021
8018- if (cparams.do_pooling && hparams. pooling_type == LLAMA_POOLING_TYPE_MEAN) {
8022+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
80198023 const int64_t n_tokens = batch.n_tokens;
80208024
80218025 GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -8043,7 +8047,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
80438047 }
80448048 }
80458049
8046- if (cparams.do_pooling && hparams. pooling_type == LLAMA_POOLING_TYPE_CLS) {
8050+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
80478051 const int64_t n_tokens = batch.n_tokens;
80488052
80498053 GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -11846,6 +11850,7 @@ struct llama_context_params llama_context_default_params() {
1184611850 /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
1184711851 /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
1184811852 /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
11853+ /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
1184911854 /*.rope_freq_base =*/ 0.0f,
1185011855 /*.rope_freq_scale =*/ 0.0f,
1185111856 /*.yarn_ext_factor =*/ -1.0f,
@@ -11861,7 +11866,6 @@ struct llama_context_params llama_context_default_params() {
1186111866 /*.logits_all =*/ false,
1186211867 /*.embedding =*/ false,
1186311868 /*.offload_kqv =*/ true,
11864- /*.do_pooling =*/ true,
1186511869 /*.abort_callback =*/ nullptr,
1186611870 /*.abort_callback_data =*/ nullptr,
1186711871 };
@@ -12012,7 +12016,7 @@ struct llama_context * llama_new_context_with_model(
1201212016 cparams.yarn_beta_slow = params.yarn_beta_slow;
1201312017 cparams.defrag_thold = params.defrag_thold;
1201412018 cparams.offload_kqv = params.offload_kqv;
12015- cparams.do_pooling = params.do_pooling ;
12019+ cparams.pooling_type = params.pooling_type ;
1201612020
1201712021 cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
1201812022 cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -12038,6 +12042,14 @@ struct llama_context * llama_new_context_with_model(
1203812042 cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
1203912043 }
1204012044
12045+ if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12046+ if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
12047+ cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
12048+ } else {
12049+ cparams.pooling_type = hparams.pooling_type;
12050+ }
12051+ }
12052+
1204112053 if (params.seed == LLAMA_DEFAULT_SEED) {
1204212054 params.seed = time(NULL);
1204312055 }
0 commit comments