Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Scripted tokenizer support for DocModel #1314

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pytext/models/doc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TokenTensorizer,
UidTensorizer,
)
from pytext.data.tokenizers import DoNothingTokenizer
from pytext.data.utils import PAD, UNK
from pytext.exporters.exporter import ModelExporter
from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss
Expand Down Expand Up @@ -119,6 +120,13 @@ def torchscriptify(self, tensorizers, traced_model):

input_vocab = tensorizers["tokens"].vocab
max_seq_len = tensorizers["tokens"].max_seq_len or -1
scripted_tokenizer: Optional[jit.ScriptModule] = None
try:
scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
except NotImplementedError:
pass
if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
scripted_tokenizer = None

"""
The input tensor packing memory is allocated/cached for different shapes,
Expand All @@ -139,6 +147,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -148,6 +157,16 @@ def forward(
tokens: Optional[List[List[str]]] = None,
languages: Optional[List[str]] = None,
):
# PyTorch breaks with 2 'not None' checks right now.
if texts is not None:
if tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)]
for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")

Expand All @@ -171,6 +190,7 @@ def __init__(self):
self.output_layer = output_layer
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
self.max_seq_len = jit.Attribute(max_seq_len, int)
self.tokenizer = scripted_tokenizer

@jit.script_method
def forward(
Expand All @@ -181,6 +201,16 @@ def forward(
languages: Optional[List[str]] = None,
dense_feat: Optional[List[List[float]]] = None,
):
# PyTorch breaks with 2 'not None' checks right now.
if texts is not None:
if tokens is not None:
raise RuntimeError("Can't set both tokens and texts")
if self.tokenizer is not None:
tokens = [
[t[0] for t in self.tokenizer.tokenize(text)]
for text in texts
]

if tokens is None:
raise RuntimeError("tokens is required")
if dense_feat is None:
Expand Down