Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tts fixed vocab #6172

Merged
merged 20 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def __init__(
locale="en-US",
punct=True,
non_default_punct_list=None,
fixed_vocab=None,
*,
space=' ',
silence=None,
Expand All @@ -522,6 +523,12 @@ def __init__(
Specify None if implementing custom logic for a new locale.
punct: Whether to reserve grapheme for basic punctuation or not.
non_default_punct_list: List of punctuation marks which will be used instead default, if any.
fixed_vocab: List of valid grapheme/phoneme tokens for the model.
Set only if overriding the default vocab generation process (reading from G2P dict).
If set, any dataset entries that have unincluded graphemes will be filtered out, and any words whose
pronunciations have unincluded phonemes will be treated as OOV.
Please make sure that the grapheme prefixes and cases are consistent with the G2P module's settings.
Defaults to None, which means default vocab generation is used.
space: Space token as string.
silence: Silence token as string (will be disabled if it is None).
apostrophe: Whether to use apostrophe or not.
Expand All @@ -546,8 +553,14 @@ def __init__(
if hasattr(g2p, "phoneme_probability"):
self.phoneme_probability = g2p.phoneme_probability

# Build tokens list
tokens = set(g2p.symbols)
# Build tokens list if fixed_vocab isn't set
if fixed_vocab:
tokens = set(fixed_vocab)
Copy link
Collaborator

@XuesongYang XuesongYang Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this function takes good care of the case that, e.g. 'ö' can be encoded as
b'\xc3\xb6' (one char) as well as b'o\xcc\x88' (two chars). We discussed similar long time ago. https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/tokenizers/text_to_speech/tokenizer_utils.py#L96-L101

Copy link
Collaborator Author

@redoctopus redoctopus Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking here that it's reasonable to assume that the user has passed in a "correct"/"canonical" version of the symbols they want (mostly I'm assuming they're copy/pasting from a previous config or model).

Are you suggesting we run normalize over the user input fixed vocab?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm...i was thinking to apply the same process (calling normalize_unicode_text) as what we did now. But this process is applied to g2p/modules.py in our current implementation rather than in tts_tokenizers.py. I guess replace_symbols func should be a better place?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rlangman for better comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now it is done both in tts_tokenizers.py as part of text_preprocessing_func, as well as in g2p/modules.py. I would favor putting any text normalization in tts_tokenizers.py where possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, added a bit of code to preprocess the fixed vocab symbols in the tokenizer init.

self.set_fixed_vocab = True # Used to check whether dataset entries need filtering
g2p.replace_symbols(tokens)
else:
tokens = set(g2p.symbols)
self.set_fixed_vocab = False

if apostrophe:
tokens.add("'")
Expand All @@ -573,6 +586,8 @@ def __init__(

super().__init__(tokens, oov=oov, sep=sep, add_blank_at=add_blank_at)

self.tokens_set = set(self.tokens) # To save some repeated work when filtering entries

self.punct = punct
self.pad_with_space = pad_with_space

Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/tts/data/tts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
self.phoneme_probability = None
if isinstance(self.text_tokenizer, BaseTokenizer):
self.text_tokenizer_pad_id = text_tokenizer.pad
self.tokens = text_tokenizer.tokens
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
self.phoneme_probability = getattr(self.text_tokenizer, "phoneme_probability", None)
else:
if text_tokenizer_pad_id is None:
Expand All @@ -195,7 +194,6 @@ def __init__(
raise ValueError(f"tokens must be specified if text_tokenizer is not BaseTokenizer")

self.text_tokenizer_pad_id = text_tokenizer_pad_id
self.tokens = tokens
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
self.cache_text = True if self.phoneme_probability is None else False

# Initialize text normalizer if specified
Expand Down
54 changes: 54 additions & 0 deletions nemo/collections/tts/g2p/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,60 @@ def _normalize_dict(self, phoneme_dict_obj: Dict[str, List[List[str]]]) -> Tuple

return g2p_dict, symbols

def replace_symbols(self, symbols, keep_alternate=True):
"""Replaces the vocabulary of symbols with the one given.
Also filters out any entries with illegal graphemes or phonemes according to the new vocab.

Args:
symbols (List, Set): User-provided set of valid symbols, both graphemes and phonemes
keep_alternate (bool): Whether to keep the other pronunciation(s) of a word if not all contain
illegal phonemes (and the word doesn't containi illegal graphemes).
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved
Warning: this may change a word from being ambiguous to having only one valid pronunciation.
Defaults to True.
"""
new_symbols = set(symbols)
if self.symbols == new_symbols:
logging.info("Did not replace G2P valid symbol set since the given set is equivalent to the existing one.")
return
XuesongYang marked this conversation as resolved.
Show resolved Hide resolved

# Keep track of what will need to be deleted or (if keep_alternate=True) replaced
deletion_words = []
replacement_dict = {}

for word, prons in self.phoneme_dict.items():
# Check for illegal grapheme in the word itself
word_graphemes = set(self._prepend_prefix_for_one_word(set_grapheme_case(word, self.grapheme_case)))
word_diff = word_graphemes - new_symbols
if word_diff:
deletion_words.append(word)
continue

# Check for illegal phonemes in the pronunciation(s)
legal_prons = []
for pron in prons:
pron_diff = set(pron) - new_symbols
if not pron_diff:
legal_prons.append(pron)

# Check if at least one pronunciation was illegal
if len(legal_prons) != len(prons):
if not keep_alternate: # Remove the word and entry fully
deletion_words.append(word)
else: # Need to check if all prons were illegal
if not legal_prons:
deletion_words.append(word)
else:
replacement_dict[word] = legal_prons

# Update pronunciation dictionary as needed
for del_word in deletion_words:
del self.phoneme_dict[del_word]

if keep_alternate:
self.phoneme_dict.update(replacement_dict)

self.symbols = new_symbols

def is_unique_in_phoneme_dict(self, word: str) -> bool:
return len(self.phoneme_dict[word]) == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,39 @@ def test_ipa_tokenizer_es_es(self):
chars, tokens = self._parse_text(tokenizer, input_text)

assert chars == expected_output

# @pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_ipa_tokenizer_fixed_vocab(self):
phoneme_dict = self.PHONEME_DICT_EN
phoneme_dict["WOUND"] = ["ˈwaʊnd", "ˈwund"]
g2p = IPAG2P(phoneme_dict=phoneme_dict)

assert "WOUND" in g2p.phoneme_dict

# fmt: off
symbol_vocab = {
'H', 'E', 'L', 'L', 'O',
'W', 'O', 'R', 'L', 'D',
'C', 'A', 'F', 'E',
'W', 'O', 'U', 'N', 'D',
'h', 'ə', 'ˈ', 'ɫ', 'o', 'ʊ',
'ˈ', 'w', 'ɝ', 'ɫ', 'd',
'k', 'ə', 'ˈ', 'f', 'e', 'ɪ',
'ˈ', 'w', 'a', 'ʊ', 'n', 'd',
'ˈ', 'w', 'u', 'n', 'd',
}
# fmt: on
fixed_vocab = symbol_vocab - {'ʊ', 'F'}
tokenizer = IPATokenizer(g2p=g2p, locale="en-US", fixed_vocab=fixed_vocab)

# Make sure phoneme_dict has been updated properly
assert "HELLO" not in tokenizer.g2p.phoneme_dict
assert "WORLD" in tokenizer.g2p.phoneme_dict
assert "CAFE" not in tokenizer.g2p.phoneme_dict
assert len(tokenizer.g2p.phoneme_dict["WOUND"]) == 1
assert tokenizer.g2p.phoneme_dict["WOUND"][0] == list("ˈwund")

chars, tokens = self._parse_text(tokenizer, "Hello, wound")
expected_output = "HELLO, ˈwund"
assert chars == expected_output
61 changes: 61 additions & 0 deletions tests/collections/tts/g2p/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,67 @@ def test_normalize_dict_with_graphemes_and_phonemes(self):
assert g2p.phoneme_dict["JONES"][0] == list("ˈdʒoʊnz")
assert g2p.phoneme_dict["AIRPORT"][0] == list("ˈɛɹˌpɔɹt")

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_replace_symbols(self):
g2p = self._create_g2p(use_chars=True, grapheme_prefix=self.GRAPHEME_PREFIX)

# fmt: off
# Get full vocab without 'i' (phoneme) and 'J' (grapheme)
fixed_symbols = {
f"{self.GRAPHEME_PREFIX}{char}"
for char in {
'H', 'E', 'L', 'L', 'O',
'W', 'O', 'R', 'L', 'D',
'L', 'E', 'A', 'D',
'N', 'V', 'I', 'D', 'I', 'A',
'O', 'N', 'E', 'S',
'A', 'I', 'R', 'P', 'O', 'R', 'T',
}
}.union(
{
'h', 'ə', 'ˈ', 'ɫ', 'o', 'ʊ',
'ˈ', 'w', 'ɝ', 'ɫ', 'd',
'ˈ', 'l', 'ɛ', 'd',
'ˈ', 'l', 'd',
'ɛ', 'n', 'ˈ', 'v', 'ɪ', 'd', 'ə',
'ˈ', 'd', 'ʒ', 'o', 'ʊ', 'n', 'z',
'ˈ', 'ɛ', 'ɹ', 'ˌ', 'p', 'ɔ', 'ɹ', 't',
}
)
# fmt: on

assert len(g2p.phoneme_dict["LEAD"]) == 2
assert len(g2p.phoneme_dict["JONES"]) == 1
assert len(g2p.phoneme_dict["NVIDIA"]) == 1

# Test with keep_alternate set to True (default)
g2p.replace_symbols(symbols=fixed_symbols, keep_alternate=True)

# Check that the alternate pron of "LEAD" was kept
assert len(g2p.phoneme_dict["LEAD"]) == 1
assert g2p.phoneme_dict["LEAD"][0] == list("ˈlɛd")
# Check that filtering was done for unique entries, both grapheme and phoneme
assert "JONES" not in g2p.phoneme_dict
assert "NVIDIA" not in g2p.phoneme_dict
# Check that other words weren't affected
assert g2p.phoneme_dict["HELLO"][0] == list("həˈɫoʊ")
assert g2p.phoneme_dict["WORLD"][0] == list("ˈwɝɫd")
assert g2p.phoneme_dict["AIRPORT"][0] == list("ˈɛɹˌpɔɹt")

# Test with keep_alternate set to False
g2p = self._create_g2p(use_chars=True, grapheme_prefix=self.GRAPHEME_PREFIX)
g2p.replace_symbols(symbols=fixed_symbols, keep_alternate=False)

# Check that both "LEAD" entries were removed
assert "LEAD" not in g2p.phoneme_dict
# Other checks remain the same
assert "JONES" not in g2p.phoneme_dict
assert "NVIDIA" not in g2p.phoneme_dict
assert g2p.phoneme_dict["HELLO"][0] == list("həˈɫoʊ")
assert g2p.phoneme_dict["WORLD"][0] == list("ˈwɝɫd")
assert g2p.phoneme_dict["AIRPORT"][0] == list("ˈɛɹˌpɔɹt")

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_forward_call(self):
Expand Down