Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple embedding providers (Huggingface, etc.) #404

Merged
merged 8 commits into from
Jun 28, 2023

Conversation

Abhinav-Naikawadi
Copy link
Contributor

No description provided.

pyproject.toml Outdated
@@ -64,10 +65,12 @@ anthropic = [
"anthropic >= 0.2.6"
]
huggingface = [
"transformers >= 4.25.0"
"transformers >= 4.25.0",
"accelerate == 0.20.3"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency was needed for hugging face pipelines support

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abhinav-Naikawadi this is only for gpu inference?

@@ -35,6 +35,7 @@ dependencies = [
"torch >= 1.10.0",
"matplotlib >= 3.5.0",
"wget >= 3.2",
"ipywidgets == 8.0.6",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency was needed for jupyter notebook support for sentence transformers progress bars

@@ -108,6 +118,15 @@ def confidence(self) -> bool:
"""Returns true if the model is able to return a confidence score along with its predictions"""
return self._model_config.get(self.COMPUTE_CONFIDENCE_KEY, False)

# Embedding config
def embedding_provider(self) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the LLM provider when an embedding model provider is not specified

@nihit
Copy link
Contributor

nihit commented Jun 28, 2023

Addresses #370

Copy link
Contributor

@nihit nihit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abhinav-Naikawadi

  1. please add example configs in the PR description here for using HuggingFace and VertexAI embeddings for one of the benchmark tasks.
  2. Add tests for verifying the embedding config is correctly used in the few shot initialization class
  3. Make relevant changes to config schema for the new embedding key - https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/configs/schema.py

pyproject.toml Outdated
@@ -64,10 +65,12 @@ anthropic = [
"anthropic >= 0.2.6"
]
huggingface = [
"transformers >= 4.25.0"
"transformers >= 4.25.0",
"accelerate == 0.20.3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abhinav-Naikawadi this is only for gpu inference?


ALGORITHM_TO_IMPLEMENTATION: Dict[FewShotAlgorithm, BaseExampleSelector] = {
FewShotAlgorithm.FIXED: FixedExampleSelector,
FewShotAlgorithm.SEMANTIC_SIMILARITY: SemanticSimilarityExampleSelector,
FewShotAlgorithm.MAX_MARGINAL_RELEVANCE: MaxMarginalRelevanceExampleSelector,
}

PROVIDER_TO_MODEL: Dict[ModelProvider, Embeddings] = {
ModelProvider.ANTHROPIC: OpenAIEmbeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do this more transparently instead of mapping anthropic, refuel to OpenAIEmbeddings under the hood.

  1. let's only have entries here for providers that actually offer embedding endpoints - openai, google, huggingface pipelines.
  2. define a "default" provider - this can be OpenAIEmbeddings() to be used if the input provider does not provide embeddings

from autolabel.configs import AutolabelConfig
from autolabel.schema import FewShotAlgorithm, ModelProvider
from langchain.embeddings import (
HuggingFaceEmbeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the default model from sentence-transformers that is used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for Huggingface is all-mpnet-base-v2. The default for Vertex AI is textembedding-gecko@001.

pyproject.toml Outdated
]
google = [
"google-cloud-aiplatform>=1.25.0"
"google-cloud-aiplatform>=1.25.0",
"google-generativeai"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need this if we're using VertexAIEmbeddings?

@Abhinav-Naikawadi
Copy link
Contributor Author

Example config for huggingface embeddings:

{
"task_name": "BankingComplaintsClassification",
"task_type": "classification",
"dataset": {
"label_column": "label",
"delimiter": ","
},
"model": {
"provider": "huggingface_pipeline",
"name": "google/flan-t5-small"
},
"embedding": {
"provider": "huggingface_pipeline",
"model": "sentence-transformers/all-mpnet-base-v2"
},
...

@Abhinav-Naikawadi
Copy link
Contributor Author

Example config with google (vertexai) embeddings:
{
"task_name": "BankingComplaintsClassification",
"task_type": "classification",
"dataset": {
"label_column": "label",
"delimiter": ","
},
"model": {
"provider": "google",
"name": "gpt-3.5-turbo"
},
"embedding": {
"provider": "google"
},
...

@nihit
Copy link
Contributor

nihit commented Jun 28, 2023

can merge once tests passing validated

@Abhinav-Naikawadi Abhinav-Naikawadi merged commit f332713 into main Jun 28, 2023
@Abhinav-Naikawadi Abhinav-Naikawadi deleted the huggingface_embeddings branch June 28, 2023 22:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants