Skip to content

Commit

Permalink
Merge pull request #943 from weaviate/new_modules
Browse files Browse the repository at this point in the history
Add support for new modules
  • Loading branch information
dirkkul authored Mar 14, 2024
2 parents af56e26 + 9abe115 commit 48413e6
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 9 deletions.
87 changes: 85 additions & 2 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,29 @@ def test_basic_config():
Configure.Vectorizer.text2vec_transformers(
pooling_strategy="cls",
vectorize_collection_name=False,
inference_url="https://api.transformers.com",
),
{
"text2vec-transformers": {
"vectorizeClassName": False,
"poolingStrategy": "cls",
"inferenceUrl": "https://api.transformers.com",
}
},
),
(
Configure.Vectorizer.text2vec_voyageai(
vectorize_collection_name=False,
model="voyage-large-2",
truncate=False,
base_url="https://voyage.made-up.com",
),
{
"text2vec-voyageai": {
"vectorizeClassName": False,
"model": "voyage-large-2",
"baseURL": "https://voyage.made-up.com",
"truncate": False,
}
},
),
Expand All @@ -293,6 +311,23 @@ def test_basic_config():
}
},
),
(
Configure.Vectorizer.multi2vec_palm(
image_fields=["image"],
text_fields=["text"],
project_id="project",
location="us-central1",
),
{
"multi2vec-palm": {
"imageFields": ["image"],
"textFields": ["text"],
"projectId": "project",
"location": "us-central1",
"vectorizeClassName": True,
}
},
),
(
Configure.Vectorizer.multi2vec_clip(
image_fields=[Multi2VecField(name="image")],
Expand Down Expand Up @@ -535,6 +570,10 @@ def test_config_with_vectorizer_and_properties(
Configure.Generative.anyscale(),
{"generative-anyscale": {}},
),
(
Configure.Generative.mistral(temperature=0.5, max_tokens=100, model="model"),
{"generative-mistral": {"temperature": 0.5, "maxTokens": 100, "model": "model"}},
),
(
Configure.Generative.openai(
model="gpt-4",
Expand Down Expand Up @@ -700,7 +739,7 @@ def test_config_with_generative(
def test_config_with_reranker(
reranker_config: _RerankerConfigCreate,
expected_mc: dict,
):
) -> None:
config = _CollectionConfigCreate(name="test", reranker_config=reranker_config)
assert config._to_dict() == {
**DEFAULTS,
Expand All @@ -710,7 +749,7 @@ def test_config_with_reranker(
}


def test_config_with_properties():
def test_config_with_properties() -> None:
config = _CollectionConfigCreate(
name="test",
description="test",
Expand Down Expand Up @@ -1046,6 +1085,25 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[
Configure.NamedVectors.text2vec_voyageai(
name="test", source_properties=["prop"], truncate=True
)
],
{
"test": {
"vectorizer": {
"text2vec-voyageai": {
"properties": ["prop"],
"vectorizeClassName": True,
"truncate": True,
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[
Configure.NamedVectors.img2vec_neural(
Expand Down Expand Up @@ -1085,6 +1143,31 @@ def test_vector_config_flat_pq() -> None:
}
},
),
(
[
Configure.NamedVectors.multi2vec_palm(
name="test",
image_fields=["image"],
text_fields=["text"],
project_id="project",
location="us-central1",
)
],
{
"test": {
"vectorizer": {
"multi2vec-palm": {
"imageFields": ["image"],
"textFields": ["text"],
"projectId": "project",
"location": "us-central1",
"vectorizeClassName": True,
}
},
"vectorIndexType": "hnsw",
}
},
),
(
[
Configure.NamedVectors.multi2vec_bind(
Expand Down
18 changes: 18 additions & 0 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class GenerativeSearches(str, Enum):
PALM = "generative-palm"
AWS = "generative-aws"
ANYSCALE = "generative-anyscale"
MISTRAL = "generative-mistral"


class Rerankers(str, Enum):
Expand Down Expand Up @@ -367,6 +368,15 @@ class _GenerativeAnyscale(_GenerativeConfigCreate):
model: Optional[str]


class _GenerativeMistral(_GenerativeConfigCreate):
generative: GenerativeSearches = Field(
default=GenerativeSearches.MISTRAL, frozen=True, exclude=True
)
temperature: Optional[float]
model: Optional[str]
maxTokens: Optional[int]


class _GenerativeOpenAIConfigBase(_GenerativeConfigCreate):
generative: GenerativeSearches = Field(
default=GenerativeSearches.OPENAI, frozen=True, exclude=True
Expand Down Expand Up @@ -464,6 +474,14 @@ def anyscale(
) -> _GenerativeConfigCreate:
return _GenerativeAnyscale(model=model, temperature=temperature)

@staticmethod
def mistral(
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> _GenerativeConfigCreate:
return _GenerativeMistral(model=model, temperature=temperature, maxTokens=max_tokens)

@staticmethod
def openai(
model: Optional[str] = None,
Expand Down
Loading

0 comments on commit 48413e6

Please sign in to comment.