Skip to content

Commit

Permalink
Merge pull request #83 from xingyaoww/no_new_tokens
Browse files Browse the repository at this point in the history
Use --no_new_tokens to stop adding built-in special tokens
  • Loading branch information
AleHD authored Nov 29, 2023
2 parents 01fa877 + 24a174d commit dead8d2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
12 changes: 8 additions & 4 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def _initalize(self, vocab_extra_ids, vocab_extra_ids_list, new_tokens):
self._inv_vocab[i] = t
self._vocab[t] = i

def _add_special_token(t):
if t not in self.vocab and not new_tokens:
def _add_special_token(t, force=False):
if t not in self.vocab and not new_tokens and not force:
return
if t not in self._vocab:
next_id = len(self._vocab)
Expand Down Expand Up @@ -392,13 +392,17 @@ def _add_special_token(t):
_add_special_token(eos_token)
self._eos_id = self._vocab.get(eos_token)

if not new_tokens:
# default to eos
self._pad_id = self._eos_id

for i in range(vocab_extra_ids):
t = "<extra_id_{}>".format(i)
_add_special_token(t)
_add_special_token(t, force=True)
self._t5_tokens += [t]
if vocab_extra_ids_list:
for t in vocab_extra_ids_list.split(","):
_add_special_token(t)
_add_special_token(t, force=True)
print("Special tokens: {}".format(self._special_tokens))

@property
Expand Down
7 changes: 2 additions & 5 deletions weights_conversion/megatron_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,11 @@ def write_tokenizer(args: Namespace):
hf_tokenizer.pad_token_id = mt_tokenizer.pad

additional_special_tokens = hf_tokenizer.additional_special_tokens
special_tokens = {"additional_special_tokens": additional_special_tokens}
if args.vocab_extra_ids_list:
additional_special_tokens.extend(args.vocab_extra_ids_list.split(","))

hf_tokenizer.add_special_tokens(special_tokens_dict=special_tokens, replace_additional_special_tokens=True)

additional_special_tokens_ids = [mt_tokenizer.vocab.get(t) for t in additional_special_tokens]
hf_tokenizer.additional_special_tokens_ids = additional_special_tokens_ids
for special_token in additional_special_tokens:
hf_tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})

hf_vocab = hf_tokenizer.get_vocab()
tokens_to_check = [
Expand Down

0 comments on commit dead8d2

Please sign in to comment.