Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Make dict_embedding Torchscript friendly (#1240)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1240

Existing implementation cannot be scripted. Error P125669619

Reviewed By: twwhatever

Differential Revision: D19646957

fbshipit-source-id: c536b60467ef2c820bf2b8f5b949835f2c3f39c6
  • Loading branch information
arbabu123 authored and facebook-github-bot committed Feb 4, 2020
1 parent d000e31 commit e5a5559
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
27 changes: 15 additions & 12 deletions pytext/models/embeddings/dict_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from pytext.data.tensorizers import Tensorizer
from pytext.data.utils import PAD_INDEX, UNK_INDEX, Vocabulary
from pytext.fields import FieldMeta
from pytext.utils import cuda

from .embedding_base import EmbeddingBase


class DictEmbedding(EmbeddingBase, nn.Embedding):
class DictEmbedding(EmbeddingBase):
"""
Module for dictionary feature embeddings for tokens. Dictionary features are
also known as gazetteer features. These are per token discrete features that
Expand Down Expand Up @@ -102,13 +101,15 @@ def __init__(
unk_index: int = UNK_INDEX,
mobile: bool = False,
) -> None:
self.pad_index = pad_index
super().__init__(embed_dim)
self.unk_index = unk_index
EmbeddingBase.__init__(self, embed_dim)
nn.Embedding.__init__(
self, num_embeddings, embed_dim, padding_idx=self.pad_index
self.pad_index = pad_index
self.embedding = nn.Embedding(
num_embeddings, embed_dim, padding_idx=self.pad_index
)
self.pooling_type = pooling_type
# Temporary workaround till https://github.com/pytorch/pytorch/issues/32840
# is resolved
self.pooling_type = str(pooling_type)
self.mobile = mobile

def find_and_replace(
Expand All @@ -124,7 +125,7 @@ def find_and_replace(
else:
return torch.where(
tensor == find_val,
cuda.GetTensor(torch.full_like(tensor, replace_val)),
torch.full_like(tensor, replace_val, device=tensor.device),
tensor,
)

Expand Down Expand Up @@ -157,23 +158,25 @@ def forward(
# convert all unk indices to pad indices
feats = self.find_and_replace(feats, self.unk_index, self.pad_index)

dict_emb = super().forward(feats)
dict_emb = self.embedding(feats)

# Calculate weighted average of the embeddings
weighted_embds = dict_emb * weights.unsqueeze(2)
new_emb_shape = torch.cat(
(
batch_size.view(1),
max_toks.view(1),
torch.LongTensor([-1]),
torch.LongTensor([weighted_embds.size()[-1]]),
torch.tensor([-1]).long(),
torch.tensor([weighted_embds.size()[-1]]).long(),
)
)
weighted_embds = torch.onnx.operators.reshape_from_tensor_shape(
weighted_embds, new_emb_shape
)

if self.pooling_type == PoolingType.MEAN:
# Temporary workaround till https://github.com/pytorch/pytorch/issues/32840
# is resolved
if self.pooling_type == "mean":
reduced_embeds = (
torch.sum(weighted_embds, dim=2) / lengths.unsqueeze(2).float()
)
Expand Down
4 changes: 2 additions & 2 deletions pytext/models/test/dict_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_basic(self):
embed_dim=output_dim,
pooling_type=PoolingType.MEAN,
)
self.assertEqual(embedding_module.weight.size(0), num_embeddings)
self.assertEqual(embedding_module.weight.size(1), output_dim)
self.assertEqual(embedding_module.embedding.weight.size(0), num_embeddings)
self.assertEqual(embedding_module.embedding.weight.size(1), output_dim)

# The first and last tokens should be mapped to the zero vector.
# This is due to the invariant that both unk and pad are considered
Expand Down

0 comments on commit e5a5559

Please sign in to comment.