feat: use regular tokenizer in static embedding, add max length #3541
+73
−34
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hello!
This PR brings the StaticEmbedding a bit more in line with other modules. I had some trouble integrating it into pylate, because pylate assumes you are using a regular tokenizer. It then occurred to me that there is actually no real reason to not use a transformers-style tokenizer, it actually makes the API a bit nicer even.
So, this PR:
tokenizers.Tokenizerandtransformers.PreTrainedTokenizer. Both work. The difference is that we turn the former into the latter. This also makes the whole thing a bit more flexible, because now it should also work with slow tokenizers, which it didn't before.max_seq_lengthargument and associated property, which was easy to do because of the new tokenizer. During training, I noticed that long sequences could still slow down the thing, presumably due to the tensors sometimes just being obnoxiously large. I haven't tested whether truncating the sequence length leads to performance differences, however.word_embedding_dimensionpropertyAlso fixed some typing stuff. This still doesn't make it completely compatible with pylate, but almost. The other part has to be done in pylate, unfortunately.