From 76cd752e01641003ff88ff3438b38e832b5e9feb Mon Sep 17 00:00:00 2001
From: Dirk Kulawiak <dirk@semi.technology>
Date: Tue, 5 Nov 2024 06:56:14 +0100
Subject: [PATCH] Add support for multi2vec-cohere

---
 test/collection/test_config.py                | 30 ++++++++++
 .../classes/config_named_vectors.py           | 54 ++++++++++++++++++
 .../collections/classes/config_vectorizers.py | 57 +++++++++++++++++++
 3 files changed, 141 insertions(+)

diff --git a/test/collection/test_config.py b/test/collection/test_config.py
index 87f102b19..4b384b4d5 100644
--- a/test/collection/test_config.py
+++ b/test/collection/test_config.py
@@ -82,6 +82,22 @@ def test_basic_config():
             }
         },
     ),
+    (
+        Configure.Vectorizer.multi2vec_cohere(
+            model="embed-multilingual-v2.0",
+            truncate="NONE",
+            vectorize_collection_name=False,
+            base_url="https://api.cohere.ai",
+        ),
+        {
+            "multi2vec-cohere": {
+                "model": "embed-multilingual-v2.0",
+                "truncate": "NONE",
+                "vectorizeClassName": False,
+                "baseURL": "https://api.cohere.ai/",
+            }
+        },
+    ),
     (
         Configure.Vectorizer.text2vec_gpt4all(),
         {
@@ -1219,6 +1235,20 @@ def test_vector_config_flat_pq() -> None:
             }
         },
     ),
+    (
+        [Configure.NamedVectors.multi2vec_cohere(name="test", text_fields=["prop"])],
+        {
+            "test": {
+                "vectorizer": {
+                    "multi2vec-cohere": {
+                        "vectorizeClassName": True,
+                        "textFields": ["prop"],
+                    }
+                },
+                "vectorIndexType": "hnsw",
+            }
+        },
+    ),
     (
         [Configure.NamedVectors.text2vec_gpt4all(name="test", source_properties=["prop"])],
         {
diff --git a/weaviate/collections/classes/config_named_vectors.py b/weaviate/collections/classes/config_named_vectors.py
index 8c93040d0..a8dd09185 100644
--- a/weaviate/collections/classes/config_named_vectors.py
+++ b/weaviate/collections/classes/config_named_vectors.py
@@ -50,6 +50,7 @@
     _VectorizerCustomConfig,
     _Text2VecDatabricksConfig,
     _Text2VecVoyageConfig,
+    _Multi2VecCohereConfig,
 )
 from ...warnings import _Warnings
 
@@ -196,6 +197,59 @@ def text2vec_cohere(
             vector_index_config=vector_index_config,
         )
 
+    @staticmethod
+    def multi2vec_cohere(
+        name: str,
+        *,
+        vector_index_config: Optional[_VectorIndexConfigCreate] = None,
+        vectorize_collection_name: bool = True,
+        base_url: Optional[AnyHttpUrl] = None,
+        model: Optional[Union[CohereModel, str]] = None,
+        truncate: Optional[CohereTruncation] = None,
+        image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
+        text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
+    ) -> _NamedVectorConfigCreate:
+        """Create a named vector using the `multi2vec_cohere` model.
+
+        See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere)
+        for detailed usage.
+
+        Arguments:
+            `name`
+                The name of the named vector.
+            `vector_index_config`
+                The configuration for Weaviate's vector index. Use wvc.config.Configure.VectorIndex to create a vector index configuration. None by default
+            `vectorize_collection_name`
+                Whether to vectorize the collection name. Defaults to `True`.
+            `model`
+                The model to use. Defaults to `None`, which uses the server-defined default.
+            `truncate`
+                The truncation strategy to use. Defaults to `None`, which uses the server-defined default.
+            `vectorize_collection_name`
+                Whether to vectorize the collection name. Defaults to `True`.
+            `base_url`
+                The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default.
+            `image_fields`
+                The image fields to use in vectorization.
+            `text_fields`
+                The text fields to use in vectorization.
+
+        Raises:
+            `pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type.
+        """
+        return _NamedVectorConfigCreate(
+            name=name,
+            vectorizer=_Multi2VecCohereConfig(
+                baseURL=base_url,
+                model=model,
+                truncate=truncate,
+                vectorizeClassName=vectorize_collection_name,
+                imageFields=_map_multi2vec_fields(image_fields),
+                textFields=_map_multi2vec_fields(text_fields),
+            ),
+            vector_index_config=vector_index_config,
+        )
+
     @staticmethod
     def text2vec_contextionary(
         name: str,
diff --git a/weaviate/collections/classes/config_vectorizers.py b/weaviate/collections/classes/config_vectorizers.py
index 93626f248..f0b81c1e3 100644
--- a/weaviate/collections/classes/config_vectorizers.py
+++ b/weaviate/collections/classes/config_vectorizers.py
@@ -112,6 +112,7 @@ class Vectorizers(str, Enum):
     TEXT2VEC_VOYAGEAI = "text2vec-voyageai"
     IMG2VEC_NEURAL = "img2vec-neural"
     MULTI2VEC_CLIP = "multi2vec-clip"
+    MULTI2VEC_COHERE = "multi2vec-cohere"
     MULTI2VEC_BIND = "multi2vec-bind"
     MULTI2VEC_PALM = "multi2vec-palm"  # change to google once 1.27 is the lowest supported version
     REF2VEC_CENTROID = "ref2vec-centroid"
@@ -374,6 +375,21 @@ def _to_dict(self) -> Dict[str, Any]:
         return ret_dict
 
 
+class _Multi2VecCohereConfig(_Multi2VecBase):
+    vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
+        default=Vectorizers.MULTI2VEC_COHERE, frozen=True, exclude=True
+    )
+    baseURL: Optional[AnyHttpUrl]
+    model: Optional[str]
+    truncate: Optional[CohereTruncation]
+
+    def _to_dict(self) -> Dict[str, Any]:
+        ret_dict = super()._to_dict()
+        if self.baseURL is not None:
+            ret_dict["baseURL"] = self.baseURL.unicode_string()
+        return ret_dict
+
+
 class _Multi2VecClipConfig(_Multi2VecBase):
     vectorizer: Union[Vectorizers, _EnumLikeStr] = Field(
         default=Vectorizers.MULTI2VEC_CLIP, frozen=True, exclude=True
@@ -698,6 +714,47 @@ def text2vec_cohere(
             vectorizeClassName=vectorize_collection_name,
         )
 
+    @staticmethod
+    def multi2vec_cohere(
+        *,
+        model: Optional[Union[CohereModel, str]] = None,
+        truncate: Optional[CohereTruncation] = None,
+        vectorize_collection_name: bool = True,
+        base_url: Optional[AnyHttpUrl] = None,
+        image_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
+        text_fields: Optional[Union[List[str], List[Multi2VecField]]] = None,
+    ) -> _VectorizerConfigCreate:
+        """Create a `_Multi2VecCohereConfig` object for use when vectorizing using the `multi2vec-cohere` model.
+
+        See the [documentation](https://weaviate.io/developers/weaviate/modules/retriever-vectorizer-modules/text2vec-cohere)
+        for detailed usage.
+
+        Arguments:
+            `model`
+                The model to use. Defaults to `None`, which uses the server-defined default.
+            `truncate`
+                The truncation strategy to use. Defaults to `None`, which uses the server-defined default.
+            `vectorize_collection_name`
+                Whether to vectorize the collection name. Defaults to `True`.
+            `base_url`
+                The base URL to use where API requests should go. Defaults to `None`, which uses the server-defined default.
+            `image_fields`
+                The image fields to use in vectorization.
+            `text_fields`
+                The text fields to use in vectorization.
+
+        Raises:
+            `pydantic.ValidationError` if `truncate` is not a valid value from the `CohereModel` type.
+        """
+        return _Multi2VecCohereConfig(
+            baseURL=base_url,
+            model=model,
+            truncate=truncate,
+            vectorizeClassName=vectorize_collection_name,
+            imageFields=_map_multi2vec_fields(image_fields),
+            textFields=_map_multi2vec_fields(text_fields),
+        )
+
     @staticmethod
     def text2vec_databricks(
         *,