diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 8bb653739b3..d8c6c7bf1d4 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -42,12 +42,11 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, ): - super().__init__() + device = "cuda" if torch.cuda.is_available() else "cpu" + super().__init__(device=device) self._model = model self._tokenizer = tokenizer - self._device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) + self._device = torch.device(device) self._max_seq_length = 2048 if max_seq_length is None else max_seq_length @property