Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for user specified embedding pooling type #5849

Merged
merged 2 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ struct gpt_params {
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
18 changes: 9 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,16 +873,16 @@ struct LLM_TN {
// gguf helpers
//

static const std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
};

static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
return (llama_rope_scaling_type) kv.first;
}
}

Expand Down Expand Up @@ -1612,16 +1612,16 @@ struct llama_hparams {
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
int32_t rope_scaling_type_train;

float f_clamp_kqv = 0.0f;
float f_max_alibi_bias = 0.0f;

bool causal_attn = true;
bool need_kq_pos = false;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;

bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true;
Expand Down Expand Up @@ -1670,8 +1670,8 @@ struct llama_cparams {
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing

float rope_freq_base;
float rope_freq_scale;
float rope_freq_base;
float rope_freq_scale;

uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
Expand Down Expand Up @@ -11848,6 +11848,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
Expand All @@ -11863,7 +11864,6 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,10 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
Comment on lines +241 to +242
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we use C23/C++11 syntax to specify the underlying type of the enum, this defeats the purpose of using fixed-size types like int32_t in the API. In C this will be an int; in C++, there are few guarantees:

Declares an unscoped enumeration type whose underlying type is not fixed (in this case, the underlying type is an implementation-defined integral type that can represent all enumerator values; this type is not larger than int unless the value of an enumerator cannot fit in an int or unsigned int. [...])

Copy link
Member

@ggerganov ggerganov Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll go back to int32_t within #5796

Edit: On second thought, using enums is OK. I doubt this would ever cause problems

// (ignored if no pooling layer)

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
Expand All @@ -259,7 +262,6 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)

// Abort callback
// if it returns true, execution of llama_decode() will be aborted
Expand Down
Loading