Make string ban more robust and add regex ban#1243
Conversation
|
If I may, create different branches on your repo for your different PRs, because you erased the content of your previous one with this one (which is fine, it's a continuation, but not the best practice to help those who want to access your previous code with ease!). I will test this. |
|
@Nexesenex https://github.com/SneedwareInc/ik_llama.cpp/tree/legacy |
|
@firecoperana Do you want to look again at this PR? |
|
Yes, I will look at it when it's ready. |
|
@firecoperana What do you want me to change/add? |
| // could be improved to support more languages | ||
| std::string string_lower(const std::string& str) { | ||
| std::string result = str; | ||
| for (char& c : result) { |
There was a problem hiding this comment.
I would keep this. No functional change.
| s = string_lower(s); | ||
| auto ban_tokens = common_tokenize(llama_get_model(ctx), s, false, true); | ||
| if (ban_tokens.size() > slot.n_buffer) { | ||
| slot.n_buffer = ban_tokens.size(); |
There was a problem hiding this comment.
Why use the length of the string over tokens count? The buffer holds tokens, not each character.
| auto ban_tokens = common_tokenize(llama_get_model(ctx), val, false, true); | ||
| if (ban_tokens.size() > slot.n_buffer) { | ||
| slot.n_buffer = ban_tokens.size(); | ||
| // Use string length instead of token count |
|
|
||
| count++; | ||
| if (!has_next) { | ||
| if (slot.stopped_limit && !slot.stopped_eos && !slot.stopped_word) { |
examples/server/server-context.cpp
Outdated
| slot.token_buffer.resize(n_keep_buffer); | ||
|
|
||
| // Adjust decoded count | ||
| slot.n_decoded -= n_rewind; |
There was a problem hiding this comment.
Don't change slot.n_decoded. This will make prompt processing and token generation time and speed calculation incorrect.
| n_rewind = check_ban_phrase(slot); | ||
| } | ||
| // if found string in the ban | ||
| if (n_rewind > 0 && (slot.rewind_count <20 || slot.rewind_count <= 2 * slot.ban_phrases.size())) { |
There was a problem hiding this comment.
Need some kind of logic to limit the number of times to rewind.
| generated_token_probs.clear(); | ||
|
|
||
| positional_bans.clear(); | ||
| ban_phrases.clear(); |
There was a problem hiding this comment.
Put them in server_slot::reset()
examples/server/server-context.cpp
Outdated
| // Check if we have specific bans for this exact position (slot.n_past) | ||
| // Note: slot.n_past is the index of the token we are about to generate. | ||
| auto pos_ban_it = slot.positional_bans.find(slot.n_past); | ||
| std::vector<llama_token> temp_banned; |
There was a problem hiding this comment.
This code and the code below could be moved inside rewind_context as it's currently done. Use slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; to adjust logit_bias.
Okay
Edge cases like when ALLCAPS gets tokenized as
This specific code block is strictly necessary to prevent valid tokens from being silently discarded when a generation reaches its maximum token limit. Because the server buffers tokens to check for banned phrases, several safe, generated tokens are often waiting in the queue. If this
Do you have any suggestions for an elegant solution that makes sure that
Why set an arbitrary limit? With regex there are many more banned combinations possible per item than with strings.
They are already there?
It is needed. Without it the program does not function correctly. I know you don't test your code properly, so let me demonstrate: Without it: With it:
Okay |
|
That's stupid and does not catch edge cases. Let's keep it simple and reliable longest string/regex length+1.
What delays? Care to demonstrate?
Okay
Okay, I'll set the default limit to a reasonable amount(512) that works for me, someone who uses this functionality a lot, and add an option to set it to whatever you want if you are so afraid it will get stuck. |
|
@SneedwareInc I would appreciate of you were slightly more respectful in your responses to @firecoperana. Thank you. |
|
@ikawrakow How am I disrespectful? |
|
@SneedwareInc I think he is referring to using phrase like "That's stupid and does not catch edge cases" instead of just "does not catch edge cases", or "I know you don't test your code properly, so let me demonstrate" instead of just "let me demonstrate". Imagine phrases like that would be directed at you when you originally missed a lot of use cases. Anyway, I appreciate your work, but firecoperana also putting a lot of effort... and not just here, he did enormous amount of work and contributions. For what's it worth, I also put a lot of work in testing the previous patches, but I did not catch the edge cases you have mentioned... My point is, not catching or missing edge cases does not mean not testing properly or being stupid, it is just implementing features is a lot of work. Your previous patch also was missing a lot of cases, so it is really hard to take into account everything. This is why discussing and testing things together always helps, for non-trivial features that cover massive amount of use cases, it is not really possible for a single person to think of them all, some suggested optimization or ones that you think of yourself may or may not cover all possible cases... so my two cents if you think something wrong being suggested, just explain why and what issues it would cause, no need for negative phrases. |
I would not care, but that's me, I know that LLMs can make mistakes in code, I am prepared for insults. But I'll soften my language for future interactions, thanks for pointing out. |
|
@firecoperana I've updated the code. Is this what you wanted?
|
|
Yes, that works. One last thing is to revert the change of the buffer size for ban strings or change to what I suggested. |
@firecoperana Here I have to disagree. My position comes from noticing how aggressively some LLMs have tried to bypass the banned strings. I had "core guidelines" banned when I used tokenization method for buffer size determination, but the model sneakily went in and bypassed it by writing CORE GUIDELINES, which exceeded the buffer and did not trigger a rewind. Having to guess which variation will lead to maximal tokenization like you are suggesting is impossible. I am not suggesting using string length+1 without a reason. Other than that, you seem to not understand that there is no significant speed penalty when buffering, just visual representation latency. Let me prove my point with data:
Times are in ms. Mistral Nemo, Q6_K, Prompt: Can you explain to me why do you think your approach is superior? |
|
@SneedwareInc If you can find time, could you please rebase your patch, I tried apply https://github.com/ikawrakow/ik_llama.cpp/pull/1243.patch but it has many conflicts: |
|
@Lissanro For some reason mainline no longer works for me. Same prompt as above, same settings, mainline crashed after generating just one token, no error log: @ikawrakow @firecoperana Can you please look into that? |
|
Does #1304 fix it? |
|
@firecoperana It no longer crashes, but string ban doesn't work: |
|
#1310 fixes it. |
|
@firecoperana done! |
firecoperana
left a comment
There was a problem hiding this comment.
This PR changes many existing code, so it will take a while for me to fully review. When you use AI to write PR, watch out for any code that it removed by AI.
examples/server/server-context.cpp
Outdated
| if (ban_pos >= 0 && allow_rewind) { | ||
| rewind_context(slot, ban_pos); | ||
| slot.rewind_status = true; | ||
| slot.ctx_sampling->rewind_samplers = true; |
There was a problem hiding this comment.
Why is this removed?
| } | ||
| } | ||
| n++; | ||
| if (slot.banned_n > 0 && n == slot.banned_n) { |
There was a problem hiding this comment.
banned_n no longer works
examples/server/server-context.cpp
Outdated
| LLAMA_LOG_INFO("Banned pattern detected at pos %d. Banning token %d ('%s') and rewinding.\n", | ||
| abs_pos, banned_tok, slot.token_buffer[token_idx].text_to_send.c_str()); | ||
|
|
||
| slot.positional_bans[abs_pos].insert(banned_tok); |
There was a problem hiding this comment.
Can this be moved to rewind_context? This function should just check whether ban string existed. If possible, can you make it return n_rewind? There is no need to make additional changes in this function except for adding regex bans detection. With n_rewind being returned, there is less change in rewind_context function too.
|
@firecoperana I got back rewind_samplers and record_samplers, restored banned_n and moved the code to rewind_context. Is there anything else you wish me to do? |
|
slot.ctx_sampling->rewind_samplers = true; and slot.ctx_sampling->record_samplers = true; are breaking something. The bans work, but the output quality is degraded when they are present. I will remove them. |
|
I still see Do you mind if I copy your code and create a clean PR? Your PR removes more code that is not related to regex ban than the last time I checked. |
|
Perhaps #1359 fixes the issue. @SneedwareInc if you would test it with string/regex bans, I would love to hear the results. |
I'll look into that.
Such as?
No, I do not use or even know what adaptive p is. I only use temperature and TFS. It is very concerining if rewind samplers causes it to turn on or affects other samplers that should not be affected.
You're welcome to cherry-pick this or copy it exactly as-is into a clean branch. However, please don't rewrite the implementation logic, just copy it verbatim. The last rewrite introduced bugs that weren't caught because they weren't tested against edge cases. If the issue is just formatting or the unrelated deletions, I can clean those up myself in this PR. I'd strongly prefer we fix the current one rather than risk another untested rewrite.
I'll look into that. |
| } | ||
| else if (penalty_prompt->is_array()) { | ||
| const auto n_tokens = penalty_prompt->size(); | ||
| slot.sparams.penalty_prompt_tokens.clear(); |
|
|
||
| const auto preserved_tokens = data.find("preserved_tokens"); | ||
| if (preserved_tokens != data.end()) { | ||
| slot.sparams.preserved_tokens.clear(); |
| } | ||
| const auto grammar_triggers = data.find("grammar_triggers"); | ||
| if (grammar_triggers != data.end()) { | ||
| slot.sparams.grammar_triggers.clear(); |
|
|
||
| slot.logit_bias = slot.sparams.logit_bias; // keep a copy to restore | ||
| slot.ban_phrases_bias = json_value(data, "banned_bias", params_base.ban_phrases_bias); | ||
| slot.banned_n = json_value(data, "banned_n", params_base.banned_n); |
| slot.n_past_prompt++; | ||
| slot.n_past++; | ||
| slot.do_checkpoint = false; | ||
| if (params_base.do_checkpoint && slot.n_prompt_tokens - slot.n_past_prompt == params_base.ctx_checkpoints_tolerance) { |
| if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) { | ||
| // save checkpoint during prompt processing | ||
| if (slot.command == SLOT_COMMAND_LOAD_PROMPT) { | ||
| if (slot.do_checkpoint) { |
| slot.t_start_generation = ggml_time_us(); | ||
| slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; | ||
| metrics.on_prompt_eval(slot); | ||
| // create checkpoint after prompt processing ends |
| } | ||
| } | ||
|
|
||
| // create checkpoint during generation |
|
@dungquixote42 Quality degradation is still there, but feels less severe than before. Could it be that adaptive_p somehow gets auto-enabled(I never had it enabled)? Or is it the way I copied it over? @firecoperana Fixed. |
I fetched this PR and ran it with test code. The adaptive P sampler is working as intended as far as I can tell. That is, no-op when its target is <0. |
firecoperana
left a comment
There was a problem hiding this comment.
Besides adding back the code that is removed, also unify the slot.banned_n with 1 and !=1. There is no need to do a special case with slot.banned_n==1 for positional ban, recovering and setting logit bias.
examples/server/server-context.cpp
Outdated
| { | ||
| for (auto result = slot.token_buffer.begin() + n_keep_buffer; result != slot.token_buffer.end(); result++) { | ||
| if (!tokens.contains(result->tok)) { | ||
| slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; |
There was a problem hiding this comment.
| slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; | |
| if (!tokens.contains(result->tok)) { | |
| tokens.insert(result->tok); | |
| slot.ctx_sampling->params.logit_bias[result->tok] += slot.ban_phrases_bias; | |
| } |
There was a problem hiding this comment.
You can combine slot.banned_n==1 in this as well. No need to create a new if . Missing positional_bans in banned_n!=0.
| continue; // sample using speculative decoding | ||
| } | ||
|
|
||
| // RESTORE AND APPLY POSITIONAL BANS |
There was a problem hiding this comment.
Move this inside rewind_context.
examples/server/server-context.cpp
Outdated
| if (!slot.rewind_status) { | ||
| slot.ctx_sampling->params.logit_bias = slot.logit_bias; // restore logit bias | ||
|
|
||
| if (slot.banned_n != 1) { |
There was a problem hiding this comment.
What's the reason to special case banned_n!=1?
| for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { | ||
| result.probs.push_back({ | ||
| cur_p->data[i].id, | ||
| common_token_to_piece(ctx, cur_p->data[i].id, special), |
| for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { | ||
| result.probs.push_back({ | ||
| cur[i].id, | ||
| common_token_to_piece(ctx, cur[i].id, special), |
| } | ||
|
|
||
| slot.ctx_sampling->n_rewind = sent_results ? -1 : n_rewind; | ||
| if (slot.sparams.adaptive_target >= 0.0f) { |
There was a problem hiding this comment.
I am not sure if this check (and others elsewhere) is the right solution to the sampler running when it is not supposed to. I cannot reproduce this, so would you print adapt_p_ctx->target from llama_sample_adaptive_p_impl() and show us what they say? Preferably before and after the enable check.
|
Quality is degraded right now, not sure which of the changes caused it |
|
Should be fixed. Moving // RESTORE AND APPLY POSITIONAL BANS to rewind_context is not possible due to quality degradation. |
firecoperana
left a comment
There was a problem hiding this comment.
Not sure about the degradation due to adaptive p, but if there is, it can be fixed later.
Continuation of #1131.
This PR adds regex ban and makes string ban location dependent. Currently the string ban is flawed: if a token is banned, it is banned in the entire buffer. During my testing with long, overlapping strings it frequently backfired, for example if
Iwas banned in the beginning and later in the contextIwas needed, it produced nonsense. In this PR the ban is localized to specific token locations.New arguments:
banned_regex: accepts json with regex, case sensitivebanned_regex_case_insensitive: accepts json with regex, case insensitivebanbuffer_size: number, sets the size for the buffer, useful when using regex. Be default(or if 0) it is longest string/regex+1My ST fork for testing: https://github.com/SneedwareInc/ik_SillyTavern
Example ban list: https://huggingface.co/datasets/ChuckMcSneed/ExampleAntislop
Currently I know it works in text completion, not sure about chat completion or openai formats.