diff --git a/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py index 07d64ec5b723..68c55ff51a4f 100644 --- a/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py +++ b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py @@ -53,7 +53,7 @@ print("Size of customization vocabulary:", len(custom_phrases)) # Load n-gram mappings vocabulary -ngram_mapping_vocab, ban_ngram = load_ngram_mappings(args.ngram_mappings, max_misspelled_freq=125000) +ngram_mapping_vocab, ban_ngram = load_ngram_mappings(args.ngram_mappings, max_misspelled_freq=args.max_misspelled_freq) # Generate index of custom phrases phrases, ngram2phrases = get_index( diff --git a/examples/nlp/spellchecking_asr_customization/run_infer.sh b/examples/nlp/spellchecking_asr_customization/run_infer.sh index 09da98171c16..b4bbdc4da375 100644 --- a/examples/nlp/spellchecking_asr_customization/run_infer.sh +++ b/examples/nlp/spellchecking_asr_customization/run_infer.sh @@ -31,7 +31,7 @@ BIG_SAMPLE=spellmapper_asr_customization_en/big_sample.txt ## File with input nemo ASR manifest INPUT_MANIFEST=spellmapper_en_evaluation/medical_manifest_ctc.json ## File containing custom words and phrases (plain text) -CUSTOM_VOCAB=spellmapper_en_evaluation/medical_custom_vocab.json +CUSTOM_VOCAB=spellmapper_en_evaluation/medical_custom_vocab.txt ## Other files will be created ## File with index of custom vocabulary diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py index cda551189d78..7385f19b414a 100644 --- a/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py @@ -764,12 +764,30 @@ def check_banned_replacements(src: str, dst: str) -> bool: # anticipated => anticipate if src.endswith("ed") and dst.endswith("e") and src[0:-2] == dst[0:-1]: return True + # blocks => blocked + if src.endswith("s") and dst.endswith("ed") and src[0:-1] == dst[0:-2]: + return True + # blocked => blocks + if src.endswith("ed") and dst.endswith("s") and src[0:-2] == dst[0:-1]: + return True + # lives => lived + if src.endswith("es") and dst.endswith("ed") and src[0:-2] == dst[0:-2]: + return True + # lived => lives + if src.endswith("ed") and dst.endswith("es") and src[0:-2] == dst[0:-2]: + return True # regarded => regard if src.endswith("ed") and src[0:-2] == dst: return True # regard => regarded if dst.endswith("ed") and dst[0:-2] == src: return True + # regardeding => regard + if src.endswith("ing") and src[0:-3] == dst: + return True + # regard => regarding + if dst.endswith("ing") and dst[0:-3] == src: + return True # longer => long if src.endswith("er") and src[0:-2] == dst: return True @@ -782,48 +800,102 @@ def check_banned_replacements(src: str, dst: str) -> bool: # discussing => discussed if src.endswith("ing") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: return True + # live => living + if src.endswith("e") and dst.endswith("ing") and src[0:-1] == dst[0:-3]: + return True + # living => live + if src.endswith("ing") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True # discussion => discussing if src.endswith("ion") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: return True # discussing => discussion if src.endswith("ing") and dst.endswith("ion") and src[0:-3] == dst[0:-3]: return True + # alignment => aligning + if src.endswith("ment") and dst.endswith("ing") and src[0:-4] == dst[0:-3]: + return True + # aligning => alignment + if src.endswith("ing") and dst.endswith("ment") and src[0:-3] == dst[0:-4]: + return True # dispensers => dispensing if src.endswith("ers") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: return True # dispensing => dispensers if src.endswith("ing") and dst.endswith("ers") and src[0:-3] == dst[0:-3]: return True + # integrate => integrity + if src.endswith("ate") and dst.endswith("ity") and src[0:-3] == dst[0:-3]: + return True + # integrity => integrate + if src.endswith("ity") and dst.endswith("ate") and src[0:-3] == dst[0:-3]: + return True # discussion => discussed if src.endswith("ion") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: return True # discussed => discussion if src.endswith("ed") and dst.endswith("ion") and src[0:-2] == dst[0:-3]: return True + # anticipation => anticipate + if src.endswith("ion") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True + # anticipate => anticipation + if src.endswith("e") and dst.endswith("ion") and src[0:-1] == dst[0:-3]: + return True # incremental => increment if src.endswith("ntal") and dst.endswith("nt") and src[0:-4] == dst[0:-2]: return True # increment => incremental if src.endswith("nt") and dst.endswith("ntal") and src[0:-2] == dst[0:-4]: return True + # national => nation + if src.endswith("nal") and dst.endswith("n") and src[0:-3] == dst[0:-1]: + return True + # nation => national + if src.endswith("n") and dst.endswith("nal") and src[0:-1] == dst[0:-3]: + return True + # significantly => significant + if src.endswith("ntly") and dst.endswith("nt") and src[0:-4] == dst[0:-2]: + return True + # significant => significantly + if src.endswith("nt") and dst.endswith("ntly") and src[0:-2] == dst[0:-4]: + return True # delivery => deliverer if src.endswith("ery") and dst.endswith("erer") and src[0:-3] == dst[0:-4]: return True # deliverer => delivery if src.endswith("erer") and dst.endswith("ery") and src[0:-4] == dst[0:-3]: return True + # deliver => deliverer + if src.endswith("er") and dst.endswith("erer") and src[0:-2] == dst[0:-4]: + return True + # deliverer => deliver + if src.endswith("erer") and dst.endswith("er") and src[0:-4] == dst[0:-2]: + return True # comparably => comparable if src.endswith("bly") and dst.endswith("ble") and src[0:-3] == dst[0:-3]: return True # comparable => comparably if src.endswith("ble") and dst.endswith("bly") and src[0:-3] == dst[0:-3]: return True + # comparably => comparability + if src.endswith("bly") and dst.endswith("bility") and src[0:-3] == dst[0:-6]: + return True + # comparability => comparably + if src.endswith("bility") and dst.endswith("bly") and src[0:-6] == dst[0:-3]: + return True # beautiful => beautifully if src.endswith("l") and dst.endswith("lly") and src[0:-1] == dst[0:-3]: return True # beautifully => beautiful if src.endswith("lly") and dst.endswith("l") and src[0:-3] == dst[0:-1]: return True + # active => actively + if src.endswith("e") and dst.endswith("ely") and src[0:-1] == dst[0:-3]: + return True + # actively => active + if src.endswith("ely") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True # america => american if src.endswith("a") and dst.endswith("an") and src[0:-1] == dst[0:-2]: return True @@ -836,6 +908,18 @@ def check_banned_replacements(src: str, dst: str) -> bool: # investing => reinvesting if dst.startswith("re") and dst[2:] == src: return True + # unchanged => changed + if src.startswith("un") and src[2:] == dst: + return True + # changed => unchanged + if dst.startswith("un") and dst[2:] == src: + return True + # disrespected => respected + if src.startswith("dis") and src[3:] == dst: + return True + # respected => disrespected + if dst.startswith("dis") and dst[3:] == src: + return True # outperformance => performance if src.startswith("out") and src[3:] == dst: return True