Skip to content

Commit

Permalink
Scripted tokenizer support for DocModel (facebookresearch#1314)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1314

Adding scripted tokenization support to the most widely used model

OSS Test failures
Waiting for a TorchScript diff to land: https://fb.workplace.com/groups/329222650990087/permalink/632527153992967/

Differential Revision: D20955370

fbshipit-source-id: 6e002136fc0113dfe87a24ca3f82be9a73d1d0bc
  • Loading branch information
snisarg authored and facebook-github-bot committed Apr 16, 2020
1 parent 522e079 commit c1b4155
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions pytext/models/doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TokenTensorizer,
UidTensorizer,
)
from pytext.data.tokenizers import DoNothingTokenizer
from pytext.data.utils import PAD, UNK
from pytext.exporters.exporter import ModelExporter
from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss
Expand Down Expand Up @@ -119,6 +120,13 @@ def torchscriptify(self, tensorizers, traced_model):

input_vocab = tensorizers["tokens"].vocab
max_seq_len = tensorizers["tokens"].max_seq_len or -1
scripted_tokenizer = None
try:
scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
except NotImplementedError:
pass
if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
scripted_tokenizer = None

"""
The input tensor packing memory is allocated/cached for different shapes,
Expand All @@ -139,6 +147,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -148,6 +157,16 @@ def forward(
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
# PyTorch breaks with 2 'not None' checks right now.
if texts is not None:
if tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)]
for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")

Expand All @@ -171,6 +190,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -181,6 +201,16 @@ def forward(
languages: Optional[List[str]] = None,
dense_feat: Optional[List[List[float]]] = None,
):
# PyTorch breaks with 2 'not None' checks right now.
if texts is not None:
if tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)]
for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
Expand Down

0 comments on commit c1b4155

Please sign in to comment.