From b4002060e0a924754e1d8d7ac611462f66cdc025 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 16 Sep 2022 11:10:03 +0000 Subject: [PATCH 1/2] Add Milvus Support and Update pipielines qa ui --- pipelines/examples/semantic-search/README.md | 13 +- .../semantic_search_example.py | 164 ++-- .../pipelines/document_stores/__init__.py | 3 + pipelines/pipelines/document_stores/base.py | 7 + .../pipelines/document_stores/milvus2.py | 763 ++++++++++++++++++ pipelines/pipelines/document_stores/sql.py | 5 +- pipelines/pipelines/pipelines/base.py | 1 + pipelines/requirements.txt | 3 +- .../pipeline/semantic_search_custom.yaml | 2 +- .../pipeline/semantic_search_milvus.yaml | 66 ++ pipelines/ui/webapp_question_answering.py | 9 +- pipelines/ui/webapp_semantic_search.py | 3 + pipelines/utils/offline_ann.py | 103 +-- 13 files changed, 1029 insertions(+), 113 deletions(-) create mode 100644 pipelines/pipelines/document_stores/milvus2.py create mode 100644 pipelines/rest_api/pipeline/semantic_search_milvus.yaml diff --git a/pipelines/examples/semantic-search/README.md b/pipelines/examples/semantic-search/README.md index 7bb70fe56b38..302a2209678e 100644 --- a/pipelines/examples/semantic-search/README.md +++ b/pipelines/examples/semantic-search/README.md @@ -73,10 +73,12 @@ python setup.py install # 我们建议在 GPU 环境下运行本示例,运行速度较快 # 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU export CUDA_VISIBLE_DEVICES=0 -python examples/semantic-search/semantic_search_example.py --device gpu +python examples/semantic-search/semantic_search_example.py --device gpu \ + --search_engine faiss # 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长 unset CUDA_VISIBLE_DEVICES -python examples/semantic-search/semantic_search_example.py --device cpu +python examples/semantic-search/semantic_search_example.py --device cpu \ + --search_engine faiss ``` `semantic_search_example.py`中`DensePassageRetriever`和`ErnieRanker`的模型介绍请参考[API介绍](../../API.md) @@ -107,6 +109,7 @@ curl http://localhost:9200/_aliases?pretty=true # 以DuReader-Robust 数据集为例建立 ANN 索引库 python utils/offline_ann.py --index_name dureader_robust_query_encoder \ --doc_dir data/dureader_dev \ + --search_engine elastic \ --delete_index ``` 可以使用下面的命令来查看数据: @@ -119,8 +122,9 @@ curl http://localhost:9200/dureader_robust_query_encoder/_search 参数含义说明 * `index_name`: 索引的名称 * `doc_dir`: txt文本数据的路径 -* `host`: Elasticsearch的IP地址 -* `port`: Elasticsearch的端口号 +* `host`: ANN索引引擎的IP地址 +* `port`: ANN索引引擎的端口号 +* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic * `delete_index`: 是否删除现有的索引和数据,用于清空es的数据,默认为false #### 3.4.3 启动 RestAPI 模型服务 @@ -139,7 +143,6 @@ sh examples/semantic-search/run_search_server.sh ``` curl -X POST -k http://localhost:8891/query -H 'Content-Type: application/json' -d '{"query": "衡量酒水的价格的因素有哪些?","params": {"Retriever": {"top_k": 5}, "Ranker":{"top_k": 5}}}' - ``` #### 3.4.4 启动 WebUI ```bash diff --git a/pipelines/examples/semantic-search/semantic_search_example.py b/pipelines/examples/semantic-search/semantic_search_example.py index a657d3d6df1e..2d31500881da 100644 --- a/pipelines/examples/semantic-search/semantic_search_example.py +++ b/pipelines/examples/semantic-search/semantic_search_example.py @@ -17,13 +17,15 @@ import paddle from pipelines.document_stores import FAISSDocumentStore +from pipelines.document_stores import MilvusDocumentStore from pipelines.nodes import DensePassageRetriever, ErnieRanker from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_documents # yapf: disable parser = argparse.ArgumentParser() parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") -parser.add_argument("--index_name", default='faiss_index', type=str, help="The ann index name of FAISS.") +parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") +parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") @@ -44,41 +46,38 @@ default=312, type=int, help="The embedding_dim of index") -args = parser.parse_args() -# yapf: enable +parser.add_argument('--host', + type=str, + default="localhost", + help='host ip of ANN search engine') -def semantic_search_tutorial(): +parser.add_argument('--port', + type=str, + default="8530", + help='port of ANN search engine') - use_gpu = True if args.device == 'gpu' else False +args = parser.parse_args() +# yapf: enable + +def get_faiss_retriever(use_gpu): faiss_document_store = "faiss_document_store.db" if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): # connect to existed FAISS Index document_store = FAISSDocumentStore.load(args.index_name) - if (os.path.exists(args.params_path)): - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - params_path=args.params_path, - output_emb_size=args.embedding_dim, - max_seq_len_query=args.max_seq_len_query, - max_seq_len_passage=args.max_seq_len_passage, - batch_size=args.retriever_batch_size, - use_gpu=use_gpu, - embed_title=False, - ) - else: - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - passage_embedding_model=args.passage_embedding_model, - max_seq_len_query=args.max_seq_len_query, - max_seq_len_passage=args.max_seq_len_passage, - batch_size=args.retriever_batch_size, - use_gpu=use_gpu, - embed_title=False, - ) + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + passage_embedding_model=args.passage_embedding_model, + params_path=args.params_path, + output_emb_size=args.embedding_dim, + max_seq_len_query=args.max_seq_len_query, + max_seq_len_passage=args.max_seq_len_passage, + batch_size=args.retriever_batch_size, + use_gpu=use_gpu, + embed_title=False, + ) else: doc_dir = "data/dureader_dev" dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip" @@ -97,35 +96,98 @@ def semantic_search_tutorial(): faiss_index_factory_str="Flat") document_store.write_documents(dicts) - if (os.path.exists(args.params_path)): - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - params_path=args.params_path, - output_emb_size=args.embedding_dim, - max_seq_len_query=args.max_seq_len_query, - max_seq_len_passage=args.max_seq_len_passage, - batch_size=args.retriever_batch_size, - use_gpu=use_gpu, - embed_title=False, - ) - else: - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - passage_embedding_model=args.passage_embedding_model, - max_seq_len_query=args.max_seq_len_query, - max_seq_len_passage=args.max_seq_len_passage, - batch_size=args.retriever_batch_size, - use_gpu=use_gpu, - embed_title=False, - ) + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + passage_embedding_model=args.passage_embedding_model, + params_path=args.params_path, + output_emb_size=args.embedding_dim, + max_seq_len_query=args.max_seq_len_query, + max_seq_len_passage=args.max_seq_len_passage, + batch_size=args.retriever_batch_size, + use_gpu=use_gpu, + embed_title=False, + ) # update Embedding document_store.update_embeddings(retriever) # save index document_store.save(args.index_name) + return document_store + + +def get_milvus_retriever(use_gpu): + + milvus_document_store = "milvus_document_store.db" + if os.path.exists(milvus_document_store): + document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={ + "M": 16, + "efConstruction": 50 + }, + index_type="HNSW") + # connect to existed Milvus Index + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + passage_embedding_model=args.passage_embedding_model, + params_path=args.params_path, + output_emb_size=args.embedding_dim, + max_seq_len_query=args.max_seq_len_query, + max_seq_len_passage=args.max_seq_len_passage, + batch_size=args.retriever_batch_size, + use_gpu=use_gpu, + embed_title=False, + ) + else: + doc_dir = "data/dureader_dev" + dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip" + + fetch_archive_from_http(url=dureader_data, output_dir=doc_dir) + dicts = convert_files_to_dicts(dir_path=doc_dir, + split_paragraphs=True, + encoding='utf-8') + document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={ + "M": 16, + "efConstruction": 50 + }, + index_type="HNSW") + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + passage_embedding_model=args.passage_embedding_model, + params_path=args.params_path, + output_emb_size=args.embedding_dim, + max_seq_len_query=args.max_seq_len_query, + max_seq_len_passage=args.max_seq_len_passage, + batch_size=args.retriever_batch_size, + use_gpu=use_gpu, + embed_title=False, + ) + + document_store.write_documents(dicts) + # update Embedding + document_store.update_embeddings(retriever) + + return retriever + + +def semantic_search_tutorial(): + + use_gpu = True if args.device == 'gpu' else False + + if (args.search_engine == 'milvus'): + retriever = get_milvus_retriever(use_gpu) + else: + retriever = get_faiss_retriever(use_gpu) ### Ranker ranker = ErnieRanker( diff --git a/pipelines/pipelines/document_stores/__init__.py b/pipelines/pipelines/document_stores/__init__.py index 724898a57036..6cdfd2416913 100644 --- a/pipelines/pipelines/document_stores/__init__.py +++ b/pipelines/pipelines/document_stores/__init__.py @@ -31,6 +31,9 @@ FAISSDocumentStore = safe_import("pipelines.document_stores.faiss", "FAISSDocumentStore", "faiss") +MilvusDocumentStore = safe_import("pipelines.document_stores.milvus2", + "Milvus2DocumentStore", "milvus") + from pipelines.document_stores.utils import ( eval_data_from_json, eval_data_from_jsonl, diff --git a/pipelines/pipelines/document_stores/base.py b/pipelines/pipelines/document_stores/base.py index 168e2452c5ab..60e277297b37 100644 --- a/pipelines/pipelines/document_stores/base.py +++ b/pipelines/pipelines/document_stores/base.py @@ -228,6 +228,13 @@ def __next__(self): self.ids_iterator = self.ids_iterator[1:] return ret + def scale_to_unit_interval(self, score: float, + similarity: Optional[str]) -> float: + if similarity == "cosine": + return (score + 1) / 2 + else: + return float(expit(score / 100)) + @abstractmethod def get_all_labels( self, diff --git a/pipelines/pipelines/document_stores/milvus2.py b/pipelines/pipelines/document_stores/milvus2.py new file mode 100644 index 000000000000..575746254da3 --- /dev/null +++ b/pipelines/pipelines/document_stores/milvus2.py @@ -0,0 +1,763 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union + +import logging +import warnings +import numpy as np + +from tqdm import tqdm + +try: + from pymilvus import FieldSchema, CollectionSchema, Collection, connections, utility + from pymilvus.client.abstract import QueryResult + from pymilvus.client.types import DataType +except (ImportError, ModuleNotFoundError) as ie: + from pipelines.utils.import_utils import _optional_component_not_installed + + _optional_component_not_installed(__name__, "milvus2", ie) + +from pipelines.schema import Document +from pipelines.document_stores.sql import SQLDocumentStore +from pipelines.document_stores.base import get_batches_from_generator + +if TYPE_CHECKING: + from pipelines.nodes.retriever.base import BaseRetriever + +logger = logging.getLogger(__name__) + + +class Milvus2DocumentStore(SQLDocumentStore): + """ + you can now run a query using vector similarity and filter for some meta data at the same time! + (See https://milvus.io/docs/v2.0.x/comparison.md for more details) + + Usage: + 1. Start a Milvus service via docker (see https://milvus.io/docs/v2.0.x/install_standalone-docker.md) + 2. Run pip install Paddle-Pipelines + 3. Init a MilvusDocumentStore() in Pipelines + + Overview: + Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors. + Therefore, it is particularly suited for Pipelines users that work with dense retrieval methods (like DPR). + + In contrast to FAISS, Milvus ... + - runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment + - allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index) + - encapsulates multiple ANN libraries (FAISS, ANNOY ...) + + This class uses Milvus for all vector related storage, processing and querying. + The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus + does not allow these data types (yet). + """ + + def __init__( + self, + sql_url: str = "sqlite:///milvus_document_store.db", + host: str = "localhost", + port: str = "19530", + connection_pool: str = "SingletonThread", + index: str = "document", + vector_dim: int = None, + embedding_dim: int = 768, + index_file_size: int = 1024, + similarity: str = "dot_product", + index_type: str = "IVF_FLAT", + index_param: Optional[Dict[str, Any]] = None, + search_param: Optional[Dict[str, Any]] = None, + return_embedding: bool = False, + embedding_field: str = "embedding", + id_field: str = "id", + custom_fields: Optional[List[Any]] = None, + progress_bar: bool = True, + duplicate_documents: str = "overwrite", + isolation_level: str = None, + consistency_level: int = 0, + recreate_index: bool = False, + ): + """ + :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale + deployment, Postgres is recommended. If using MySQL then same server can also be used for + Milvus metadata. For more details see https://milvus.io/docs/v1.1.0/data_manage.md. + :param milvus_url: Milvus server connection URL for storing and processing vectors. + Protocol, host and port will automatically be inferred from the URL. + See https://milvus.io/docs/v2.0.x/install_standalone-docker.md for instructions to start a Milvus instance. + :param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread". + :param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name"). + :param vector_dim: Deprecated. Use embedding_dim instead. + :param embedding_dim: The embedding vector size. Default: 768. + :param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB. + When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment. + Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one. + As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048. + Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory. + (From https://milvus.io/docs/v2.0.x/performance_faq.md) + :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings. + 'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus. + However, you can normalize your embeddings and use `dot_product` to get the same results. + See https://milvus.io/docs/v2.0.x/metric.md. + :param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy. + Some popular options: + - FLAT (default): Exact method, slow + - IVF_FLAT, inverted file based heuristic, fast + - HSNW: Graph based, fast + - ANNOY: Tree based, fast + See: https://milvus.io/docs/v2.0.x/index.md + :param index_param: Configuration parameters for the chose index_type needed at indexing time. + For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT. + See https://milvus.io/docs/v2.0.x/index.md + :param search_param: Configuration parameters for the chose index_type needed at query time + For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT. + See https://milvus.io/docs/v2.0.x/index.md + :param return_embedding: To return document embedding. + :param embedding_field: Name of field containing an embedding vector. + :param progress_bar: Whether to show a tqdm progress bar or not. + Can be helpful to disable in production deployments to keep the logs clean. + :param duplicate_documents: Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip: Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already + exists. + :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) + :param recreate_index: If set to True, an existing Milvus index will be deleted and a new one will be + created using the config you are using for initialization. Be aware that all data in the old index will be + lost if you choose to recreate the index. Be aware that both the document_index and the label_index will + be recreated. + """ + + super().__init__(url=sql_url, + index=index, + duplicate_documents=duplicate_documents, + isolation_level=isolation_level) + + # save init parameters to enable export of component config as YAML + self.set_config( + sql_url=sql_url, + host=host, + port=port, + index=index, + embedding_dim=embedding_dim, + vector_dim=vector_dim, + index_file_size=1024, + similarity=similarity, + index_type=index_type, + ) + + connections.add_connection(default={"host": host, "port": port}) + connections.connect() + + if vector_dim is not None: + warnings.warn( + message= + "The 'vector_dim' parameter is deprecated, use 'embedding_dim' instead.", + category=DeprecationWarning, + stacklevel=2, + ) + self.embedding_dim = vector_dim + else: + self.embedding_dim = embedding_dim + + self.index_file_size = index_file_size + self.similarity = similarity + self.cosine = False + + if similarity == "dot_product": + self.metric_type = "IP" + elif similarity == "l2": + self.metric_type = "L2" + elif similarity == "cosine": + self.metric_type = "IP" + self.cosine = True + else: + raise ValueError( + "The Milvus document store can currently only support dot_product, cosine and L2 similarity. " + 'Please set similarity="dot_product" or "cosine" or "l2"') + + self.index_type = index_type + self.index_param = index_param or {"nlist": 16384} + self.search_param = search_param or {"nprobe": 10} + self.index = index + self.embedding_field = embedding_field + self.id_field = id_field + self.custom_fields = custom_fields + + self.collection = self._create_collection_and_index( + self.index, consistency_level, recreate_index=recreate_index) + + self.return_embedding = return_embedding + self.progress_bar = progress_bar + + def _create_collection_and_index( + self, + index: Optional[str] = None, + consistency_level: int = 0, + index_param: Optional[Dict[str, Any]] = None, + recreate_index: bool = False, + ): + index = index or self.index + index_param = index_param or self.index_param + custom_fields = self.custom_fields or [] + + if recreate_index: + self._delete_index(index) + super().delete_labels() + + has_collection = utility.has_collection(collection_name=index) + if not has_collection: + fields = [ + FieldSchema(name=self.id_field, + dtype=DataType.INT64, + is_primary=True, + auto_id=True, + description="primary id"), + FieldSchema(name=self.embedding_field, + dtype=DataType.FLOAT_VECTOR, + dim=self.embedding_dim, + description="vector"), + ] + + for field in custom_fields: + if field.name == self.id_field or field.name == self.embedding_field: + logger.warning( + f"Skipping `{field.name}` as it is similar to `id_field` or `embedding_field`" + ) + else: + fields.append(field) + + collection_schema = CollectionSchema(fields=fields) + else: + collection_schema = None + + collection = Collection(name=index, + schema=collection_schema, + consistency_level=consistency_level) + + has_index = collection.has_index() + if not has_index: + collection.create_index( + field_name=self.embedding_field, + index_params={ + "index_type": self.index_type, + "metric_type": self.metric_type, + "params": index_param + }, + ) + + collection.load() + + return collection + + def _create_document_field_map(self) -> Dict: + return {self.index: self.embedding_field} + + def write_documents( + self, + documents: Union[List[dict], List[Document]], + index: Optional[str] = None, + batch_size: int = 10_000, + duplicate_documents: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + index_param: Optional[Dict[str, Any]] = None, + ): + """ + Add new documents to the DocumentStore. + + :param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index + them right away in Milvus. If not, you can later call `update_embeddings()` to create & index them. + :param index: (SQL) index name for storing the docs and metadata + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + :param duplicate_documents: Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip: Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already + exists. + :raises DuplicateDocumentError: Exception trigger on duplicate document + :return: + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + index = index or self.index + index_param = index_param or self.index_param + duplicate_documents = duplicate_documents or self.duplicate_documents + assert ( + duplicate_documents in self.duplicate_documents_options + ), f"duplicate_documents parameter must be {', '.join(self.duplicate_documents_options)}" + field_map = self._create_document_field_map() + + if len(documents) == 0: + logger.warning( + "Calling DocumentStore.write_documents() with empty list") + return + + document_objects = [ + Document.from_dict(d, field_map=field_map) + if isinstance(d, dict) else d for d in documents + ] + document_objects = self._handle_duplicate_documents( + document_objects, duplicate_documents) + add_vectors = False if document_objects[0].embedding is None else True + + batched_documents = get_batches_from_generator(document_objects, + batch_size) + with tqdm(total=len(document_objects), + disable=not self.progress_bar) as progress_bar: + mutation_result: Any = None + + for document_batch in batched_documents: + if add_vectors: + doc_ids = [] + embeddings = [] + for doc in document_batch: + doc_ids.append(doc.id) + if isinstance(doc.embedding, np.ndarray): + if self.cosine: + embedding = doc.embedding / np.linalg.norm( + doc.embedding) + embeddings.append(embedding.tolist()) + else: + embeddings.append(doc.embedding.tolist()) + elif isinstance(doc.embedding, list): + if self.cosine: + embedding = np.array(doc.embedding) + embedding /= np.linalg.norm(embedding) + embeddings.append(embedding.tolist()) + else: + embeddings.append(doc.embedding) + else: + raise AttributeError( + f"Format of supplied document embedding {type(doc.embedding)} is not " + f"supported. Please use list or numpy.ndarray") + if duplicate_documents == "overwrite": + existing_docs = super().get_documents_by_id(ids=doc_ids, + index=index) + self._delete_vector_ids_from_milvus( + documents=existing_docs, index=index) + + mutation_result = self.collection.insert([embeddings]) + + docs_to_write_in_sql = [] + + for idx, doc in enumerate(document_batch): + meta = doc.meta + if add_vectors and mutation_result is not None: + meta["vector_id"] = str( + mutation_result.primary_keys[idx]) + docs_to_write_in_sql.append(doc) + + super().write_documents(docs_to_write_in_sql, + index=index, + duplicate_documents=duplicate_documents) + progress_bar.update(batch_size) + progress_bar.close() + + def update_embeddings( + self, + retriever: "BaseRetriever", + index: Optional[str] = None, + batch_size: int = 10_000, + update_existing_embeddings: bool = True, + filters: + Optional[Dict[ + str, + Any]] = None, # TODO: Adapt type once we allow extended filters in Milvus2DocStore + ): + """ + Updates the embeddings in the the document store using the encoding model specified in the retriever. + This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config). + + :param retriever: Retriever to use to get embeddings for text + :param index: (SQL) index name for storing the docs and metadata + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + :param update_existing_embeddings: Whether to update existing embeddings of the documents. If set to False, + only documents without embeddings are processed. This mode can be used for + incremental updating of embeddings, wherein, only newly indexed documents + get processed. + :param filters: Optional filters to narrow down the documents for which embeddings are to be updated. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :return: None + """ + index = index or self.index + + document_count = self.get_document_count(index=index) + if document_count == 0: + logger.warning( + "Calling DocumentStore.update_embeddings() on an empty index") + return + + logger.info(f"Updating embeddings for {document_count} docs...") + + result = self._query( + index=index, + vector_ids=None, + batch_size=batch_size, + filters=filters, + only_documents_without_embedding=not update_existing_embeddings, + ) + batched_documents = get_batches_from_generator(result, batch_size) + with tqdm(total=document_count, + disable=not self.progress_bar, + position=0, + unit=" docs", + desc="Updating Embedding") as progress_bar: + for document_batch in batched_documents: + self._delete_vector_ids_from_milvus(documents=document_batch, + index=index) + + embeddings = retriever.embed_documents( + document_batch) # type: ignore + if self.cosine: + embeddings = [ + embedding / np.linalg.norm(embedding) + for embedding in embeddings + ] + embeddings_list = [ + embedding.tolist() for embedding in embeddings + ] + assert len(document_batch) == len(embeddings_list) + + mutation_result = self.collection.insert([embeddings_list]) + + vector_id_map = {} + for vector_id, doc in zip(mutation_result.primary_keys, + document_batch): + vector_id_map[doc.id] = str(vector_id) + + self.update_vector_ids(vector_id_map, index=index) + progress_bar.set_description_str("Documents Processed") + progress_bar.update(batch_size) + + def query_by_embedding( + self, + query_emb: np.ndarray, + filters: Optional[Dict[ + str, + Any]] = None, # TODO: Adapt type once we allow extended filters in Milvus2DocStore + top_k: int = 10, + index: Optional[str] = None, + return_embedding: Optional[bool] = None, + headers: Optional[Dict[str, str]] = None, + scale_score: bool = True, + ) -> List[Document]: + """ + Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. + + :param query_emb: Embedding of the query (e.g. gathered from DPR) + :param filters: Optional filters to narrow down the search space. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param top_k: How many documents to return + :param index: (SQL) index name for storing the docs and metadata + :param return_embedding: To return document embedding + :param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]). + If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant. + Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. + :return: + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + index = index or self.index + has_collection = utility.has_collection(collection_name=index) + if not has_collection: + raise Exception( + "No index exists. Use 'update_embeddings()` to create an index." + ) + if return_embedding is None: + return_embedding = self.return_embedding + + query_emb = query_emb.reshape(-1).astype(np.float32) + if self.cosine: + query_emb = query_emb / np.linalg.norm(query_emb) + + search_result: QueryResult = self.collection.search( + data=[query_emb.tolist()], + anns_field=self.embedding_field, + param={ + "metric_type": self.metric_type, + **self.search_param + }, + limit=top_k, + ) + + vector_ids_for_query = [] + scores_for_vector_ids: Dict[str, float] = {} + for vector_id, distance in zip(search_result[0].ids, + search_result[0].distances): + vector_ids_for_query.append(str(vector_id)) + scores_for_vector_ids[str(vector_id)] = distance + + documents = self.get_documents_by_vector_ids(vector_ids_for_query, + index=index) + + if return_embedding: + self._populate_embeddings_to_docs(index=index, docs=documents) + + for doc in documents: + score = scores_for_vector_ids[doc.meta["vector_id"]] + if scale_score: + score = self.scale_to_unit_interval(score, self.similarity) + doc.score = score + + return documents + + def delete_documents( + self, + index: Optional[str] = None, + ids: Optional[List[str]] = None, + filters: Optional[Dict[ + str, + Any]] = None, # TODO: Adapt type once we allow extended filters in Milvus2DocStore + headers: Optional[Dict[str, str]] = None, + batch_size: int = 10_000, + ): + """ + Delete all documents (from SQL AND Milvus). + :param index: (SQL) index name for storing the docs and metadata + :param filters: Optional filters to narrow down the search space. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :return: None + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + if ids: + self._delete_vector_ids_from_milvus(ids=ids, index=index) + elif filters: + batch = [] + for existing_docs in super().get_all_documents_generator( + filters=filters, index=index, batch_size=batch_size): + batch.append(existing_docs) + if len(batch) == batch_size: + self._delete_vector_ids_from_milvus(documents=batch, + index=index) + if len(batch) != 0: + self._delete_vector_ids_from_milvus(documents=batch, + index=index) + else: + self.collection = self._create_collection_and_index( + self.index, recreate_index=True) + + index = index or self.index + super().delete_documents(index=index, filters=filters, ids=ids) + + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + self._delete_index(index) + + def _delete_index(self, index: str): + if utility.has_collection(collection_name=index): + utility.drop_collection(collection_name=index) + logger.info(f"Index '{index}' deleted.") + super().delete_labels(index) + + def get_all_documents_generator( + self, + index: Optional[str] = None, + filters: Optional[Dict[ + str, + Any]] = None, # TODO: Adapt type once we allow extended filters in Milvus2DocStore + return_embedding: Optional[bool] = None, + batch_size: int = 10_000, + headers: Optional[Dict[str, str]] = None, + ) -> Generator[Document, None, None]: + """ + Get all documents from the document store. Under-the-hood, documents are fetched in batches from the + document store and yielded as individual documents. This method can be used to iteratively process + a large number of documents without having to load all documents in memory. + + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + :param filters: Optional filters to narrow down the documents to return. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param return_embedding: Whether to return the document embeddings. + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + index = index or self.index + documents = super().get_all_documents_generator(index=index, + filters=filters, + batch_size=batch_size) + if return_embedding is None: + return_embedding = self.return_embedding + + for doc in documents: + if return_embedding: + self._populate_embeddings_to_docs(index=index, docs=[doc]) + yield doc + + def get_all_documents( + self, + index: Optional[str] = None, + filters: Optional[Dict[ + str, + Any]] = None, # TODO: Adapt type once we allow extended filters in Milvus2DocStore + return_embedding: Optional[bool] = None, + batch_size: int = 10_000, + headers: Optional[Dict[str, str]] = None, + ) -> List[Document]: + """ + Get documents from the document store (optionally using filter criteria). + + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + :param filters: Optional filters to narrow down the documents to return. + Example: {"name": ["some", "more"], "category": ["only_one"]} + :param return_embedding: Whether to return the document embeddings. + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + index = index or self.index + result = self.get_all_documents_generator( + index=index, + filters=filters, + return_embedding=return_embedding, + batch_size=batch_size) + documents = list(result) + return documents + + def get_document_by_id( + self, + id: str, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None) -> Optional[Document]: + """ + Fetch a document by specifying its text id string + + :param id: ID of the document + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + documents = self.get_documents_by_id([id], index) + document = documents[0] if documents else None + return document + + def get_documents_by_id( + self, + ids: List[str], + index: Optional[str] = None, + batch_size: int = 10_000, + headers: Optional[Dict[str, str]] = None, + ) -> List[Document]: + """ + Fetch multiple documents by specifying their IDs (strings) + + :param ids: List of IDs of the documents + :param index: Name of the index to get the documents from. If None, the + DocumentStore's default index (self.index) will be used. + :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + """ + if headers: + raise NotImplementedError( + "Milvus2DocumentStore does not support headers.") + + index = index or self.index + documents = super().get_documents_by_id(ids=ids, + index=index, + batch_size=batch_size) + if self.return_embedding: + self._populate_embeddings_to_docs(index=index, docs=documents) + + return documents + + def _populate_embeddings_to_docs(self, + docs: List[Document], + index: Optional[str] = None): + index = index or self.index + docs_with_vector_ids = [] + for doc in docs: + if doc.meta and doc.meta.get("vector_id") is not None: + docs_with_vector_ids.append(doc) + + if len(docs_with_vector_ids) == 0: + return + + ids = [] + vector_id_map = {} + + for doc in docs_with_vector_ids: + vector_id: str = doc.meta["vector_id"] # type: ignore + # vector_id is always a string, but it isn't part of type hint + ids.append(str(vector_id)) + vector_id_map[int(vector_id)] = doc + + search_result: QueryResult = self.collection.query( + expr=f'{self.id_field} in [ {",".join(ids)} ]', + output_fields=[self.embedding_field]) + + for result in search_result: + doc = vector_id_map[result["id"]] + doc.embedding = np.array(result["embedding"], "float32") + + def _delete_vector_ids_from_milvus( + self, + documents: Optional[List[Document]] = None, + ids: Optional[List[str]] = None, + index: Optional[str] = None): + index = index or self.index + if ids is None: + ids = [] + if documents is None: + raise ValueError( + "You must either specify documents or ids to delete.") + for doc in documents: + if "vector_id" in doc.meta: + ids.append(str(doc.meta["vector_id"])) + else: + docs = super().get_documents_by_id(ids=ids, index=index) + ids = [ + doc.meta["vector_id"] for doc in docs if "vector_id" in doc.meta + ] + + expr = f"{self.id_field} in [{','.join(ids)}]" + + self.collection.delete(expr) + + def get_embedding_count( + self, + index: Optional[str] = None, + filters: Optional[Dict[str, List[str]]] = None) -> int: + """ + Return the count of embeddings in the document store. + """ + if filters: + raise Exception( + "filters are not supported for get_embedding_count in MilvusDocumentStore." + ) + return len(self.get_all_documents(index=index)) diff --git a/pipelines/pipelines/document_stores/sql.py b/pipelines/pipelines/document_stores/sql.py index ab276a08b8e8..cb4f71fbbf73 100644 --- a/pipelines/pipelines/document_stores/sql.py +++ b/pipelines/pipelines/document_stores/sql.py @@ -457,8 +457,11 @@ def write_documents( for doc in document_objects[i:i + batch_size]: meta_fields = doc.meta or {} vector_id = meta_fields.pop("vector_id", None) + # Support storing list type data by adding value semicolon meta_orms = [ - MetaDocumentORM(name=key, value=value) + MetaDocumentORM( + name=key, + value=';'.join(value) if type(value) == list else value) for key, value in meta_fields.items() ] doc_orm = DocumentORM( diff --git a/pipelines/pipelines/pipelines/base.py b/pipelines/pipelines/pipelines/base.py index 447fa88ac8f1..25fdc51950bc 100644 --- a/pipelines/pipelines/pipelines/base.py +++ b/pipelines/pipelines/pipelines/base.py @@ -832,6 +832,7 @@ def _load_or_get_component(cls, name: str, definitions: dict, component_type=component_type, **component_params) components[name] = instance except Exception as e: + # breakpoint() raise Exception(f"Failed loading pipeline component '{name}': {e}") return instance diff --git a/pipelines/requirements.txt b/pipelines/requirements.txt index 3b046a182622..44fa2c41e6b0 100644 --- a/pipelines/requirements.txt +++ b/pipelines/requirements.txt @@ -15,10 +15,11 @@ faiss-cpu>=1.7.2 opencv-python>=4.4 opencv-contrib-python-headless python-multipart -git+https://github.com/tvst/htbuilder.git +htbuilder@git+https://github.com/tvst/htbuilder.git st-annotated-text streamlit==1.9.0 fastapi uvicorn markdown numba +pymilvus diff --git a/pipelines/rest_api/pipeline/semantic_search_custom.yaml b/pipelines/rest_api/pipeline/semantic_search_custom.yaml index 0db19dafc217..96ccbd16bf2e 100644 --- a/pipelines/rest_api/pipeline/semantic_search_custom.yaml +++ b/pipelines/rest_api/pipeline/semantic_search_custom.yaml @@ -2,7 +2,7 @@ version: '1.1.0' components: # define all the building-blocks for Pipeline - name: DocumentStore - type: ElasticsearchDocumentStore # consider using MilvusDocumentStore or WeaviateDocumentStore for scaling to large number of documents + type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents params: host: localhost port: 9200 diff --git a/pipelines/rest_api/pipeline/semantic_search_milvus.yaml b/pipelines/rest_api/pipeline/semantic_search_milvus.yaml new file mode 100644 index 000000000000..0fbbbdd243ed --- /dev/null +++ b/pipelines/rest_api/pipeline/semantic_search_milvus.yaml @@ -0,0 +1,66 @@ +version: '1.1.0' + +components: # define all the building-blocks for Pipeline + - name: DocumentStore + type: Milvus2DocumentStore # consider using MilvusDocumentStore or WeaviateDocumentStore for scaling to large number of documents + params: + host: localhost + port: 8530 + index: dureader_index + embedding_dim: 312 + - name: Retriever + type: DensePassageRetriever + params: + document_store: DocumentStore # params can reference other components defined in the YAML + top_k: 10 + query_embedding_model: rocketqa-zh-nano-query-encoder + passage_embedding_model: rocketqa-zh-nano-para-encoder + embed_title: False + - name: Ranker # custom-name for the component; helpful for visualization & debugging + type: ErnieRanker # pipelines Class name for the component + params: + model_name_or_path: rocketqa-nano-cross-encoder + top_k: 3 + - name: TextFileConverter + type: TextConverter + - name: ImageFileConverter + type: ImageToTextConverter + - name: PDFFileConverter + type: PDFToTextConverter + - name: DocxFileConverter + type: DocxToTextConverter + - name: Preprocessor + type: PreProcessor + params: + split_by: word + split_length: 1000 + - name: FileTypeClassifier + type: FileTypeClassifier + +pipelines: + - name: query # a sample extractive-qa Pipeline + type: Query + nodes: + - name: Retriever + inputs: [Query] + - name: Ranker + inputs: [Retriever] + - name: indexing + type: Indexing + nodes: + - name: FileTypeClassifier + inputs: [File] + - name: TextFileConverter + inputs: [FileTypeClassifier.output_1] + - name: PDFFileConverter + inputs: [FileTypeClassifier.output_2] + - name: DocxFileConverter + inputs: [FileTypeClassifier.output_4] + - name: ImageFileConverter + inputs: [FileTypeClassifier.output_6] + - name: Preprocessor + inputs: [PDFFileConverter, TextFileConverter, DocxFileConverter, ImageFileConverter] + - name: Retriever + inputs: [Preprocessor] + - name: DocumentStore + inputs: [Retriever] diff --git a/pipelines/ui/webapp_question_answering.py b/pipelines/ui/webapp_question_answering.py index 3636a64b82da..3a6d29dd30fe 100644 --- a/pipelines/ui/webapp_question_answering.py +++ b/pipelines/ui/webapp_question_answering.py @@ -85,14 +85,7 @@ def reset_results(*args): on_change=reset_results, ) - top_k_ranker = st.sidebar.slider( - "最大排序数量", - min_value=1, - max_value=50, - value=DEFAULT_DOCS_FROM_RANKER, - step=1, - on_change=reset_results, - ) + top_k_ranker = 1 top_k_reader = st.sidebar.slider( "最大的答案的数量", diff --git a/pipelines/ui/webapp_semantic_search.py b/pipelines/ui/webapp_semantic_search.py index b4dce0b94c8c..ece261698fbb 100644 --- a/pipelines/ui/webapp_semantic_search.py +++ b/pipelines/ui/webapp_semantic_search.py @@ -197,6 +197,9 @@ def reset_results(*args): markdown(context), unsafe_allow_html=True, ) + # Sqlalchemy Support storing list type data by adding value semicolon, so split str data into separate files + if (type(result['images']) == str): + result['images'] = result['images'].split(';') for image_path in result['images']: image_url = pipelines_files(image_path) st.image( diff --git a/pipelines/utils/offline_ann.py b/pipelines/utils/offline_ann.py index 3a2ac9756dcb..8b1c6d0fabe2 100644 --- a/pipelines/utils/offline_ann.py +++ b/pipelines/utils/offline_ann.py @@ -17,7 +17,7 @@ import paddle from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http -from pipelines.document_stores import ElasticsearchDocumentStore +from pipelines.document_stores import ElasticsearchDocumentStore, MilvusDocumentStore from pipelines.nodes import DensePassageRetriever from pipelines.utils import launch_es @@ -33,21 +33,24 @@ parser.add_argument("--index_name", default='baike_cities', type=str, - help="The index name of the elasticsearch engine") + help="The index name of the ANN search engine") parser.add_argument("--doc_dir", default='data/baike/', type=str, help="The doc path of the corpus") - +parser.add_argument("--search_engine", + choices=['elastic', 'milvus'], + default="elastic", + help="The type of ANN search engine.") parser.add_argument('--host', type=str, default="127.0.0.1", - help='host ip of elastic search') + help='host ip of ANN search engine') parser.add_argument('--port', type=str, default="9200", - help='port of elastic search') + help='port of ANN search engine') parser.add_argument("--embedding_dim", default=312, @@ -83,15 +86,25 @@ def offline_ann(index_name, doc_dir): - launch_es() - - document_store = ElasticsearchDocumentStore( - host=args.host, - port=args.port, - username="", - password="", - embedding_dim=args.embedding_dim, - index=index_name) + if (args.search_engine == "milvus"): + document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={ + "M": 16, + "efConstruction": 50 + }, + index_type="HNSW") + else: + launch_es() + document_store = ElasticsearchDocumentStore( + host=args.host, + port=args.port, + username="", + password="", + embedding_dim=args.embedding_dim, + index=index_name) # 将每篇文档按照段落进行切分 dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, @@ -104,44 +117,42 @@ def offline_ann(index_name, doc_dir): document_store.write_documents(dicts) ### 语义索引模型 - if (os.path.exists(args.params_path)): - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - params_path=args.params_path, - output_emb_size=args.embedding_dim, - max_seq_len_query=64, - max_seq_len_passage=256, - batch_size=16, - use_gpu=True, - embed_title=False, - ) - - else: - retriever = DensePassageRetriever( - document_store=document_store, - query_embedding_model=args.query_embedding_model, - passage_embedding_model=args.passage_embedding_model, - max_seq_len_query=64, - max_seq_len_passage=256, - batch_size=16, - use_gpu=True, - embed_title=False, - ) + retriever = DensePassageRetriever( + document_store=document_store, + query_embedding_model=args.query_embedding_model, + passage_embedding_model=args.passage_embedding_model, + params_path=args.params_path, + output_emb_size=args.embedding_dim, + max_seq_len_query=64, + max_seq_len_passage=256, + batch_size=16, + use_gpu=True, + embed_title=False, + ) # 建立索引库 document_store.update_embeddings(retriever) def delete_data(index_name): - document_store = ElasticsearchDocumentStore( - host=args.host, - port=args.port, - username="", - password="", - embedding_dim=args.embedding_dim, - index=index_name) - + if (args.search_engine == 'milvus'): + document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim, + host=args.host, + index=args.index_name, + port=args.port, + index_param={ + "M": 16, + "efConstruction": 50 + }, + index_type="HNSW") + else: + document_store = ElasticsearchDocumentStore( + host=args.host, + port=args.port, + username="", + password="", + embedding_dim=args.embedding_dim, + index=index_name) document_store.delete_index(index_name) print('Delete an existing elasticsearch index {} Done.'.format(index_name)) From 8950381dd11b706db0b7ec14a3b452cde5562df4 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 16 Sep 2022 11:14:41 +0000 Subject: [PATCH 2/2] Remove unused comments --- pipelines/pipelines/pipelines/base.py | 1 - pipelines/rest_api/pipeline/semantic_search.yaml | 2 +- pipelines/rest_api/pipeline/semantic_search_milvus.yaml | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pipelines/pipelines/pipelines/base.py b/pipelines/pipelines/pipelines/base.py index 25fdc51950bc..447fa88ac8f1 100644 --- a/pipelines/pipelines/pipelines/base.py +++ b/pipelines/pipelines/pipelines/base.py @@ -832,7 +832,6 @@ def _load_or_get_component(cls, name: str, definitions: dict, component_type=component_type, **component_params) components[name] = instance except Exception as e: - # breakpoint() raise Exception(f"Failed loading pipeline component '{name}': {e}") return instance diff --git a/pipelines/rest_api/pipeline/semantic_search.yaml b/pipelines/rest_api/pipeline/semantic_search.yaml index 855e4811ef3f..faea615f2ced 100644 --- a/pipelines/rest_api/pipeline/semantic_search.yaml +++ b/pipelines/rest_api/pipeline/semantic_search.yaml @@ -2,7 +2,7 @@ version: '1.1.0' components: # define all the building-blocks for Pipeline - name: DocumentStore - type: ElasticsearchDocumentStore # consider using MilvusDocumentStore or WeaviateDocumentStore for scaling to large number of documents + type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents params: host: localhost port: 9200 diff --git a/pipelines/rest_api/pipeline/semantic_search_milvus.yaml b/pipelines/rest_api/pipeline/semantic_search_milvus.yaml index 0fbbbdd243ed..dbac53876bf9 100644 --- a/pipelines/rest_api/pipeline/semantic_search_milvus.yaml +++ b/pipelines/rest_api/pipeline/semantic_search_milvus.yaml @@ -2,7 +2,7 @@ version: '1.1.0' components: # define all the building-blocks for Pipeline - name: DocumentStore - type: Milvus2DocumentStore # consider using MilvusDocumentStore or WeaviateDocumentStore for scaling to large number of documents + type: Milvus2DocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents params: host: localhost port: 8530