-
Notifications
You must be signed in to change notification settings - Fork 323
Add padding_side and pad_token_id in OrtBackend
#705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- The shared functionality for the input preparation has been detached
from both the `OrtBackend::embed` and `OrtBackend::predict`, into
separate functions:
- `prepare_inputs` to prepare the inputs based on what the ONNX
model expects i.e., input_ids, attention_mask, etc.
- `prepare_ort_inputs` to go from those inputs to `ort::inputs!`
- Since the input processing in both `OrtBackend::embed` and
`OrtBackend::predict` was default to right-padding, and both the pooling
and post-processing in `OrtBackend::embed` too, the `PaddingSide` is now
handled to ensure the proper methods are used taking into consideration
the `padding_side`
padding_side handling in OrtBackendpadding_side and pad_token_id in OrtBackend
Narsil
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks alright, but I think we can simplify further.
backends/ort/src/lib.rs
Outdated
| Pool::Cls => match self.padding_side { | ||
| PaddingSide::Left => { | ||
| if masking { | ||
| let mut cls_embeddings = Vec::new(); | ||
| for (batch_idx, &seq_length) in | ||
| model_inputs.input_lengths.iter().enumerate() | ||
| { | ||
| let padding = max_length as f32 - seq_length; | ||
| let cls_pos = padding as usize; | ||
| cls_embeddings | ||
| .push(outputs.slice(s![batch_idx, cls_pos, ..]).to_owned()); | ||
| } | ||
| ndarray::stack( | ||
| Axis(0), | ||
| &cls_embeddings.iter().map(|x| x.view()).collect::<Vec<_>>(), | ||
| ) | ||
| .unwrap() | ||
| .into_dyn() | ||
| } else { | ||
| outputs.slice(s![.., 0, ..]).into_owned().into_dyn() | ||
| } | ||
| } | ||
| PaddingSide::Right => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), | ||
| }, | ||
| Pool::LastToken => match self.padding_side { | ||
| // NOTE: when using left-padding, the last-token is always in the last position |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like there's a lot of switching on padding side. I haven't carefully looked at each line, but it seems to me that the code could be made significantly simpler by simply allocating things and using a different offset of insertion (overwriting) using the padding side.
Something like
let offset = if padding_side == Side::Left {0} else {max_length - length};
for (i, item) in elements.iter().enumerate(){
input_ids[i + offset] = item;
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yes it could be the case for some, the issue is that given the padding_side, we can apply the pooling in a most performant approach as in e.g. last-token pooling when left-padding it's literally the last token in the sequence, but when padding is right we need to iterate over each sequence to identify where it ends and then capture what's the last-token accordingly; this is why I kept one implementation per padding_side, but happy to unify those into a single match even if either right or left might be slightly less performance as in requiring more ops to obtain, which I'm also fine with, as that'd simplify the code a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough.
Narsil
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR adds the
padding_sidefrom thetokenizer_config.jsonif applicable, otherwise it defaults topadding_side: "right"to handle the scenarios where thepadding_sideis other than "right", as e.g. https://huggingface.co/onnx-community/Qwen3-Embedding-0.6B-ONNX to ensure parity with the inputs and the outputs. And also reads thepad_token_idfrom theconfig.json, instead of setting it to 0 by default, which means thepad_token_idis read, if not there it falls back toeos_token_id, and finally to 0 if none are defined.This PR also updates the input preparation and pooling strategies accordingly, so that those are applied one way or another based on the padding side, given that the pooling should be padding-agnostic, but with the padding side information we can efficiently apply the pooling strategies instead.
Additionally, this PR fixes the last-token pooling for the
OrtBackendwhich was leading to issues (unrelated to thepadding_side) as well as using the correctpad_token_idas reported in e.g. #704As some other, smaller but still relevant changes, this PR:
OrtBackend::prepare_inputsto prepare thendarrays for theinput_ids,attention_mask, etc. within a function to be reused for bothOrtBackend::embedandOrtBackend::predictto prevent from duplicating the codeOrtBackend::prepare_ort_inputsto go fromndarrays toort::inputs!, and the reason is the same as per the function above.ModelInputsto capture all the inputs within the same struct so that it can be easily managedConfigto read fromconfig.json, required for both thepad_token_idand also for thepast_key_valuesrequired configuration valuesBefore submitting
instasnapshots?Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil