Skip to content

Commit

Permalink
Scripted tokenizer support for DocModel (facebookresearch#1314)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#1314

Adding scripted tokenization support to the most widely used model

Differential Revision: D20955370

fbshipit-source-id: 91c60db42c8eb8deac8afb573942053ac9555e99
  • Loading branch information
snisarg authored and facebook-github-bot committed Apr 10, 2020
1 parent f2b55ff commit e17514d
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions pytext/models/doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TokenTensorizer,
UidTensorizer,
)
from pytext.data.tokenizers import DoNothingTokenizer
from pytext.data.utils import PAD, UNK
from pytext.exporters.exporter import ModelExporter
from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss
Expand Down Expand Up @@ -115,6 +116,13 @@ def torchscriptify(self, tensorizers, traced_model):

input_vocab = tensorizers["tokens"].vocab
max_seq_len = tensorizers["tokens"].max_seq_len or -1
scripted_tokenizer = None
try:
scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
except NotImplementedError:
pass
if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
scripted_tokenizer = None

"""
The input tensor packing memory is allocated/cached for different shapes,
Expand All @@ -135,6 +143,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -144,6 +153,13 @@ def forward(
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
if texts is not None and tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None and texts is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)] for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")

Expand All @@ -167,6 +183,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -177,6 +194,13 @@ def forward(
languages: Optional[List[str]] = None,
dense_feat: Optional[List[List[float]]] = None,
):
if texts is not None and tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None and texts is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)] for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
Expand Down

0 comments on commit e17514d

Please sign in to comment.