Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ernie Search base model into pipelines #3906

Merged
merged 4 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions paddlenlp/transformers/ernie/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,33 @@ class ErniePretrainedModel(PretrainedModel):
"vocab_size": 30522,
"pad_token_id": 0,
},
"ernie-search-base-dual-encoder-marco-en": {
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 4,
"vocab_size": 30522,
"pad_token_id": 0,
},
"ernie-search-large-cross-encoder-marco-en": {
"attention_probs_dropout_prob": 0.1,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"type_vocab_size": 4,
"vocab_size": 30522,
"pad_token_id": 0,
},
}
resource_files_names = {"model_state": "model_state.pdparams"}
pretrained_resource_files_map = {
Expand Down Expand Up @@ -800,6 +827,10 @@ class ErniePretrainedModel(PretrainedModel):
"https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_query_encoder.pdparams",
"rocketqav2-en-marco-para-encoder":
"https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_para_encoder.pdparams",
"ernie-search-base-dual-encoder-marco-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_search/ernie_search_base_dual_encoder_marco_en.pdparams",
"ernie-search-large-cross-encoder-marco-en":
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_search/ernie_search_large_cross_encoder_marco_en.pdparams",
}
}
base_model_prefix = "ernie"
Expand Down
12 changes: 12 additions & 0 deletions paddlenlp/transformers/ernie/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
"rocketqav2-en-marco-cross-encoder": 512,
"rocketqav2-en-marco-query-encoder": 512,
"rocketqav2-en-marco-para-encoder": 512,
"ernie-search-base-dual-encoder-marco-en": 512,
"ernie-search-large-cross-encoder-marco-en": 512,
}


Expand Down Expand Up @@ -211,6 +213,10 @@ class ErnieTokenizer(PretrainedTokenizer):
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
"rocketqav2-en-marco-para-encoder":
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
"ernie-search-base-dual-encoder-marco-en":
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
"ernie-search-large-cross-encoder-marco-en":
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_large/vocab.txt",
}
}
pretrained_init_configuration = {
Expand Down Expand Up @@ -343,6 +349,12 @@ class ErnieTokenizer(PretrainedTokenizer):
"rocketqav2-en-marco-para-encoder": {
"do_lower_case": True
},
"ernie-search-base-dual-encoder-marco-en": {
"do_lower_case": True
},
"ernie-search-large-cross-encoder-marco-en": {
"do_lower_case": True
},
}

max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Expand Down
51 changes: 47 additions & 4 deletions paddlenlp/transformers/semantic_search/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,25 @@

class ErnieEncoder(ErniePretrainedModel):

def __init__(self, ernie, dropout=None, num_classes=2):
def __init__(self,
ernie,
dropout=None,
output_emb_size=None,
num_classes=2):
super(ErnieEncoder, self).__init__()
self.ernie = ernie # allow ernie to be config
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
self.classifier = nn.Linear(self.ernie.config["hidden_size"],
num_classes)
# Compatible to ERNIE-Search for adding extra linear layer
self.output_emb_size = output_emb_size
if output_emb_size is not None and output_emb_size > 0:
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
self.emb_reduce_linear = paddle.nn.Linear(
self.ernie.config["hidden_size"],
output_emb_size,
weight_attr=weight_attr)
self.apply(self.init_weights)

def init_weights(self, layer):
Expand Down Expand Up @@ -79,21 +92,23 @@ def __init__(self,
query_model_name_or_path=None,
title_model_name_or_path=None,
share_parameters=False,
output_emb_size=None,
dropout=None,
reinitialize=False,
use_cross_batch=False):

super().__init__()
self.query_ernie, self.title_ernie = None, None
self.use_cross_batch = use_cross_batch
self.output_emb_size = output_emb_size
if query_model_name_or_path is not None:
self.query_ernie = ErnieEncoder.from_pretrained(
query_model_name_or_path)
query_model_name_or_path, output_emb_size=output_emb_size)
if share_parameters:
self.title_ernie = self.query_ernie
elif title_model_name_or_path is not None:
self.title_ernie = ErnieEncoder.from_pretrained(
title_model_name_or_path)
title_model_name_or_path, output_emb_size=output_emb_size)
assert (self.query_ernie is not None) or (self.title_ernie is not None), \
"At least one of query_ernie and title_ernie should not be None"

Expand Down Expand Up @@ -125,16 +140,27 @@ def get_pooled_embedding(self,
position_ids=None,
attention_mask=None,
is_query=True):
"""Get the first feature of each sequence for classification"""
assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \
"Please check whether your parameter for `is_query` are consistent with DualEncoder initialization."
if is_query:
sequence_output, _ = self.query_ernie(input_ids, token_type_ids,
position_ids, attention_mask)
if self.output_emb_size is not None and self.output_emb_size > 0:
cls_embedding = self.query_ernie.emb_reduce_linear(
sequence_output[:, 0])
else:
cls_embedding = sequence_output[:, 0]

else:
sequence_output, _ = self.title_ernie(input_ids, token_type_ids,
position_ids, attention_mask)
return sequence_output[:, 0]
if self.output_emb_size is not None and self.output_emb_size > 0:
cls_embedding = self.title_ernie.emb_reduce_linear(
sequence_output[:, 0])
else:
cls_embedding = sequence_output[:, 0]
return cls_embedding

def cosine_sim(self,
query_input_ids,
Expand Down Expand Up @@ -272,6 +298,7 @@ def matching(self,
position_ids=None,
attention_mask=None,
return_prob_distributation=False):
"""Use the pooled_output as the feature for pointwise prediction, eg. RocketQAv1"""
_, pooled_output = self.ernie(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
Expand All @@ -288,6 +315,7 @@ def matching_v2(self,
token_type_ids=None,
position_ids=None,
attention_mask=None):
"""Use the cls token embedding as the feature for listwise prediction, eg. RocketQAv2"""
sequence_output, _ = self.ernie(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
Expand All @@ -296,6 +324,21 @@ def matching_v2(self,
probs = self.ernie.classifier(pooled_output)
return probs

def matching_v3(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
"""Use the pooled_output as the feature for listwise prediction, eg. ERNIE-Search"""
sequence_output, pooled_output = self.ernie(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
pooled_output = self.ernie.dropout(pooled_output)
probs = self.ernie.classifier(pooled_output)
return probs

def forward(self,
input_ids,
token_type_ids=None,
Expand Down
2 changes: 2 additions & 0 deletions pipelines/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
| rocketqa-zh-micro-query-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. |
| rocketqa-zh-nano-query-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. |
| rocketqav2-en-marco-query-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on MSMARCO. |
| ernie-search-base-dual-encoder-marco-en | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on MSMARCO. |

## ErnieRanker

Expand All @@ -27,6 +28,7 @@
| rocketqa-micro-cross-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. |
| rocketqa-nano-cross-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. |
| rocketqav2-en-marco-cross-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on Trained on MSMARCO. |
| ernie-search-large-cross-encoder-marco-en | English | 24-layer, 768-hidden, 12-heads, 118M parameters. Trained on Trained on MSMARCO. |

## ErnieReader

Expand Down
17 changes: 13 additions & 4 deletions pipelines/pipelines/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def __init__(
Path, str] = "rocketqa-zh-dureader-para-encoder",
params_path: Optional[str] = "",
model_version: Optional[str] = None,
output_emb_size=256,
output_emb_size: Optional[int] = None,
reinitialize: bool = False,
share_parameters: bool = False,
max_seq_len_query: int = 64,
max_seq_len_passage: int = 256,
top_k: int = 10,
Expand Down Expand Up @@ -98,7 +100,7 @@ def __init__(
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
"""
# save init parameters to enable export of component config as YAML
# Save init parameters to enable export of component config as YAML
self.set_config(
document_store=document_store,
query_embedding_model=query_embedding_model,
Expand All @@ -110,6 +112,9 @@ def __init__(
use_gpu=use_gpu,
batch_size=batch_size,
embed_title=embed_title,
reinitialize=reinitialize,
share_parameters=share_parameters,
output_emb_size=output_emb_size,
similarity_function=similarity_function,
progress_bar=progress_bar,
)
Expand Down Expand Up @@ -150,8 +155,12 @@ def __init__(
self.passage_tokenizer = AutoTokenizer.from_pretrained(
query_embedding_model)
else:
self.ernie_dual_encoder = ErnieDualEncoder(query_embedding_model,
passage_embedding_model)
self.ernie_dual_encoder = ErnieDualEncoder(
query_embedding_model,
passage_embedding_model,
output_emb_size=output_emb_size,
reinitialize=reinitialize,
share_parameters=share_parameters)
self.query_tokenizer = AutoTokenizer.from_pretrained(
query_embedding_model)
self.passage_tokenizer = AutoTokenizer.from_pretrained(
Expand Down
1 change: 1 addition & 0 deletions pipelines/rest_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"PIPELINE_YAML_PATH",
str((Path(__file__).parent / "pipeline" / "pipelines.yaml").absolute()))
QUERY_PIPELINE_NAME = os.getenv("QUERY_PIPELINE_NAME", "query")
QUERY_QA_PAIRS_NAME = os.getenv('QUERY_QA_PAIRS_NAME', 'query_qa_pairs')
INDEXING_PIPELINE_NAME = os.getenv("INDEXING_PIPELINE_NAME", "indexing")
INDEXING_QA_GENERATING_PIPELINE_NAME = os.getenv(
"INDEXING_QA_GENERATING_PIPELINE_NAME", "indexing_qa_generating")
Expand Down
9 changes: 6 additions & 3 deletions pipelines/rest_api/controller/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import pipelines
from pipelines.pipelines.base import Pipeline
from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME
from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME, QUERY_QA_PAIRS_NAME
from rest_api.config import LOG_LEVEL, CONCURRENT_REQUEST_PER_WORKER
from rest_api.schema import QueryRequest, QueryResponse, DocumentRequest, DocumentResponse, QueryImageResponse, QueryQAPairResponse, QueryQAPairRequest
from rest_api.controller.utils import RequestLimiter
Expand All @@ -42,8 +42,11 @@
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH),
pipeline_name=QUERY_PIPELINE_NAME)

QA_PAIR_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH),
pipeline_name="query_qa_pairs")
try:
QA_PAIR_PIPELINE = Pipeline.load_from_yaml(
Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_QA_PAIRS_NAME)
except Exception as e:
logger.warning(f"Request pipeline ('{QUERY_QA_PAIRS_NAME}: is null'). ")
DOCUMENT_STORE = PIPELINE.get_document_store()
logging.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")

Expand Down
69 changes: 69 additions & 0 deletions pipelines/rest_api/pipeline/semantic_ernie_search_en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
version: '1.1.0'

components: # define all the building-blocks for Pipeline
- name: DocumentStore
type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents
params:
host: localhost
port: 9200
index: msmarco
embedding_dim: 768
- name: Retriever
type: DensePassageRetriever
params:
document_store: DocumentStore # params can reference other components defined in the YAML
top_k: 10
query_embedding_model: ernie-search-base-dual-encoder-marco-en # an example of using ernie search models
share_parameters: True
output_emb_size: 768
embed_title: False
- name: Ranker # custom-name for the component; helpful for visualization & debugging
type: ErnieRanker # pipelines Class name for the component
params:
model_name_or_path: rocketqav2-en-marco-cross-encoder
top_k: 3
use_en: True,
reinitialize: True
- name: TextFileConverter
type: TextConverter
- name: ImageFileConverter
type: ImageToTextConverter
- name: PDFFileConverter
type: PDFToTextConverter
- name: DocxFileConverter
type: DocxToTextConverter
- name: Preprocessor
type: PreProcessor
params:
split_by: word
split_length: 1000
- name: FileTypeClassifier
type: FileTypeClassifier

pipelines:
- name: query
type: Query
nodes:
- name: Retriever
inputs: [Query]
- name: Ranker
inputs: [Retriever]
- name: indexing
type: Indexing
nodes:
- name: FileTypeClassifier
inputs: [File]
- name: TextFileConverter
inputs: [FileTypeClassifier.output_1]
- name: PDFFileConverter
inputs: [FileTypeClassifier.output_2]
- name: DocxFileConverter
inputs: [FileTypeClassifier.output_4]
- name: ImageFileConverter
inputs: [FileTypeClassifier.output_6]
- name: Preprocessor
inputs: [PDFFileConverter, TextFileConverter, DocxFileConverter, ImageFileConverter]
- name: Retriever
inputs: [Preprocessor]
- name: DocumentStore
inputs: [Retriever]
9 changes: 8 additions & 1 deletion pipelines/utils/offline_ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@
parser.add_argument(
'--delete_index',
action='store_true',
help='whether to delete existing index while updating index')
help='Whether to delete existing index while updating index')

parser.add_argument(
'--share_parameters',
action='store_true',
help='Use to control the query and title models sharing the same parameters'
)

args = parser.parse_args()

Expand Down Expand Up @@ -126,6 +132,7 @@ def offline_ann(index_name, doc_dir):
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim,
share_parameters=args.share_parameters,
max_seq_len_query=64,
max_seq_len_passage=256,
batch_size=16,
Expand Down