Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions scripts/inference/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,60 @@
# Inference scripts for BLOOM

To run a server using HuggingFace (requires [accelerate](https://github.com/huggingface/accelerate) to be installed):
```
python scripts/inference/bloom-accelerate-server.py --model_name bigscience/bloom --dtype bf16 --log_file data.log --host $ADDRESS --port $PORT
```

To run a server using deepspeed (requires [DeepSpeed MII](https://github.com/microsoft/DeepSpeed-mii) to be installed):
```
export DS_CACHE=<path where to dump pre-sharded 8-TP checkpoints>

deepspeed --num_gpus 8 scripts/inference/cache-ds-model.py --model_name bigscience/bloom --dtype fp16

python scripts/inference/bloom-ds-server.py --model_name bigscience/bloom --dtype fp16 --log_file data.log --host $ADDRESS --port $PORT
```

Usage:
Currently, the script supports 3 method:
1. The main generate method
```
curl -H "Content-Type: application/json" -X POST -d '{ "input_text": "India is a country of", "top_k": "5", "top_p": "0.9", "temperature": "0.7", "min_length": "1", "max_new_tokens": "40" }' http://$ADDRESS:$PORT/generate/
```
returns
```
{"output_text":" many languages and cultures. The country is a melting pot of different cultures and languages. The country is a home to more than 1.2 billion people. The country is a home to more than 22","query_id":8,"total_time_taken":"19.358 s"}
```
2. Method that returns the model description
```
curl -H "Content-Type: application/json" -X GET http://$ADDRESS:$PORT/about/
```
returns
```
Please don't send any personal information to this endpoint. We are logging your data.

Usage:
A request object should look like:
{
input_text: "Hello, I'm a model",
"top_k": 5,
"top_p": 0.9,
"temperature": 0.7,
"min_length": 1,
"max_new_tokens": 40,
}

Default values (use if not provided in request object):
top_k = 50
top_p = 1
temperature = 1
min_length = 1
max_new_tokens = 40
```
3. Method to check GPU usage
```
curl -H "Content-Type: application/json" -X GET http://$ADDRESS:$PORT/gpu/
```
returns the nvidia-smi output
## BLOOM Inference solutions

Here are some stats on JeanZay's 8x80GB A100 node w/ 512GB of CPU memory:
Expand All @@ -14,7 +69,7 @@ Throughput in msecs:

| project \ bs | 1 | 8 | 16 | 32 | 64 | 128 |
| :----------- | :---- | :---- | :---- | :---- | :---- | :--- |
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | omm |
| accelerate | 230.38 | 31.78 | 17.84 | 10.89 | oom | oom |
| ds-inference | 40.57 | 5.23 | | | 2.77 | 0.66 |
| ds-zero | 283 | 34.88 | oom | oom | oom | oom |

Expand Down Expand Up @@ -192,4 +247,4 @@ $ python scripts/inference/bloom-accelerate-inference.py --name bigscience/bloom
[...]


```
```
93 changes: 31 additions & 62 deletions scripts/inference/bloom-accelerate-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import gc
import torch
import math
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import generate_, get_max_memory_per_gpu_dict


def get_args():
parser = argparse.ArgumentParser()
Expand All @@ -18,46 +20,6 @@ def get_args():

return parser.parse_args()

def get_max_memory_per_gpu_dict(dtype, model_name):
""" try to generate the memory map based on what we know about the model and the available hardware """

# figure out the memory map - the minimum per gpu required to load the model
n_gpus = torch.cuda.device_count()

if model_name == "bigscience/bloom" and n_gpus == 8 and torch.cuda.get_device_properties(0).total_memory > 79*2**30:
# hand crafted optimized memory map for 8x80 setup over BLOOM
# this works with bs=40
return {0: '0GIB', 1: '51GIB', 2: '51GIB', 3: '51GIB', 4: '51GIB', 5: '51GIB', 6: '51GIB', 7: '51GIB'}

try:
# model_params calculation, as we don't have a model yet to do:
#model_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

config = AutoConfig.from_pretrained(model_name)
h = config.n_embed
l = config.n_layer
v = config.vocab_size
# from https://github.com/bigscience-workshop/bigscience/tree/6917a3b5fefcf439d3485ca184b4d9f6ab605150/math#model-sizing
model_params = l*(12*h**2 + 13*h) + v*h + 4*h
except:
print(f"The model {model_name} has a broken config file. Please notify the owner")
raise

bytes = torch.finfo(dtype).bits / 8
param_memory_total_in_bytes = model_params * bytes
# add 5% since weight sizes aren't the same and some GPU may need more memory
param_memory_per_gpu_in_bytes = int(param_memory_total_in_bytes / n_gpus * 1.05)
print(f"Estimating {param_memory_per_gpu_in_bytes/2**30:0.2f}GB per gpu for weights")

# check the real available memory
# load cuda kernels first and only measure the real free memory after loading (shorter by ~2GB)
torch.ones(1).cuda()
max_memory_per_gpu_in_bytes = torch.cuda.mem_get_info(0)[0]
if max_memory_per_gpu_in_bytes < param_memory_per_gpu_in_bytes:
raise ValueError(f"Unable to generate the memory map automatically as the needed estimated memory per gpu ({param_memory_per_gpu_in_bytes/2**30:0.2f}GB) is bigger than the available per gpu memory ({max_memory_per_gpu_in_bytes/2**30:0.2f}GB)")

return {i: param_memory_per_gpu_in_bytes for i in range(torch.cuda.device_count())}

t_start = time.time()

num_tokens = 100
Expand Down Expand Up @@ -122,30 +84,25 @@ def get_max_memory_per_gpu_dict(dtype, model_name):
if rank == 0:
print(f"Generate args {generate_kwargs}")
inputs = input_sentences[:args.batch_size]
def generate():
""" returns a list of zipped inputs, outputs and number of new tokens """

input_tokens = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to("cuda:0")

outputs = model.generate(**input_tokens, **generate_kwargs)

input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
output_tokens_lengths = [x.shape[0] for x in outputs]

total_new_tokens = [o-i for i,o in zip(input_tokens_lengths, output_tokens_lengths)]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

return zip(inputs, outputs, total_new_tokens)

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
_ = generate()
_ = generate_(
inputs,
model,
tokenizer,
generate_kwargs,
"cuda:0"
)

t_generate_start = time.time()
generated = generate()
generated = generate_(
inputs,
model,
tokenizer,
generate_kwargs,
"cuda:0"
)
t_generate_span = time.time() - t_generate_start
if rank == 0:
for i,o,_ in generated:
Expand All @@ -164,15 +121,27 @@ def generate():

# warm up
for i in range(1):
_ = generate()
_ = generate_(
inputs,
model,
tokenizer,
generate_kwargs,
"cuda:0"
)
torch.cuda.synchronize()

# benchmark
t0 = time.time()
cycles = 5
total_new_tokens_generated = 0
for i in range(cycles):
generated = generate()
generated = generate_(
inputs,
model,
tokenizer,
generate_kwargs,
"cuda:0"
)
total_new_tokens_generated += sum(new_tokens for _,_,new_tokens in generated)
torch.cuda.synchronize()
if rank == 0:
Expand Down
Loading