Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: proper parameters in Astra DB Vectorize options #3901

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/backend/base/langflow/components/vectorstores/AstraDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,20 @@ def update_build_config(self, build_config: dict, field_value: str, field_name:
def build_vectorize_options(self, **kwargs):
for attribute in [
"provider",
"z_00_api_key_name",
"z_01_model_name",
"z_02_authentication",
"z_00_model_name",
"z_01_model_parameters",
"z_02_api_key_name",
"z_03_provider_api_key",
"z_04_model_parameters",
"z_04_authentication",
]:
if not hasattr(self, attribute):
setattr(self, attribute, None)

# Fetch values from kwargs if any self.* attributes are None
provider_value = self.VECTORIZE_PROVIDERS_MAPPING.get(self.provider, [None])[0] or kwargs.get("provider")
authentication = {**(self.z_02_authentication or kwargs.get("z_02_authentication", {}))}
authentication = {**(self.z_04_authentication or kwargs.get("z_04_authentication", {}))}

api_key_name = self.z_00_api_key_name or kwargs.get("z_00_api_key_name")
api_key_name = self.z_02_api_key_name or kwargs.get("z_02_api_key_name")
provider_key_name = self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key")
if provider_key_name:
authentication["providerKey"] = provider_key_name
Expand All @@ -346,9 +346,9 @@ def build_vectorize_options(self, **kwargs):
# must match astrapy.info.CollectionVectorServiceOptions
"collection_vector_service_options": {
"provider": provider_value,
"modelName": self.z_01_model_name or kwargs.get("z_01_model_name"),
"modelName": self.z_00_model_name or kwargs.get("z_00_model_name"),
"authentication": authentication,
"parameters": self.z_04_model_parameters or kwargs.get("z_04_model_parameters", {}),
"parameters": self.z_01_model_parameters or kwargs.get("z_01_model_parameters", {}),
},
"collection_embedding_api_key": self.z_03_provider_api_key or kwargs.get("z_03_provider_api_key"),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_astra_vectorize():
store = None
try:
options = {"provider": "nvidia", "modelName": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_01_model_name": "NV-Embed-QA"}
options_comp = {"provider": "nvidia", "z_00_model_name": "NV-Embed-QA"}

store = AstraDBVectorStore(
collection_name=VECTORIZE_COLLECTION,
Expand Down Expand Up @@ -156,10 +156,10 @@ def test_astra_vectorize_with_provider_api_key():

options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {},
"z_00_model_name": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_03_provider_api_key": "openai",
"z_04_authentication": {},
}

store = AstraDBVectorStore(
Expand Down Expand Up @@ -212,9 +212,9 @@ def test_astra_vectorize_passes_authentication():
}
options_comp = {
"provider": "openai",
"z_01_model_name": "text-embedding-3-small",
"z_04_model_parameters": {},
"z_02_authentication": {"providerKey": "openai"},
"z_00_model_name": "text-embedding-3-small",
"z_01_model_parameters": {},
"z_04_authentication": {"providerKey": "openai"},
}

store = AstraDBVectorStore(
Expand Down
Loading