Skip to content

Commit

Permalink
feat: remove bfloat16 for cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
gwkrsrch committed Nov 14, 2022
1 parent 6f8a40d commit 1f4b27c
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 6 deletions.
2 changes: 0 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def demo_process(input_img):
pretrained_model.half()
device = torch.device("cuda")
pretrained_model.to(device)
else:
pretrained_model.encoder.to(torch.bfloat16)

pretrained_model.eval()

Expand Down
2 changes: 0 additions & 2 deletions donut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,6 @@ def inference(
if self.device.type == "cuda": # half is not compatible in cpu implementation.
image_tensors = image_tensors.half()
image_tensors = image_tensors.to(self.device)
else:
image_tensors = image_tensors.to(torch.bfloat16)

if prompt_tensors is None:
prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
Expand Down
2 changes: 0 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def test(args):
if torch.cuda.is_available():
pretrained_model.half()
pretrained_model.to("cuda")
else:
pretrained_model.encoder.to(torch.bfloat16)

pretrained_model.eval()

Expand Down

0 comments on commit 1f4b27c

Please sign in to comment.