Skip to content

Commit

Permalink
Fix skip generation (NVIDIA#7270)
Browse files Browse the repository at this point in the history
* Fix skip generation

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add metric condition

Signed-off-by: Cheng-Ping Hsieh <[email protected]>

---------

Signed-off-by: Cheng-Ping Hsieh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <[email protected]>
Signed-off-by: Siddharth Tyagi <[email protected]>
Signed-off-by: Siddharth Tyagi <[email protected]>
  • Loading branch information
3 people authored and siddhartht130 committed Aug 23, 2023
1 parent dc3d723 commit 21977bd
Showing 1 changed file with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -402,21 +402,24 @@ def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0):
metadata = batch.get('metadata', [{}] * len(batch['tokens']))
loss = super().validation_step(itertools.chain([batch]), batch_idx)

# We need _inference_config to get generation params
# add_BOS and tokens_to_generate are set in dataset
if self.get_inference_config() is None:
self.set_inference_config(inference_config={})
self._inference_config['add_BOS'] = data_cfg.add_bos
self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate')

output = self.predict_step(batch, batch_idx, dataloader_idx)

inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']]
labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']]
preds_text = [
self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')])
for t, l in zip(output['token_ids'], batch['context_lengths'])
]
if data_cfg.get("write_predictions_to_file", False) or data_cfg.metric.name != 'loss':
# We need _inference_config to get generation params
# add_BOS and tokens_to_generate are set in dataset
if self.get_inference_config() is None:
self.set_inference_config(inference_config={})
self._inference_config['add_BOS'] = data_cfg.add_bos
self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate')

output = self.predict_step(batch, batch_idx, dataloader_idx)
inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']]
labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']]
preds_text = [
self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')])
for t, l in zip(output['token_ids'], batch['context_lengths'])
]
else:
inputs_text, labels_text, preds_text = [], [], []

outputs = {
'loss': loss,
'preds': preds_text, # [str]
Expand Down

0 comments on commit 21977bd

Please sign in to comment.