diff --git a/examples/chat.py b/examples/chat.py index 4b839f42..151e0a05 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -36,6 +36,7 @@ parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt") parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)") +parser.add_argument("-smthfctr", "--smoothing_factor", type = float, default = 0.0, help = "Smoothing Factor, default = 0.0 (0 to disable") parser.add_argument("-dyntemp", "--dynamic_temperature", type = str, help = "Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1") parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)") parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)") @@ -196,6 +197,7 @@ def get_tokenized_context(max_len): settings.token_repetition_penalty = args.repetition_penalty settings.token_frequency_penalty = args.frequency_penalty settings.token_presence_penalty = args.presence_penalty +settings.smoothing_factor = args.smoothing_factor if args.dynamic_temperature: dt_args = [float(alloc) for alloc in args.dynamic_temperature.split(",")] diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index 79a402d2..3166b260 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -95,6 +95,36 @@ void apply_rep_penalty_cpu } } +void quadratic_sampling +( + const int vocab_size, + const float temperature, + float* logits, + const bool* logits_filter, + float smoothing_factor, + float* output +) +{ + // Calculate maxl as the maximum logit value + float maxl = -1e38; + for (int i = 0; i < vocab_size; i++) + { + if (!logits_filter[i]) continue; + maxl = fmaxf(logits[i], maxl); + } + + for (int i = 0; i < vocab_size; i++) + { + if (!logits_filter[i]) continue; + float logit_shifted = logits[i] - maxl; + logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl; + // Limit the range of logits to prevent extreme values + logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20); + } + + softmax_cpu(vocab_size, temperature, logits, logits_filter, output); +} + void softmax_cpu ( const int vocab_size, diff --git a/exllamav2/exllamav2_ext/cpp/sampling.h b/exllamav2/exllamav2_ext/cpp/sampling.h index e26cd6cf..ac923db6 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.h +++ b/exllamav2/exllamav2_ext/cpp/sampling.h @@ -28,6 +28,16 @@ void softmax_cpu float* output ); +void quadratic_sampling +( + const int vocab_size, + const float temperature, + float* logits, + const bool* logits_filter, + float smoothing_factor, + float* output +); + void normalize_cpu ( const int num_candidates, diff --git a/exllamav2/exllamav2_ext/ext.cpp b/exllamav2/exllamav2_ext/ext.cpp index 040afd64..71bee7af 100644 --- a/exllamav2/exllamav2_ext/ext.cpp +++ b/exllamav2/exllamav2_ext/ext.cpp @@ -987,7 +987,8 @@ std::vector sample_basic float post_temperature, float min_temp = 0, float max_temp = 0.0f, - float temp_exponent = 1.0f + float temp_exponent = 1.0f, + float smoothing_factor = 0.0f ) { TORCH_CHECK_DTYPE(logits, kFloat); @@ -1019,14 +1020,30 @@ std::vector sample_basic for (int i = 0; i < bsz; i++) { - softmax_cpu - ( - vocab_size, - temperature, - logits_ptr + i * vocab_size, - logits_filter_ptr + i * vocab_size, - temp_probs - ); + if (smoothing_factor > 0) + { + // Apply quadratic_sampling to the logits + quadratic_sampling + ( + vocab_size, + temperature, + logits_ptr + i * vocab_size, + logits_filter_ptr + i * vocab_size, + smoothing_factor, + temp_probs + ); + } + else + { + softmax_cpu + ( + vocab_size, + temperature, + logits_ptr + i * vocab_size, + logits_filter_ptr + i * vocab_size, + temp_probs + ); + } if (top_k == 1) { @@ -1063,6 +1080,11 @@ std::vector sample_basic normalize_cpu(num_candidates, temp_probs); } + if (smoothing_factor > 0) + { + + } + if (tfs > 0.0f && tfs < 1.0f) { num_candidates = tfs_cpu(num_candidates, temp_probs, temp_indices, tfs); diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index e969b94f..c1f225b5 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -15,6 +15,7 @@ class Settings: token_presence_penalty = 0.0 temperature = 0.8 + smoothing_factor = 0.0 min_temp = 0 max_temp = 0.0 temp_exponent = 1.0 @@ -50,6 +51,7 @@ def clone(self): c.token_presence_penalty = self.token_presence_penalty c.temperature = self.temperature + c.smoothing_factor = self.smoothing_factor c.min_temp = self.min_temp c.max_temp = self.max_temp c.temp_exponent = self.temp_exponent @@ -220,7 +222,8 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, settings.temperature if settings.temperature_last else 1.0, settings.min_temp, settings.max_temp, - settings.temp_exponent) + settings.temp_exponent, + settings.smoothing_factor) if settings.mirostat: settings.mirostat_mu = m