Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")

parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)")
parser.add_argument("-smthfctr", "--smoothing_factor", type = float, default = 0.0, help = "Smoothing Factor, default = 0.0 (0 to disable")
parser.add_argument("-dyntemp", "--dynamic_temperature", type = str, help = "Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1")
parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)")
parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)")
Expand Down Expand Up @@ -196,6 +197,7 @@ def get_tokenized_context(max_len):
settings.token_repetition_penalty = args.repetition_penalty
settings.token_frequency_penalty = args.frequency_penalty
settings.token_presence_penalty = args.presence_penalty
settings.smoothing_factor = args.smoothing_factor

if args.dynamic_temperature:
dt_args = [float(alloc) for alloc in args.dynamic_temperature.split(",")]
Expand Down
30 changes: 30 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,36 @@ void apply_rep_penalty_cpu
}
}

void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float smoothing_factor,
float* output
)
{
// Calculate maxl as the maximum logit value
float maxl = -1e38;
for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
maxl = fmaxf(logits[i], maxl);
}

for (int i = 0; i < vocab_size; i++)
{
if (!logits_filter[i]) continue;
float logit_shifted = logits[i] - maxl;
logits[i] = -smoothing_factor * logit_shifted * logit_shifted + maxl;
// Limit the range of logits to prevent extreme values
logits[i] = fminf(fmaxf(logits[i], -1e20), 1e20);
}

softmax_cpu(vocab_size, temperature, logits, logits_filter, output);
}

void softmax_cpu
(
const int vocab_size,
Expand Down
10 changes: 10 additions & 0 deletions exllamav2/exllamav2_ext/cpp/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ void softmax_cpu
float* output
);

void quadratic_sampling
(
const int vocab_size,
const float temperature,
float* logits,
const bool* logits_filter,
float smoothing_factor,
float* output
);

void normalize_cpu
(
const int num_candidates,
Expand Down
40 changes: 31 additions & 9 deletions exllamav2/exllamav2_ext/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,8 @@ std::vector<float> sample_basic
float post_temperature,
float min_temp = 0,
float max_temp = 0.0f,
float temp_exponent = 1.0f
float temp_exponent = 1.0f,
float smoothing_factor = 0.0f
)
{
TORCH_CHECK_DTYPE(logits, kFloat);
Expand Down Expand Up @@ -1019,14 +1020,30 @@ std::vector<float> sample_basic

for (int i = 0; i < bsz; i++)
{
softmax_cpu
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
temp_probs
);
if (smoothing_factor > 0)
{
// Apply quadratic_sampling to the logits
quadratic_sampling
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
smoothing_factor,
temp_probs
);
}
else
{
softmax_cpu
(
vocab_size,
temperature,
logits_ptr + i * vocab_size,
logits_filter_ptr + i * vocab_size,
temp_probs
);
}

if (top_k == 1)
{
Expand Down Expand Up @@ -1063,6 +1080,11 @@ std::vector<float> sample_basic
normalize_cpu(num_candidates, temp_probs);
}

if (smoothing_factor > 0)
{

}

if (tfs > 0.0f && tfs < 1.0f)
{
num_candidates = tfs_cpu(num_candidates, temp_probs, temp_indices, tfs);
Expand Down
5 changes: 4 additions & 1 deletion exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Settings:
token_presence_penalty = 0.0

temperature = 0.8
smoothing_factor = 0.0
min_temp = 0
max_temp = 0.0
temp_exponent = 1.0
Expand Down Expand Up @@ -50,6 +51,7 @@ def clone(self):
c.token_presence_penalty = self.token_presence_penalty

c.temperature = self.temperature
c.smoothing_factor = self.smoothing_factor
c.min_temp = self.min_temp
c.max_temp = self.max_temp
c.temp_exponent = self.temp_exponent
Expand Down Expand Up @@ -220,7 +222,8 @@ def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor,
settings.temperature if settings.temperature_last else 1.0,
settings.min_temp,
settings.max_temp,
settings.temp_exponent)
settings.temp_exponent,
settings.smoothing_factor)

if settings.mirostat: settings.mirostat_mu = m

Expand Down