Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[PERFORMANCE] Improve vocab lookup performance by working with a dict…
Browse files Browse the repository at this point in the history
…() directly (#1382) (#1385)

Co-authored-by: Sheng Zha <[email protected]>

Co-authored-by: shishirb126 <[email protected]>
Co-authored-by: Sheng Zha <[email protected]>
  • Loading branch information
3 people authored Oct 8, 2020
1 parent 3fbe961 commit e65cd41
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/gluonnlp/vocab/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .. import _constants as C
from .. import embedding as emb
from ..data.utils import Counter, DefaultLookupDict, count_tokens
from ..data.utils import Counter, count_tokens

UNK_IDX = 0
_DEPR_PAD = object()
Expand Down Expand Up @@ -219,10 +219,7 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =
# Set up idx_to_token and token_to_idx based on presence of unknown token
self._unknown_token = unknown_token
self._idx_to_token = [unknown_token] if unknown_token else []
if unknown_token:
self._token_to_idx = DefaultLookupDict(UNK_IDX)
else:
self._token_to_idx = {}
self._token_to_idx = dict()

# Handle special tokens
special_tokens = []
Expand Down Expand Up @@ -267,10 +264,6 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] =

if token_to_idx:
self._sort_index_according_to_user_specification(token_to_idx)
if unknown_token:
self._token_to_idx._default = \
self._token_to_idx[unknown_token] # pytype: disable=not-writable


def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size,
min_freq):
Expand Down Expand Up @@ -395,9 +388,17 @@ def __getitem__(self, tokens):
"""

if not isinstance(tokens, (list, tuple)):
return self._token_to_idx[tokens]
if self._unknown_token:
unknown_token_idx = self._token_to_idx[self._unknown_token]
return self._token_to_idx.get(tokens, unknown_token_idx)
else:
return self._token_to_idx[tokens]
else:
return [self._token_to_idx[token] for token in tokens]
if self._unknown_token:
unknown_token_idx = self._token_to_idx[self._unknown_token]
return [self._token_to_idx.get(token, unknown_token_idx) for token in tokens]
else:
return [self._token_to_idx[token] for token in tokens]

def __len__(self):
return len(self._idx_to_token)
Expand Down

0 comments on commit e65cd41

Please sign in to comment.