Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add val_loss and test_loss calculation and logging for QnA task (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikrangasai authored Oct 5, 2021
1 parent 8477dc2 commit 42348f0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
12 changes: 10 additions & 2 deletions flash/text/question_answering/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,16 @@ def _tokenize_fn(self, samples: Any) -> Callable:

if stage == RunningStage.TRAINING:
# Preprocess function for training
tokenized_samples = self._prepare_train_features(samples, tokenized_samples)
tokenized_samples, _, _ = self._prepare_train_features(samples, tokenized_samples)
elif self._running_stage.evaluating or stage == RunningStage.PREDICTING:
if self._running_stage.evaluating:
tokenized_samples, _sample_mapping, _offset_mapping = self._prepare_train_features(
samples, tokenized_samples
)

tokenized_samples["overflow_to_sample_mapping"] = _sample_mapping
tokenized_samples["offset_mapping"] = _offset_mapping

# Preprocess function for eval or predict
tokenized_samples = self._prepare_val_features(samples, tokenized_samples)

Expand Down Expand Up @@ -169,7 +177,7 @@ def _prepare_train_features(self, samples: Any, tokenized_samples: Any):
token_end_index -= 1
tokenized_samples["end_positions"].append(token_end_index + 1)

return tokenized_samples
return tokenized_samples, sample_mapping, offset_mapping

def _prepare_val_features(self, samples: Any, tokenized_samples: Any):
# Since one example might give us several features if it has a long context, we need a map from a feature to
Expand Down
10 changes: 6 additions & 4 deletions flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples):
def forward(self, batch: Any) -> Any:
metadata = batch.pop(DefaultDataKeys.METADATA)
outputs = self.model(**batch)
loss = outputs.loss
start_logits = outputs.start_logits
end_logits = outputs.end_logits

generated_answers = self._generate_answers(start_logits, end_logits, metadata)
batch[DefaultDataKeys.METADATA] = metadata
return generated_answers
return loss, generated_answers

def training_step(self, batch: Any, batch_idx: int) -> Tensor:
outputs = self.model(**batch)
Expand All @@ -265,9 +266,10 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor:
return loss

def common_step(self, prefix: str, batch: Any) -> torch.Tensor:
generated_answers = self(batch)
loss, generated_answers = self(batch)
result = self.compute_metrics(generated_answers, batch[DefaultDataKeys.METADATA])
self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True)
self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False)

def compute_metrics(self, generated_tokens, batch):
for example in batch:
Expand All @@ -285,7 +287,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
self.common_step("test", batch)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
generated_answers = self(batch)
_, generated_answers = self(batch)
return generated_answers

@property
Expand Down
4 changes: 4 additions & 0 deletions tests/text/question_answering/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def test_from_files(tmpdir):
batch = next(iter(dm.val_dataloader()))
assert "input_ids" in batch
assert "attention_mask" in batch
assert "start_positions" in batch
assert "end_positions" in batch
assert DefaultDataKeys.METADATA in batch
assert "context" in batch[DefaultDataKeys.METADATA][0]
assert "answer" in batch[DefaultDataKeys.METADATA][0]
Expand All @@ -145,6 +147,8 @@ def test_from_files(tmpdir):
batch = next(iter(dm.test_dataloader()))
assert "input_ids" in batch
assert "attention_mask" in batch
assert "start_positions" in batch
assert "end_positions" in batch
assert DefaultDataKeys.METADATA in batch
assert "context" in batch[DefaultDataKeys.METADATA][0]
assert "answer" in batch[DefaultDataKeys.METADATA][0]
Expand Down

0 comments on commit 42348f0

Please sign in to comment.