Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Milvus2.1 Support and Update pipielines qa ui #3283

Merged
merged 3 commits into from
Sep 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pipelines/examples/semantic-search/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ python setup.py install
# 我们建议在 GPU 环境下运行本示例,运行速度较快
# 设置 1 个空闲的 GPU 卡,此处假设 0 卡为空闲 GPU
export CUDA_VISIBLE_DEVICES=0
python examples/semantic-search/semantic_search_example.py --device gpu
python examples/semantic-search/semantic_search_example.py --device gpu \
--search_engine faiss
# 如果只有 CPU 机器,可以通过 --device 参数指定 cpu 即可, 运行耗时较长
unset CUDA_VISIBLE_DEVICES
python examples/semantic-search/semantic_search_example.py --device cpu
python examples/semantic-search/semantic_search_example.py --device cpu \
--search_engine faiss
Copy link

@tianxin1860 tianxin1860 Sep 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认 ANN 引擎是 ES 吧? 为什么需要新增 --search_engine 参数?ES 我记得只支持 1 种 ANN 算法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了milvus的支持,milvus也可以使用,所以用了一个参数来进行选择是否用faiss或者milvus

```
`semantic_search_example.py`中`DensePassageRetriever`和`ErnieRanker`的模型介绍请参考[API介绍](../../API.md)

Expand Down Expand Up @@ -107,6 +109,7 @@ curl http://localhost:9200/_aliases?pretty=true
# 以DuReader-Robust 数据集为例建立 ANN 索引库
python utils/offline_ann.py --index_name dureader_robust_query_encoder \
--doc_dir data/dureader_dev \
--search_engine elastic \
--delete_index
```
可以使用下面的命令来查看数据:
Expand All @@ -119,8 +122,9 @@ curl http://localhost:9200/dureader_robust_query_encoder/_search
参数含义说明
* `index_name`: 索引的名称
* `doc_dir`: txt文本数据的路径
* `host`: Elasticsearch的IP地址
* `port`: Elasticsearch的端口号
* `host`: ANN索引引擎的IP地址
* `port`: ANN索引引擎的端口号
* `search_engine`: 选择的近似索引引擎elastic,milvus,默认elastic
* `delete_index`: 是否删除现有的索引和数据,用于清空es的数据,默认为false

#### 3.4.3 启动 RestAPI 模型服务
Expand All @@ -139,7 +143,6 @@ sh examples/semantic-search/run_search_server.sh

```
curl -X POST -k http://localhost:8891/query -H 'Content-Type: application/json' -d '{"query": "衡量酒水的价格的因素有哪些?","params": {"Retriever": {"top_k": 5}, "Ranker":{"top_k": 5}}}'

```
#### 3.4.4 启动 WebUI
```bash
Expand Down
164 changes: 113 additions & 51 deletions pipelines/examples/semantic-search/semantic_search_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

import paddle
from pipelines.document_stores import FAISSDocumentStore
from pipelines.document_stores import MilvusDocumentStore
from pipelines.nodes import DensePassageRetriever, ErnieRanker
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, print_documents

# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.")
parser.add_argument("--index_name", default='faiss_index', type=str, help="The ann index name of FAISS.")
parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.")
parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.")
parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.")
parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.")
parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.")
Expand All @@ -44,41 +46,38 @@
default=312,
type=int,
help="The embedding_dim of index")
args = parser.parse_args()
# yapf: enable

parser.add_argument('--host',
type=str,
default="localhost",
help='host ip of ANN search engine')

def semantic_search_tutorial():
parser.add_argument('--port',
type=str,
default="8530",
help='port of ANN search engine')

use_gpu = True if args.device == 'gpu' else False
args = parser.parse_args()
# yapf: enable


def get_faiss_retriever(use_gpu):
faiss_document_store = "faiss_document_store.db"
if os.path.exists(args.index_name) and os.path.exists(faiss_document_store):
# connect to existed FAISS Index
document_store = FAISSDocumentStore.load(args.index_name)
if (os.path.exists(args.params_path)):
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
else:
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
else:
doc_dir = "data/dureader_dev"
dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip"
Expand All @@ -97,35 +96,98 @@ def semantic_search_tutorial():
faiss_index_factory_str="Flat")
document_store.write_documents(dicts)

if (os.path.exists(args.params_path)):
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
else:
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)

# update Embedding
document_store.update_embeddings(retriever)

# save index
document_store.save(args.index_name)
return document_store


def get_milvus_retriever(use_gpu):

milvus_document_store = "milvus_document_store.db"
if os.path.exists(milvus_document_store):
document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim,
host=args.host,
index=args.index_name,
port=args.port,
index_param={
"M": 16,
"efConstruction": 50
},
index_type="HNSW")
# connect to existed Milvus Index
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)
else:
doc_dir = "data/dureader_dev"
dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip"

fetch_archive_from_http(url=dureader_data, output_dir=doc_dir)
dicts = convert_files_to_dicts(dir_path=doc_dir,
split_paragraphs=True,
encoding='utf-8')
document_store = MilvusDocumentStore(embedding_dim=args.embedding_dim,
host=args.host,
index=args.index_name,
port=args.port,
index_param={
"M": 16,
"efConstruction": 50
},
index_type="HNSW")
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=False,
)

document_store.write_documents(dicts)
# update Embedding
document_store.update_embeddings(retriever)

return retriever


def semantic_search_tutorial():

use_gpu = True if args.device == 'gpu' else False

if (args.search_engine == 'milvus'):
retriever = get_milvus_retriever(use_gpu)
else:
retriever = get_faiss_retriever(use_gpu)

### Ranker
ranker = ErnieRanker(
Expand Down
3 changes: 3 additions & 0 deletions pipelines/pipelines/document_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
FAISSDocumentStore = safe_import("pipelines.document_stores.faiss",
"FAISSDocumentStore", "faiss")

MilvusDocumentStore = safe_import("pipelines.document_stores.milvus2",
"Milvus2DocumentStore", "milvus")

from pipelines.document_stores.utils import (
eval_data_from_json,
eval_data_from_jsonl,
Expand Down
7 changes: 7 additions & 0 deletions pipelines/pipelines/document_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ def __next__(self):
self.ids_iterator = self.ids_iterator[1:]
return ret

def scale_to_unit_interval(self, score: float,
similarity: Optional[str]) -> float:
if similarity == "cosine":
return (score + 1) / 2
else:
return float(expit(score / 100))

@abstractmethod
def get_all_labels(
self,
Expand Down
Loading