Skip to content

Commit

Permalink
Fix for T5 finetuning when starting with pad instead of bos (#6278)
Browse files Browse the repository at this point in the history
* Fix for T5 finetuning when starting with pad instead of bos

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MaximumEntropy and pre-commit-ci[bot] authored Mar 24, 2023
1 parent dfc0e69 commit 797318c
Showing 1 changed file with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,19 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config):
def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
# Regular finetuning datasets will return a list of dicts for each microbatch. But T0 datasets will return a single dict for the global batch.
batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds

processed_batch = self._reconfigure_and_process_inference_batch(
batch, self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
)
processed_batch = self._reconfigure_and_process_inference_batch(batch, data_cfg)

# Call parent validation step to get the loss.
# NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class.
loss = super().validation_step(processed_batch, batch_idx)

predicted_token_ids, _ = self.decode(
tokens_enc=processed_batch['text_enc'], enc_mask=processed_batch['enc_mask'], num_tokens_to_generate=30
tokens_enc=processed_batch['text_enc'],
enc_mask=processed_batch['enc_mask'],
num_tokens_to_generate=30,
bos_id=self.tokenizer.pad_id if data_cfg.replace_bos_with_pad else self.tokenizer.bos_id,
)

# Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
Expand All @@ -317,12 +319,8 @@ def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
pred=pred,
label=label,
metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name,
class_labels=self.cfg.data.validation_ds.metric.get('class_labels', None)
if mode == 'validation'
else self.cfg.data.test_ds.metric.get('class_labels', None),
labels_are_strings=self.cfg.data.validation_ds.metric.get('labels_are_strings', False)
if mode == 'validation'
else self.cfg.data.test_ds.metric.get('labels_are_strings', False),
class_labels=data_cfg.metric.get('class_labels', None),
labels_are_strings=data_cfg.metric.get('labels_are_strings', False),
)
if batch_has_lang_information:
_ = metric(pred, label, category)
Expand Down

0 comments on commit 797318c

Please sign in to comment.