Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference time cpu vs gpu #42

Open
ganga7445 opened this issue Nov 6, 2023 · 3 comments
Open

inference time cpu vs gpu #42

ganga7445 opened this issue Nov 6, 2023 · 3 comments

Comments

@ganga7445
Copy link

I have used gte-tiny embeddings for my custom NER model and need to speed up the inference time.
below are stats for different batch sizes.

Batch Size Average Inference Time (ms)- GPU Average Inference Time (ms)- CPU
16 0.14945 1.23388
32 0.28 3.24456
64 0.51582 6.57234
128 1.10669 13.73319
256 2.24729 28.236

Is there any specific method to enhance it? @tomaarsen

@tomaarsen
Copy link
Owner

You may experience improved speed if you use SpanMarkerModel.from_pretrained(..., torch_dtype=torch.float16) or torch.bfloat16. See e.g.:

import time
import torch
from span_marker import SpanMarkerModel

model = SpanMarkerModel.from_pretrained("tomaarsen/span-marker-roberta-large-fewnerd-fine-super", torch_dtype=torch.bfloat16, device_map="cuda")
# model = SpanMarkerModel.from_pretrained("tomaarsen/span-marker-roberta-large-fewnerd-fine-super", device_map="cuda")

text = [
    "Leonardo da Vinci recently published a scientific paper on combatting Mitocromulent disease. Leonardo da Vinci painted the most famous painting in existence: the Mona Lisa.",
    "Leonardo da Vinci scored a critical goal towards the end of the second half. Leonardo da Vinci controversially veto'd a bill regarding public health care last friday. Leonardo da Vinci was promoted to Sergeant after his outstanding work in the war."
]
BS = 64
N = 500
model.predict(text * 50, batch_size=BS)
start_t = time.time()
model.predict(text * N, batch_size=BS)
print(f"{time.time() - start_t:8f}s for {N * 2} samples with batch_size={BS} and torch_dtype={model.dtype}.")

This gave me:

20.745640s for 1000 samples with batch_size=64 and torch_dtype=torch.float16.
16.534876s for 1000 samples with batch_size=64 and torch_dtype=torch.bfloat16.

and

39.655506s for 1000 samples with batch_size=64 and torch_dtype=torch.float32.

Note that float16 is not available on CPU though! Not sure about bfloat16.

If you have a Linux (or Mac?) device, then you can also use load_in_8bit=True and load_in_4bit=True by installing bitsandbytes, but I don't know if that improves inference speed - this is also only for CUDA.

Beyond that the steps to increase the inference speeds become pretty challenging. Hope this helps a bit.

Also, you can process about 8 sentences per second with CPU and about 110 sentences per second in GPU, is that not sufficiently fast yet?

  • Tom Aarsen

@ganga7445
Copy link
Author

thanku @tomaarsen
Using torch.float16 was working for me. It would be excellent if the operation could be completed in less than one second with a batch size of 256.

Batch Size Average Inference Time (ms) new inference time(ms)
16 0.14945 0.09211015701
32 0.28 0.1645913124
64 0.51582 0.2973537445
128 1.10669 0.6381671429
256 2.24729 1.238643169

@tomaarsen
Copy link
Owner

@polodealvarado started working on ONNX support here: #26 (comment)
If we can make it work, perhaps then we can improve the speed even further. Until then, it will be hard to get even faster results. Less than a second for a batch size of 256 equals 256 sentences per second, that is already quite efficient.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants