From 42348f0142ee5a20be557def18a206ec407506c8 Mon Sep 17 00:00:00 2001 From: karthikrangasai <39360170+karthikrangasai@users.noreply.github.com> Date: Tue, 5 Oct 2021 14:42:22 +0530 Subject: [PATCH] Add val_loss and test_loss calculation and logging for QnA task (#832) --- flash/text/question_answering/data.py | 12 ++++++++++-- flash/text/question_answering/model.py | 10 ++++++---- tests/text/question_answering/test_data.py | 4 ++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 35f3af8df7..778eb2d143 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -87,8 +87,16 @@ def _tokenize_fn(self, samples: Any) -> Callable: if stage == RunningStage.TRAINING: # Preprocess function for training - tokenized_samples = self._prepare_train_features(samples, tokenized_samples) + tokenized_samples, _, _ = self._prepare_train_features(samples, tokenized_samples) elif self._running_stage.evaluating or stage == RunningStage.PREDICTING: + if self._running_stage.evaluating: + tokenized_samples, _sample_mapping, _offset_mapping = self._prepare_train_features( + samples, tokenized_samples + ) + + tokenized_samples["overflow_to_sample_mapping"] = _sample_mapping + tokenized_samples["offset_mapping"] = _offset_mapping + # Preprocess function for eval or predict tokenized_samples = self._prepare_val_features(samples, tokenized_samples) @@ -169,7 +177,7 @@ def _prepare_train_features(self, samples: Any, tokenized_samples: Any): token_end_index -= 1 tokenized_samples["end_positions"].append(token_end_index + 1) - return tokenized_samples + return tokenized_samples, sample_mapping, offset_mapping def _prepare_val_features(self, samples: Any, tokenized_samples: Any): # Since one example might give us several features if it has a long context, we need a map from a feature to diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 664f3d7dfe..4956e951f7 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -251,12 +251,13 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples): def forward(self, batch: Any) -> Any: metadata = batch.pop(DefaultDataKeys.METADATA) outputs = self.model(**batch) + loss = outputs.loss start_logits = outputs.start_logits end_logits = outputs.end_logits generated_answers = self._generate_answers(start_logits, end_logits, metadata) batch[DefaultDataKeys.METADATA] = metadata - return generated_answers + return loss, generated_answers def training_step(self, batch: Any, batch_idx: int) -> Tensor: outputs = self.model(**batch) @@ -265,9 +266,10 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor: return loss def common_step(self, prefix: str, batch: Any) -> torch.Tensor: - generated_answers = self(batch) + loss, generated_answers = self(batch) result = self.compute_metrics(generated_answers, batch[DefaultDataKeys.METADATA]) - self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) + self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False) def compute_metrics(self, generated_tokens, batch): for example in batch: @@ -285,7 +287,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("test", batch) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - generated_answers = self(batch) + _, generated_answers = self(batch) return generated_answers @property diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py index b2a7ba62dc..6a6893babf 100644 --- a/tests/text/question_answering/test_data.py +++ b/tests/text/question_answering/test_data.py @@ -136,6 +136,8 @@ def test_from_files(tmpdir): batch = next(iter(dm.val_dataloader())) assert "input_ids" in batch assert "attention_mask" in batch + assert "start_positions" in batch + assert "end_positions" in batch assert DefaultDataKeys.METADATA in batch assert "context" in batch[DefaultDataKeys.METADATA][0] assert "answer" in batch[DefaultDataKeys.METADATA][0] @@ -145,6 +147,8 @@ def test_from_files(tmpdir): batch = next(iter(dm.test_dataloader())) assert "input_ids" in batch assert "attention_mask" in batch + assert "start_positions" in batch + assert "end_positions" in batch assert DefaultDataKeys.METADATA in batch assert "context" in batch[DefaultDataKeys.METADATA][0] assert "answer" in batch[DefaultDataKeys.METADATA][0]