diff --git a/pytext/models/doc_model.py b/pytext/models/doc_model.py index 190f0de8c..55bec6c2e 100644 --- a/pytext/models/doc_model.py +++ b/pytext/models/doc_model.py @@ -14,6 +14,7 @@ TokenTensorizer, UidTensorizer, ) +from pytext.data.tokenizers import DoNothingTokenizer from pytext.data.utils import PAD, UNK from pytext.exporters.exporter import ModelExporter from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss @@ -119,6 +120,13 @@ def torchscriptify(self, tensorizers, traced_model): input_vocab = tensorizers["tokens"].vocab max_seq_len = tensorizers["tokens"].max_seq_len or -1 + scripted_tokenizer: Optional[jit.ScriptModule] = None + try: + scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify() + except NotImplementedError: + pass + if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer): + scripted_tokenizer = None """ The input tensor packing memory is allocated/cached for different shapes, @@ -139,6 +147,7 @@ def __init__(self): self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) + self.tokenizer = scripted_tokenizer @jit.script_method def forward( @@ -148,6 +157,16 @@ def forward( tokens: Optional[List[List[str]]] = None, languages: Optional[List[str]] = None, ): + # PyTorch breaks with 2 'not None' checks right now. + if texts is not None: + if tokens is not None: + raise RuntimeError("Can't set both tokens and texts") + if self.tokenizer is not None: + tokens = [ + [t[0] for t in self.tokenizer.tokenize(text)] + for text in texts + ] + if tokens is None: raise RuntimeError("tokens is required") @@ -171,6 +190,7 @@ def __init__(self): self.output_layer = output_layer self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int) self.max_seq_len = jit.Attribute(max_seq_len, int) + self.tokenizer = scripted_tokenizer @jit.script_method def forward( @@ -181,6 +201,16 @@ def forward( languages: Optional[List[str]] = None, dense_feat: Optional[List[List[float]]] = None, ): + # PyTorch breaks with 2 'not None' checks right now. + if texts is not None: + if tokens is not None: + raise RuntimeError("Can't set both tokens and texts") + if self.tokenizer is not None: + tokens = [ + [t[0] for t in self.tokenizer.tokenize(text)] + for text in texts + ] + if tokens is None: raise RuntimeError("tokens is required") if dense_feat is None: