refactor: use extra_body to pass in input_type params for asymmetric embedding models for NVIDIA Inference Provider#3804
Conversation
98c19d1 to
f4f203a
Compare
…dding models for NVIDIA Inference Provider
f4f203a to
1d4d263
Compare
| "model": embedding_model_id, | ||
| "input": [], | ||
| } | ||
| if is_asymmetric_model(client_with_models, embedding_model_id): |
There was a problem hiding this comment.
you can always add extra body unconditionally? if it is null, it is null no harm made. you don't need the double check?
There was a problem hiding this comment.
Good point! I also consolidated the verification and retrieval of extra_body for asymmetric models into a single helper function that's used consistently across all openai_embeddings test cases.
| For other models, return None. | ||
| """ | ||
| is_asymmetric = is_asymmetric_model(client_with_models, model_id) | ||
| if is_asymmetric: |
There was a problem hiding this comment.
nah I don't think we need to be that defensive. i'd just simplify this a bunch more. simply make the callsites be:
client.embeddings.create(..., extra_body=get_extra_body())
that's it!
c8d54e3 to
562ad43
Compare
| return providers[provider_id] | ||
|
|
||
|
|
||
| def is_asymmetric_model(client_with_models, model_id): |
There was a problem hiding this comment.
Update to put it in get_extra_body_for_model: if it is not an asymmetric model, return None for the extra_body.
f651ece to
81e04c7
Compare
81e04c7 to
33190c1
Compare
What does this PR do?
Previously, the NVIDIA inference provider implemented a custom
openai_embeddingsmethod with a hardcodedinput_type="query"parameter, which is required by NVIDIA asymmetric embedding models(#3205).Recently
extra_bodyparameter is added to the embeddings API (#3794). So, this PR updates the NVIDIA inference provider to use the baseOpenAIMixin.openai_embeddingsmethod instead and pass theinput_typethrough theextra_bodyparameter for asymmetric embedding models.Test Plan
Run the following command for the
embedding_model:nvidia/llama-3.2-nv-embedqa-1b-v2,nvidia/nv-embedqa-e5-v5,nvidia/nv-embedqa-mistral-7b-v2, andsnowflake/arctic-embed-l.