You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The NoBadWordsLogitsProcessor class nested loops can be slow to run when you have a bunch of bad words. That is my case for instance on distilgpt2 that has ~800 bad words in its vocabulary.
Building a static Map can speed up the look ups. something like
const bad_words_map = new Map();
for (const bad_word_ids of this.bad_words_ids) {
const key = bad_word_ids.at(-1);
if (!bad_words_map.has(key)) {
bad_words_map.set(key, []);
}
bad_words_map.get(key).push(bad_word_ids.slice(0, -1));
}
and then
_call(input_ids, logits) {
for (let i = 0; i < input_ids.length; ++i) {
const batch_logits_data = /** @type {Float32Array} */(logits[i].data);
const ids = input_ids[i];
const last_id = ids.at(-1);
if (bad_words_map.has(last_id)) {
const prefixes = bad_words_map.get(last_id);
for (const prefix of prefixes) {
if (ids.slice(-prefix.length).every((v, idx) => v === prefix[idx])) {
batch_logits_data[last_id] = -Infinity;
break;
}
}
}
}
return logits;
}
The text was updated successfully, but these errors were encountered:
I've done a bit of testing and it's a bit trickier than that, unfortunately. The bad_words_ids is a list of lists structured as follows:
[
[a], // ALWAYS block a
[b, c], // only block c if preceded by [b]
[d], // ALWAYS block d
[e, f, g], // only block g if preceded by [e, f]
...
]
this means we still need to iterate over the entire list - especially to handle these "single bad words". This code:
for(letj=1;j<=bad_word_ids.length-1&&bad_word_ids.length<ids.length;++j){// NOTE: We use != instead of !== to compare bigint and number// @ts-ignoreif(bad_word_ids.at(-j-1)!=ids.at(-j)){// We have found a mismatchmark=false;break;}}
will check if the tokens before the last in the block list match the last ids, and if not, we won't block the last id in the block list.
The good news is that you shouldn't see a massive difference in performance. For the block list of 800, I only see a ~10ms difference in the unit test I created. For a block list of 100 000, the difference is more noticeable, but I don't see that happening in practice.
System Info
v3
Environment/Platform
Description
consider use a Map in
NoBadWordsLogitsProcessor
Reproduction
The NoBadWordsLogitsProcessor class nested loops can be slow to run when you have a bunch of bad words. That is my case for instance on distilgpt2 that has ~800 bad words in its vocabulary.
Building a static Map can speed up the look ups. something like
and then
The text was updated successfully, but these errors were encountered: