Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3#36787
Make Gemma and Gemma 2 accept inputs_embeds like Gemma 3#36787DarkLight1337 merged 3 commits intovllm-project:mainfrom
inputs_embeds like Gemma 3#36787Conversation
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
|
Btw there are some other models that do a similar thing as well. How should we handle them? |
|
This was only really needed for the basic correctness test because it uses HF to generate the embeds and then passes them to vLLM. So the change in behaviour on the HF side was a problem. For the rest of vLLM where we use vLLM to generate the embeds this should be a non-issue. |
There was a problem hiding this comment.
Code Review
This pull request refactors the embedding scaling for Gemma and Gemma 2 models to align with Gemma 3 and recent changes in the transformers library. The scaling logic is correctly moved from the forward method to embed_input_ids. The accompanying test changes are also appropriate, but they contain an incorrect version string for transformers which should be corrected for accuracy and maintainability.
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…ect#36787) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…ect#36787) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
…ect#36787) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Gemma3Model.forwardaccepts pre-scaledinputs_embedswhich are scaled byGemma3Model.embed_input_ids.Before this PR
GemmaModelandGemma2Modeldid the scaling insideforward.After this PR the scaling for the earlier Gemma variants happens in
embed_input_ids. This is consistent with: