Skip to content

Commit

Permalink
fix ModelForQABasic
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 6, 2020
1 parent 0e13a58 commit b8c85bb
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion scripts/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,67 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask):
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
scores = self.qa_outputs(contextual_embedding)
scores = self.qa_outputs(contextual_embeddings)
start_scores = scores[:, :, 0]
end_scores = scores[:, :, 1]
start_logits = masked_logsoftmax(F, start_scores, mask=p_mask, axis=-1)
end_logits = masked_logsoftmax(F, end_scores, mask=p_mask, axis=-1)
return start_logits, end_logits

def inference(self, tokens, token_types, valid_length, p_mask,
start_top_n: int = 5, end_top_n: int = 5):
"""Get the inference result with beam search
Parameters
----------
tokens
The input tokens. Shape (batch_size, sequence_length)
token_types
The input token types. Shape (batch_size, sequence_length)
valid_length
The valid length of the tokens. Shape (batch_size,)
p_mask
The mask which indicates that some tokens won't be used in the calculation.
Shape (batch_size, sequence_length)
start_top_n
The number of candidates to select for the start position.
end_top_n
The number of candidates to select for the end position.
Returns
-------
start_top_logits
The top start logits
Shape (batch_size, start_top_n)
start_top_index
Index of the top start logits
Shape (batch_size, start_top_n)
end_top_logits
The top end logits.
Shape (batch_size, end_top_n)
end_top_index
Index of the top end logits
Shape (batch_size, end_top_n)
"""
# Shape (batch_size, sequence_length, C)
if self.use_segmentation:
contextual_embeddings = self.backbone(tokens, token_types, valid_length)
else:
contextual_embeddings = self.backbone(tokens, valid_length)
scores = self.qa_outputs(contextual_embeddings)
start_scores = scores[:, :, 0]
end_scores = scores[:, :, 1]
start_logits = masked_logsoftmax(mx.nd, start_scores, mask=p_mask, axis=-1)
end_logits = masked_logsoftmax(mx.nd, end_scores, mask=p_mask, axis=-1)
# The shape of start_top_index will be (..., start_top_n)
start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1,
ret_typ='both')
# Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top)
# So that for each start position, there are end_n_top end positions on the third dim.
end_top_logits, end_top_index = mx.npx.topk(end_logits, k=end_top_n, axis=-1,
ret_typ='both')
return start_top_logits, start_top_index, end_top_logits, end_top_index


@use_np
class ModelForQAConditionalV1(HybridBlock):
Expand Down

0 comments on commit b8c85bb

Please sign in to comment.