Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Update device type issues present in _generate_answers method.
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikrangasai committed Aug 27, 2021
1 parent adbb4aa commit c149c37
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples):
}

# Go through all possibilities for the `n_best_size` greater start and end logits.
start_indexes: List[int] = np.argsort(start_logits.clone().detach().numpy())[
start_indexes: List[int] = np.argsort(start_logits.clone().detach().cpu().numpy())[
-1 : -self.n_best_size - 1 : -1
].tolist()
end_indexes: List[int] = np.argsort(end_logits.clone().detach().numpy())[
end_indexes: List[int] = np.argsort(end_logits.clone().detach().cpu().numpy())[
-1 : -self.n_best_size - 1 : -1
].tolist()

Expand Down

0 comments on commit c149c37

Please sign in to comment.