From bbdce60df2415f6880aa383c321b9c8397efe88e Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 10 Nov 2022 07:46:16 +0000 Subject: [PATCH 1/5] Add English semantic search of pipelines --- paddlenlp/transformers/ernie/modeling.py | 45 ++++++++++++ paddlenlp/transformers/ernie/tokenizer.py | 20 +++++- .../transformers/semantic_search/modeling.py | 55 ++++++++++++--- pipelines/API.md | 4 +- .../semantic-search/run_search_server.sh | 16 +++++ .../semantic-search/run_search_web.sh | 21 +++++- .../pipelines/nodes/ranker/ernie_ranker.py | 17 +++-- pipelines/rest_api/pipeline/dense_faq.yaml | 2 +- .../rest_api/pipeline/semantic_search.yaml | 2 +- .../rest_api/pipeline/semantic_search_en.yaml | 68 +++++++++++++++++++ pipelines/ui/country_search.csv | 3 + pipelines/ui/webapp_semantic_search.py | 4 +- 12 files changed, 237 insertions(+), 20 deletions(-) create mode 100644 pipelines/rest_api/pipeline/semantic_search_en.yaml create mode 100644 pipelines/ui/country_search.csv diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index 744b7096000d..08ed2061d05d 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -673,6 +673,45 @@ class ErniePretrainedModel(PretrainedModel): "use_task_id": True, "vocab_size": 40000 }, + "rocketqav2-en-marco-cross-encoder": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, + "rocketqav2-en-marco-query-encoder": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, + "rocketqav2-en-marco-para-encoder": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, } resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = { @@ -752,6 +791,12 @@ class ErniePretrainedModel(PretrainedModel): "https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqa-micro-cross-encoder.pdparams", "rocketqa-nano-cross-encoder": "https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqa-nano-cross-encoder.pdparams", + "rocketqav2-en-marco-cross-encoder": + "https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_cross_encoder.pdparams", + "rocketqav2-en-marco-query-encoder": + "https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_query_encoder.pdparams", + "rocketqav2-en-marco-para-encoder": + "https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_para_encoder.pdparams", } } base_model_prefix = "ernie" diff --git a/paddlenlp/transformers/ernie/tokenizer.py b/paddlenlp/transformers/ernie/tokenizer.py index eedeb3d9e225..a2cb4b114ca2 100644 --- a/paddlenlp/transformers/ernie/tokenizer.py +++ b/paddlenlp/transformers/ernie/tokenizer.py @@ -66,7 +66,10 @@ "rocketqa-medium-cross-encoder": 2048, "rocketqa-mini-cross-encoder": 2048, "rocketqa-micro-cross-encoder": 2048, - "rocketqa-nano-cross-encoder": 2048 + "rocketqa-nano-cross-encoder": 2048, + "rocketqav2-en-marco-cross-encoder": 512, + "rocketqav2-en-marco-query-encoder": 512, + "rocketqav2-en-marco-para-encoder": 512, } @@ -202,6 +205,12 @@ class ErnieTokenizer(PretrainedTokenizer): "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_micro_zh_vocab.txt", "rocketqa-nano-cross-encoder": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_nano_zh_vocab.txt", + "rocketqav2-en-marco-cross-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt", + "rocketqav2-en-marco-query-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt", + "rocketqav2-en-marco-para-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt", } } pretrained_init_configuration = { @@ -325,6 +334,15 @@ class ErnieTokenizer(PretrainedTokenizer): "rocketqa-nano-cross-encoder": { "do_lower_case": True }, + "rocketqav2-en-marco-cross-encoder": { + "do_lower_case": True + }, + "rocketqav2-en-marco-query-encoder": { + "do_lower_case": True + }, + "rocketqav2-en-marco-para-encoder": { + "do_lower_case": True + }, } max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES diff --git a/paddlenlp/transformers/semantic_search/modeling.py b/paddlenlp/transformers/semantic_search/modeling.py index 82b1f04a2ba6..97e88b390bfa 100644 --- a/paddlenlp/transformers/semantic_search/modeling.py +++ b/paddlenlp/transformers/semantic_search/modeling.py @@ -80,6 +80,7 @@ def __init__(self, title_model_name_or_path=None, share_parameters=False, dropout=None, + reinitialize=False, use_cross_batch=False): super().__init__() @@ -96,6 +97,15 @@ def __init__(self, assert (self.query_ernie is not None) or (self.title_ernie is not None), \ "At least one of query_ernie and title_ernie should not be None" + # Compatible to rocketv2 initialization + if (reinitialize): + self.apply(self.init_weights) + + def init_weights(self, layer): + """ Initialization hook """ + if isinstance(layer, nn.LayerNorm): + layer._epsilon = 1e-5 + def get_semantic_embedding(self, data_loader): self.eval() with paddle.no_grad(): @@ -118,13 +128,14 @@ def get_pooled_embedding(self, assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \ "Please check whether your parameter for `is_query` are consistent with DualEncoder initialization." if is_query: - sequence_output, _ = self.query_ernie(input_ids, token_type_ids, - position_ids, attention_mask) + sequence_output, pool_output = self.query_ernie( + input_ids, token_type_ids, position_ids, attention_mask) else: - sequence_output, _ = self.title_ernie(input_ids, token_type_ids, - position_ids, attention_mask) + sequence_output, pool_output = self.title_ernie( + input_ids, token_type_ids, position_ids, attention_mask) return sequence_output[:, 0] + # return pool_output def cosine_sim(self, query_input_ids, @@ -242,9 +253,20 @@ class ErnieCrossEncoder(nn.Layer): def __init__(self, pretrain_model_name_or_path, num_classes=2, + reinitialize=False, dropout=None): super().__init__() - self.ernie = ErnieEncoder.from_pretrained(pretrain_model_name_or_path) + + self.ernie = ErnieEncoder.from_pretrained(pretrain_model_name_or_path, + num_classes=num_classes) + # Compatible to rocketv2 initialization + if (reinitialize): + self.apply(self.init_weights) + + def init_weights(self, layer): + """ Initialization hook """ + if isinstance(layer, nn.LayerNorm): + layer._epsilon = 1e-5 def matching(self, input_ids, @@ -252,10 +274,11 @@ def matching(self, position_ids=None, attention_mask=None, return_prob_distributation=False): - _, pooled_output = self.ernie(input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + sequence_output, pooled_output = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) pooled_output = self.ernie.dropout(pooled_output) cls_embedding = self.ernie.classifier(pooled_output) probs = F.softmax(cls_embedding, axis=1) @@ -263,6 +286,20 @@ def matching(self, return probs return probs[:, 1] + def matching_v2(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + sequence_output, pooled_output = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + pooled_output = self.ernie.dropout(sequence_output[:, 0]) + cls_embedding = self.ernie.classifier(pooled_output) + return cls_embedding + def forward(self, input_ids, token_type_ids=None, diff --git a/pipelines/API.md b/pipelines/API.md index 1630c7397f16..719d038617b7 100644 --- a/pipelines/API.md +++ b/pipelines/API.md @@ -13,7 +13,7 @@ | rocketqa-zh-mini-query-encoder | Chinese | 6-layer, 384-hidden, 12-heads, 27M parameters. Trained on DuReader retrieval text. | | rocketqa-zh-micro-query-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. | | rocketqa-zh-nano-query-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. | - +| rocketqav2-en-marco-query-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on MSMARCO. | ## ErnieRanker @@ -26,7 +26,7 @@ | rocketqa-mini-cross-encoder | Chinese | 6-layer, 384-hidden, 12-heads, 27M parameters. Trained on DuReader retrieval text. | | rocketqa-micro-cross-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. | | rocketqa-nano-cross-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. | - +| rocketqav2-en-marco-cross-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on Trained on MSMARCO. | ## ErnieReader diff --git a/pipelines/examples/semantic-search/run_search_server.sh b/pipelines/examples/semantic-search/run_search_server.sh index b940b6fd6e4e..bd3531d40d63 100644 --- a/pipelines/examples/semantic-search/run_search_server.sh +++ b/pipelines/examples/semantic-search/run_search_server.sh @@ -1,5 +1,21 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # 指定语义检索系统的Yaml配置文件 export CUDA_VISIBLE_DEVICES=0 export PIPELINE_YAML_PATH=rest_api/pipeline/semantic_search.yaml +# English Version +# export PIPELINE_YAML_PATH=rest_api/pipeline/semantic_search_en.yaml # 使用端口号 8891 启动模型服务 python rest_api/application.py 8891 \ No newline at end of file diff --git a/pipelines/examples/semantic-search/run_search_web.sh b/pipelines/examples/semantic-search/run_search_web.sh index 05530d8779eb..1a9476f7f53f 100644 --- a/pipelines/examples/semantic-search/run_search_web.sh +++ b/pipelines/examples/semantic-search/run_search_web.sh @@ -1,5 +1,24 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + unset http_proxy && unset https_proxy # 配置模型服务地址 export API_ENDPOINT=http://127.0.0.1:8891 # 在指定端口 8502 启动 WebUI -python -m streamlit run ui/webapp_semantic_search.py --server.port 8502 \ No newline at end of file +python -m streamlit run ui/webapp_semantic_search.py --server.port 8502 + +# English Version +# export EVAL_FILE=ui/country_search.csv +# export DEFAULT_QUESTION_AT_STARTUP="The introduction of United States of America?" +# python -m streamlit run ui/webapp_semantic_search.py --server.port 8502 \ No newline at end of file diff --git a/pipelines/pipelines/nodes/ranker/ernie_ranker.py b/pipelines/pipelines/nodes/ranker/ernie_ranker.py index 8146e246bf06..a5da599a9413 100644 --- a/pipelines/pipelines/nodes/ranker/ernie_ranker.py +++ b/pipelines/pipelines/nodes/ranker/ernie_ranker.py @@ -48,6 +48,8 @@ def __init__( max_seq_len: int = 256, progress_bar: bool = True, batch_size: int = 1000, + reinitialize: bool = False, + user_en: bool = False, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. @@ -60,14 +62,17 @@ def __init__( self.set_config( model_name_or_path=model_name_or_path, top_k=top_k, + user_en=user_en, ) self.top_k = top_k + self.user_en = user_en self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True) print("Loading Parameters from:{}".format(model_name_or_path)) - self.transformer_model = ErnieCrossEncoder(model_name_or_path) + self.transformer_model = ErnieCrossEncoder(model_name_or_path, + reinitialize=reinitialize) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.transformer_model.eval() self.progress_bar = progress_bar @@ -156,15 +161,19 @@ def predict_batch( for cur_queries, cur_docs in batches: features = self.tokenizer(cur_queries, [doc.content for doc in cur_docs], - max_seq_len=256, + max_seq_len=self.max_seq_len, pad_to_max_seq_len=True, truncation_strategy="longest_first") tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()} with paddle.no_grad(): - similarity_scores = self.transformer_model.matching( - **tensors).numpy() + if (self.user_en): + similarity_scores = self.transformer_model.matching_v2( + **tensors).numpy() + else: + similarity_scores = self.transformer_model.matching( + **tensors).numpy() preds.extend(similarity_scores) for doc, rank_score in zip(cur_docs, similarity_scores): diff --git a/pipelines/rest_api/pipeline/dense_faq.yaml b/pipelines/rest_api/pipeline/dense_faq.yaml index 8174ae43fe7c..4bae813cdf3e 100644 --- a/pipelines/rest_api/pipeline/dense_faq.yaml +++ b/pipelines/rest_api/pipeline/dense_faq.yaml @@ -39,7 +39,7 @@ components: # define all the building-blocks for Pipeline type: FileTypeClassifier pipelines: - - name: query # a sample extractive-qa Pipeline + - name: query type: Query nodes: - name: Retriever diff --git a/pipelines/rest_api/pipeline/semantic_search.yaml b/pipelines/rest_api/pipeline/semantic_search.yaml index faea615f2ced..b948b5931fcf 100644 --- a/pipelines/rest_api/pipeline/semantic_search.yaml +++ b/pipelines/rest_api/pipeline/semantic_search.yaml @@ -38,7 +38,7 @@ components: # define all the building-blocks for Pipeline type: FileTypeClassifier pipelines: - - name: query # a sample extractive-qa Pipeline + - name: query type: Query nodes: - name: Retriever diff --git a/pipelines/rest_api/pipeline/semantic_search_en.yaml b/pipelines/rest_api/pipeline/semantic_search_en.yaml new file mode 100644 index 000000000000..e251c6b275ab --- /dev/null +++ b/pipelines/rest_api/pipeline/semantic_search_en.yaml @@ -0,0 +1,68 @@ +version: '1.1.0' + +components: # define all the building-blocks for Pipeline + - name: DocumentStore + type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents + params: + host: localhost + port: 9200 + index: msmarco_query_encoder + embedding_dim: 768 + - name: Retriever + type: DensePassageRetriever + params: + document_store: DocumentStore # params can reference other components defined in the YAML + top_k: 10 + query_embedding_model: rocketqav2-en-marco-query-encoder + passage_embedding_model: rocketqav2-en-marco-query-encoder + embed_title: False + - name: Ranker # custom-name for the component; helpful for visualization & debugging + type: ErnieRanker # pipelines Class name for the component + params: + model_name_or_path: rocketqav2-en-marco-cross-encoder + top_k: 3 + user_en: True, + reinitialize: True + - name: TextFileConverter + type: TextConverter + - name: ImageFileConverter + type: ImageToTextConverter + - name: PDFFileConverter + type: PDFToTextConverter + - name: DocxFileConverter + type: DocxToTextConverter + - name: Preprocessor + type: PreProcessor + params: + split_by: word + split_length: 1000 + - name: FileTypeClassifier + type: FileTypeClassifier + +pipelines: + - name: query + type: Query + nodes: + - name: Retriever + inputs: [Query] + - name: Ranker + inputs: [Retriever] + - name: indexing + type: Indexing + nodes: + - name: FileTypeClassifier + inputs: [File] + - name: TextFileConverter + inputs: [FileTypeClassifier.output_1] + - name: PDFFileConverter + inputs: [FileTypeClassifier.output_2] + - name: DocxFileConverter + inputs: [FileTypeClassifier.output_4] + - name: ImageFileConverter + inputs: [FileTypeClassifier.output_6] + - name: Preprocessor + inputs: [PDFFileConverter, TextFileConverter, DocxFileConverter, ImageFileConverter] + - name: Retriever + inputs: [Preprocessor] + - name: DocumentStore + inputs: [Retriever] diff --git a/pipelines/ui/country_search.csv b/pipelines/ui/country_search.csv new file mode 100644 index 000000000000..7a6ff67ea5a3 --- /dev/null +++ b/pipelines/ui/country_search.csv @@ -0,0 +1,3 @@ +"Question Text";"Answer" +"What is the capital of America?";"Washington" +"How many people live in the capital of the US?";"689,545" \ No newline at end of file diff --git a/pipelines/ui/webapp_semantic_search.py b/pipelines/ui/webapp_semantic_search.py index 59cf9342df53..f18dc88af79c 100644 --- a/pipelines/ui/webapp_semantic_search.py +++ b/pipelines/ui/webapp_semantic_search.py @@ -27,6 +27,7 @@ sys.path.append('ui') from utils import pipelines_is_ready, semantic_search, send_feedback, upload_doc, pipelines_version, get_backlink from utils import pipelines_files + # Adjust to a question that you would like users to see in the search bar when they load the UI: DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "衡量酒水的价格的因素有哪些?") @@ -227,4 +228,5 @@ def reset_results(*args): st.write("___") -main() \ No newline at end of file +if __name__ == "__main__": + main() From c852d61ddb6aa559594a0f706ca647230f135fde Mon Sep 17 00:00:00 2001 From: w5688414 Date: Thu, 10 Nov 2022 09:18:54 +0000 Subject: [PATCH 2/5] Optimize semantic_search modeling.py --- .../transformers/semantic_search/modeling.py | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/paddlenlp/transformers/semantic_search/modeling.py b/paddlenlp/transformers/semantic_search/modeling.py index 97e88b390bfa..f0c61cdfbe72 100644 --- a/paddlenlp/transformers/semantic_search/modeling.py +++ b/paddlenlp/transformers/semantic_search/modeling.py @@ -128,14 +128,13 @@ def get_pooled_embedding(self, assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \ "Please check whether your parameter for `is_query` are consistent with DualEncoder initialization." if is_query: - sequence_output, pool_output = self.query_ernie( - input_ids, token_type_ids, position_ids, attention_mask) + sequence_output, _ = self.query_ernie(input_ids, token_type_ids, + position_ids, attention_mask) else: - sequence_output, pool_output = self.title_ernie( - input_ids, token_type_ids, position_ids, attention_mask) + sequence_output, _ = self.title_ernie(input_ids, token_type_ids, + position_ids, attention_mask) return sequence_output[:, 0] - # return pool_output def cosine_sim(self, query_input_ids, @@ -274,11 +273,10 @@ def matching(self, position_ids=None, attention_mask=None, return_prob_distributation=False): - sequence_output, pooled_output = self.ernie( - input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + _, pooled_output = self.ernie(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) pooled_output = self.ernie.dropout(pooled_output) cls_embedding = self.ernie.classifier(pooled_output) probs = F.softmax(cls_embedding, axis=1) @@ -291,14 +289,13 @@ def matching_v2(self, token_type_ids=None, position_ids=None, attention_mask=None): - sequence_output, pooled_output = self.ernie( - input_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - attention_mask=attention_mask) + sequence_output, _ = self.ernie(input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) pooled_output = self.ernie.dropout(sequence_output[:, 0]) - cls_embedding = self.ernie.classifier(pooled_output) - return cls_embedding + probs = self.ernie.classifier(pooled_output) + return probs def forward(self, input_ids, From 0447b0b3972dbbb626224eea58948d6ddc30e7fe Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 16 Nov 2022 08:01:58 +0000 Subject: [PATCH 3/5] Update syntax --- paddlenlp/transformers/semantic_search/modeling.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/paddlenlp/transformers/semantic_search/modeling.py b/paddlenlp/transformers/semantic_search/modeling.py index f0c61cdfbe72..db09d4e2d7bd 100644 --- a/paddlenlp/transformers/semantic_search/modeling.py +++ b/paddlenlp/transformers/semantic_search/modeling.py @@ -97,8 +97,8 @@ def __init__(self, assert (self.query_ernie is not None) or (self.title_ernie is not None), \ "At least one of query_ernie and title_ernie should not be None" - # Compatible to rocketv2 initialization - if (reinitialize): + # Compatible to rocketv2 initialization for setting layer._epsilon to 1e-5 + if reinitialize: self.apply(self.init_weights) def init_weights(self, layer): @@ -210,7 +210,6 @@ def forward(self, paddle.distributed.all_gather(tensor_list, all_title_cls_embedding) all_title_cls_embedding = paddle.concat(x=tensor_list, axis=0) - # multiply logits = paddle.matmul(query_cls_embedding, all_title_cls_embedding, transpose_y=True) @@ -258,8 +257,8 @@ def __init__(self, self.ernie = ErnieEncoder.from_pretrained(pretrain_model_name_or_path, num_classes=num_classes) - # Compatible to rocketv2 initialization - if (reinitialize): + # Compatible to rocketv2 initialization for setting layer._epsilon to 1e-5 + if reinitialize: self.apply(self.init_weights) def init_weights(self, layer): From 0cdce544cb8eb453b9370c26b7ebee67278e58b1 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 16 Nov 2022 08:04:26 +0000 Subject: [PATCH 4/5] Update semantic search english examples --- pipelines/ui/country_search.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipelines/ui/country_search.csv b/pipelines/ui/country_search.csv index 7a6ff67ea5a3..d1b6e5157d7c 100644 --- a/pipelines/ui/country_search.csv +++ b/pipelines/ui/country_search.csv @@ -1,3 +1,3 @@ "Question Text";"Answer" "What is the capital of America?";"Washington" -"How many people live in the capital of the US?";"689,545" \ No newline at end of file +"How many states of the United States ?";"50" \ No newline at end of file From 1cd6407b55511799e9c7c585157ed7be3c31bb9b Mon Sep 17 00:00:00 2001 From: w5688414 Date: Wed, 16 Nov 2022 09:29:21 +0000 Subject: [PATCH 5/5] Change parameter user_en to use_en --- pipelines/pipelines/nodes/ranker/ernie_ranker.py | 9 +++++---- pipelines/rest_api/pipeline/semantic_search_en.yaml | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pipelines/pipelines/nodes/ranker/ernie_ranker.py b/pipelines/pipelines/nodes/ranker/ernie_ranker.py index a5da599a9413..15651d30414d 100644 --- a/pipelines/pipelines/nodes/ranker/ernie_ranker.py +++ b/pipelines/pipelines/nodes/ranker/ernie_ranker.py @@ -49,7 +49,7 @@ def __init__( progress_bar: bool = True, batch_size: int = 1000, reinitialize: bool = False, - user_en: bool = False, + use_en: bool = False, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. @@ -62,11 +62,12 @@ def __init__( self.set_config( model_name_or_path=model_name_or_path, top_k=top_k, - user_en=user_en, + use_en=use_en, ) self.top_k = top_k - self.user_en = user_en + # Parameter to control the use of English Cross Encoder Model + self.use_en = use_en self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=True) @@ -168,7 +169,7 @@ def predict_batch( tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()} with paddle.no_grad(): - if (self.user_en): + if (self.use_en): similarity_scores = self.transformer_model.matching_v2( **tensors).numpy() else: diff --git a/pipelines/rest_api/pipeline/semantic_search_en.yaml b/pipelines/rest_api/pipeline/semantic_search_en.yaml index e251c6b275ab..f98feac8369c 100644 --- a/pipelines/rest_api/pipeline/semantic_search_en.yaml +++ b/pipelines/rest_api/pipeline/semantic_search_en.yaml @@ -21,7 +21,7 @@ components: # define all the building-blocks for Pipeline params: model_name_or_path: rocketqav2-en-marco-cross-encoder top_k: 3 - user_en: True, + use_en: True, reinitialize: True - name: TextFileConverter type: TextConverter