Skip to content

Commit

Permalink
fix save_best missing chpt bug, update for setup_tokenizer() changes (#…
Browse files Browse the repository at this point in the history
…3932)

* fix save_best missing chpt bug, update for setup_tokenizer() changes

Signed-off-by: ekmb <[email protected]>

* style fix

Signed-off-by: ekmb <[email protected]>
  • Loading branch information
ekmb authored and ericharper committed Apr 8, 2022
1 parent 29cce8e commit 3e7e042
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ decoder_exp_manager:
save_top_k: 3
monitor: "val_loss"
mode: "min"
save_best_model: True

# Data
data:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,6 @@ def main(cfg: DictConfig) -> None:
tagger_exp_manager = cfg.get('tagger_exp_manager', None)
exp_manager(tagger_trainer, tagger_exp_manager)
tagger_trainer.fit(tagger_model)
if (
tagger_exp_manager
and tagger_exp_manager.get('create_checkpoint_callback', False)
and cfg.tagger_model.nemo_path
):
tagger_model.to(tagger_trainer.accelerator.root_device)
tagger_model.save_to(cfg.tagger_model.nemo_path)
logging.info('Training finished!')

# Train the decoder
Expand All @@ -125,13 +118,6 @@ def main(cfg: DictConfig) -> None:
decoder_exp_manager = cfg.get('decoder_exp_manager', None)
exp_manager(decoder_trainer, decoder_exp_manager)
decoder_trainer.fit(decoder_model)
if (
decoder_exp_manager
and decoder_exp_manager.get('create_checkpoint_callback', False)
and cfg.decoder_model.nemo_path
):
decoder_model.to(decoder_trainer.accelerator.root_device)
decoder_model.save_to(cfg.decoder_model.nemo_path)
logging.info('Training finished!')

# Evaluation after training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def output_module(self):
return self

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer, add_prefix_space=True)
self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer, add_prefix_space=True)
super().__init__(cfg=cfg, trainer=trainer)
self.num_labels = len(constants.ALL_TAG_LABELS)
self.mode = cfg.get('mode', 'joint')

self.model = AutoModelForTokenClassification.from_pretrained(cfg.transformer, num_labels=self.num_labels)
self.transformer_name = cfg.transformer
self.max_sequence_len = cfg.get('max_sequence_len', self._tokenizer.model_max_length)
self.max_sequence_len = cfg.get('max_sequence_len', self.tokenizer.model_max_length)

# Loss Functions
self.loss_fct = nn.CrossEntropyLoss(ignore_index=constants.LABEL_PAD_TOKEN_ID)
Expand Down Expand Up @@ -175,9 +175,7 @@ def _infer(self, sents: List[List[str]], inst_directions: List[str]):
texts.append([prefix] + sent)

# Apply the model
encodings = self._tokenizer(
texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt'
)
encodings = self.tokenizer(texts, is_split_into_words=True, padding=True, truncation=True, return_tensors='pt')

inputs = encodings
encodings_reduced = None
Expand All @@ -186,7 +184,7 @@ def _infer(self, sents: List[List[str]], inst_directions: List[str]):
# if an input symbol is missing in the tokenizer's vocabulary (such as emoji or a Chinese character), it could be skipped
len_texts = [len(x) for x in texts]
len_ids = [
len(self._tokenizer.convert_ids_to_tokens(x, skip_special_tokens=True)) for x in encodings['input_ids']
len(self.tokenizer.convert_ids_to_tokens(x, skip_special_tokens=True)) for x in encodings['input_ids']
]
idx_valid = [i for i, (t, enc) in enumerate(zip(len_texts, len_ids)) if enc >= t]

Expand Down Expand Up @@ -346,7 +344,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str):
tagger_data_augmentation = cfg.get('tagger_data_augmentation', False)
dataset = TextNormalizationTaggerDataset(
input_file=input_file,
tokenizer=self._tokenizer,
tokenizer=self.tokenizer,
tokenizer_name=self.transformer_name,
mode=self.mode,
tagger_data_augmentation=tagger_data_augmentation,
Expand All @@ -355,7 +353,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str):
use_cache=cfg.get('use_cache', False),
max_insts=cfg.get('max_insts', -1),
)
data_collator = DataCollatorForTokenClassification(self._tokenizer)
data_collator = DataCollatorForTokenClassification(self.tokenizer)
dl = torch.utils.data.DataLoader(
dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, collate_fn=data_collator
)
Expand Down
2 changes: 2 additions & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ def on_train_end(self, trainer, pl_module):

# Load the best model and then re-save it
if self.save_best_model:
# wait for all processes
trainer.training_type_plugin.barrier("SaveBestCheckpointConnector.resume_end")
if self.best_model_path == "":
logging.warning(
f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints "
Expand Down

0 comments on commit 3e7e042

Please sign in to comment.