Skip to content

Commit

Permalink
Merge release/5.1 changes into master
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
rajeevsrao committed Aug 27, 2019
1 parent 296ee92 commit 443e495
Show file tree
Hide file tree
Showing 71 changed files with 7,299 additions and 2,877 deletions.
33 changes: 19 additions & 14 deletions demo/BERT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \
-O3")

set(BERT_LIBS
common
bert_plugins
cudart
cublas
nvinfer
pthread
z
cublas
)
)

include_directories(
../../include
Expand All @@ -45,25 +43,30 @@ include_directories(
/workspace/cutlass/
)

link_directories(
/usr/local/cuda-10.1/targets/x86_64-linux/lib
/tensorrt/lib
)

add_library(common SHARED
../../samples/common/logger.cpp
util/data_utils.cpp
util/dataUtils.cpp
)

add_library(bert_plugins SHARED
plugins/gelu_plugin.cu
plugins/skip_layer_norm_plugin.cu
plugins/qkv2context_plugin.cu
plugins/emb_layer_norm_plugin.cu
plugins/geluPlugin.cu
plugins/skipLayerNormPlugin.cu
plugins/qkvToContextPlugin.cu
plugins/embLayerNormPlugin.cu
)


link_directories(
/usr/local/cuda-10.1/targets/x86_64-linux/lib
/tensorrt/lib
target_link_libraries(bert_plugins
${BERT_LIBS}
)

target_link_libraries(common
${BERT_LIBS}
)

add_executable(sample_bert
sampleBERT.cpp
Expand All @@ -72,5 +75,7 @@ add_executable(sample_bert
target_compile_features(sample_bert PUBLIC cxx_std_11)

target_link_libraries(sample_bert
${BERT_LIBS}
common
bert_plugins
)

4 changes: 2 additions & 2 deletions demo/BERT/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ RUN echo $mygid
RUN groupadd -r -g ${mygid} nb && useradd -r -u ${myuid} -g ${mygid} -ms /bin/bash nb

RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository ppa:ubuntu-toolchain-r/test
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 sudo gcc-7 g++-7 zlib1g-dev
RUN apt-get update && apt-get install -y pbzip2 pv bzip2 sudo gcc-7 g++-7 zlib1g-dev g++-4.9
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 \
--slave /usr/bin/g++ g++ /usr/bin/g++-7 && \
update-alternatives --config gcc

RUN wget https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh && \
sh cmake-3.14.0-Linux-x86_64.sh --prefix=/usr/local --exclude-subdir
RUN pip install tensorflow==1.13.1 horovod
RUN pip install tensorflow==1.13.1 && pip install horovod

RUN echo 'nb:abc123' | chpasswd

Expand Down
76 changes: 41 additions & 35 deletions demo/BERT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To build the TensorRT OSS components, ensure you meet the following package requ
* [CUDA](https://developer.nvidia.com/cuda-toolkit)
* Recommended versions:
* [cuda-10.1](https://developer.nvidia.com/cuda-10.1-download-archive-base) + cuDNN-7.5

* [GNU Make](https://ftp.gnu.org/gnu/make/) >= v4.1

* [CMake](https://github.com/Kitware/CMake/releases) >= v3.8
Expand Down Expand Up @@ -46,78 +46,84 @@ To build the TensorRT OSS components, ensure you meet the following package requ
* [TensorRT](https://developer.nvidia.com/nvidia-tensorrt-5x-download) v5.1.5


## Building the example

The example was tested in a docker container based on NVIDIA NGC images, that provides the required dependencies, such as CUDA, CuDNN and TensorRT.

This example uses `cmake` and can be built with the following steps:
```
mkdir build
cd build
cmake ..
make -j
```

This will produce an executable `sample_bert` in the `build` folder.


## Example Workflow

The example provides scripts to convert fine-tuned Tensorflow model checkpoints into a simple binary format that can be read by sample binary.

The high-level workflow consists of the following steps:
1. Download the BERT reference code and pre-trained language model
2. Run the fine-tuning script for squad to obtain task specific network weights
3. Convert the fine-tuned checkpoint into our simple format, described in the appendix (the original weights are assumed to be float32 values)
4. Generate a test input/output pair (input sequences are assumed to be int32 values)
5. run the sample

1. Download a pre-trained BERT SQuAD checkpoint from NGC model registry (See optional section if you would like to train your own model)
2. Convert the fine-tuned checkpoint into our simple format, described in the appendix (the original weights are assumed to be float32 values)
3. Generate a test input/output pair (input sequences are assumed to be int32 values)
4. Build and run the sample

### Downloading the BERT reference code and pre-trained language model, and running SQuAD Fine-tuning
### 1. Download a pre-trained BERT SQuAD checkpoint from NGC model registry
```
wget -O bert-base-squad1.1.zip https://api.ngc.nvidia.com/v2/models/nvidia/bert_tf_v1_1_base_fp32_128/versions/2/zip
unzip bert-base-squad1.1.zip -d squad_output_path
```

Below, we will refer to the location `<squad output path>/model.ckpt-<number>` as shell variable `CHECKPOINT` and the path to the folder that contains the `bert_config.json` as `BERT_PATH`.


#### (Optional) Downloading the BERT reference code and pre-trained language model, and running SQuAD Fine-tuning

Please follow the instructions in the [DeepLearningExamples repository](https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT) for fine-tuning SQuAD, which involves downloading the pre-trained language model as well as the SQuAD training data.

Then, in the scripts folder, there is `run_squad.sh` script, that adds a SQuAD-specific task layer to BERT and performs the fine-tuning.

This will create three files prefixed with `model.ckpt-<number>` that contain the fine-tuned model parameters, in the specified output directory.

Below, we will refer to the location `<squad output path>/model.ckpt-<number>` as shell variable `CHECKPOINT` and the path to the folder that contains the `bert_config.json` as `BERT_PATH`.

### Convert the fine-tuned checkpoint into a simple format
### 2. Convert the fine-tuned checkpoint into a simple format
Python scripts in step 2 and 3 require Tensorflow on the system. We tested using tensorflow:19.06-py3 NGC container image.

After the fine-tuning is complete, the generated Tensorflow checkpoint can be converted using the following command:
The SQuAD fine-tuned Tensorflow checkpoint can be converted using the following command:

```
python python/convert_weights.py -m $CHECKPOINT -o <weight path>/filename
python helpers/convert_weights.py -m $CHECKPOINT -o <weight path>/filename
```

This will generate a file `<weight path>/<filename>.weights`. The path that contains the weights file, will be referred to as `WEIGHT_PATH`.


### Generate an input/output pair
### 3. Generate an input/output pair

To run the sample on random inputs and compare the output to the reference Tensorflow implementation, the following command produces test inputs and outputs:

```python python/generate_dbg.py -f $CHECKPOINT -p $BERT_PATH -o $OUTPUT_PATH -s <seq.len.> -b <batch size>```
```python helpers/generate_dbg.py -f $CHECKPOINT -p $BERT_PATH -o $OUTPUT_PATH -s <seq.len.> -b <batch size>```

Please refer to the help of `generate_dbg.py` for more options.

### Running the example

### 4. Build and run the example

The C++ example was tested using TensorRT OSS docker container image created by following the instruction [in this link](https://github.com/NVIDIA/TensorRT#setting-up-the-build-environment)

This example uses `cmake` and can be built with the following steps:
```
mkdir build
cd build
cmake ..
make -j
```

This will produce an executable `sample_bert` in the `build` folder.

The binary `sample_bert` requires as arguments the paths that contain `bert_config.json` (from the pre-trained BERT checkpoint), `bert.weights` and `test_inputs.weights_int32` and `test_outputs.weights` as generated by the steps above.

```build/sample_bert -d $WEIGHT_PATH -d $OUTPUT_PATH --fp16 --nheads <num heads>```
```build/sample_bert -d $WEIGHT_PATH -d $OUTPUT_PATH --fp16 --nheads <num_attention_heads>```

`<num heads>` refers to the number of attention heads and can be found in the `bert_config.json`.
`<num_attention_heads>` refers to the number of attention heads and can be found in the `bert_config.json`.

# Appendix

## Description of the binary format
## Description of the binary format

The example expects weights and inputs in a simple tensor dictionary format.
The example expects weights and inputs in a simple tensor dictionary format.
It consists of an integer in the first line `N` denoting the number of entries in the dictionary.
Then there are `N` lines, each line following the format
`[tensor name: String] [element type: DataType] [number of dimensions D: int] [dim1, dim2, ..., dimD] [binary tensor data]\n`
DataType is the `nvinfer1` enumeration, that encodes types as numbers. E.g. `DataType::kFLOAT = 0` (float32) and `DataType::kINT32 = 3`.
DataType is the `nvinfer1` enumeration, that encodes types as numbers. E.g. `DataType::kFLOAT = 0` (float32) and `DataType::kINT32 = 3`.
The binary tensor data is dim1 * dim2 * ... * dimD * sizeof(type) bytes followed by a line break.
Methods to read this format can be found in `data_utils.hpp`
Methods to read this format can be found in `dataUtils.hpp`
1 change: 1 addition & 0 deletions demo/BERT/docker/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ docker run -it --rm \
-u $(id -u):$(id -g) \
-v ${HOME}:/host/ \
-v $1:/data/ \
-v $(pwd)/../../:/workspace/TensorRT \
sample-bert bash


Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.


# The reference bert implementation preprocesses the input e.g. a squad dataset
# generating a tf record dataset of token ids, input masks and segment ids.
# Given such a dataset (e.g. eval.tf_record), this script converts the tf records
# into a simpler binary format readable by the bert sample.

import sys
import struct
Expand Down Expand Up @@ -64,29 +69,26 @@ def _decode_record(record, name_to_features):

return example

try:
raw_dataset = tf.data.TFRecordDataset([inputbase])
outputFileName = outputbase + ".weights_int32"
with open(outputFileName, 'wb') as outputFile:
raw_dataset = tf.data.TFRecordDataset([inputbase])
out_fn = outputbase + ".weights_int32"
with open(out_fn, 'wb') as output_file:

count = raw_dataset.reduce(0, lambda x,y: x+ 1).numpy()
print(count)

output_file.write("{}\n".format(count).encode('ASCII'))

for idx, record in enumerate(raw_dataset):
dec = _decode_record(record, name_to_features)

count = raw_dataset.reduce(0, lambda x,y: x+ 1).numpy()
print(count)
for k,v in dec.items():
a = v.numpy()
outname = '{}_{}'.format(k, idx)

shape = a.shape
shape_str = '{} '.format(len(shape)) + ' '.join([str(d) for d in shape])

outputFile.write("{}\n".format(count).encode('ASCII'))
output_file.write("{} 3 {} ".format(outname, shape_str).encode('ASCII'))
output_file.write(a.tobytes())
output_file.write("\n".encode('ASCII'));

for idx, record in enumerate(raw_dataset):
dec = _decode_record(record, name_to_features)

for k,v in dec.items():
a = v.numpy()
outname = '{}_{}'.format(k, idx)

shape = a.shape
shape_str = '{} '.format(len(shape)) + ' '.join([str(d) for d in shape])

outputFile.write("{} 3 {} ".format(outname, shape_str).encode('ASCII'))
outputFile.write(a.tobytes())
outputFile.write("\n".encode('ASCII'));

except Exception as error:
print(str(error))
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -38,37 +40,34 @@
inputbase = opt.model
outputbase = opt.output

try:
reader = pyTF.NewCheckpointReader(inputbase)
tensorDict = reader.get_variable_to_shape_map()
outputFileName = outputbase + ".weights"
with open(outputFileName, 'wb') as outputFile:
reader = pyTF.NewCheckpointReader(inputbase)
tensor_dict = reader.get_variable_to_shape_map()
out_fn = outputbase + ".weights"
with open(out_fn, 'wb') as output_file:

# there might be training-related variables in the checkpoint that can be discarded
paramNames = [key for key in sorted(tensorDict) if 'adam' not in key and 'global_step' not in key and 'pooler' not in key]
# there might be training-related variables in the checkpoint that can be discarded
param_names = [key for key in sorted(tensor_dict) if 'adam' not in key and 'global_step' not in key and 'pooler' not in key]

count = len(paramNames)
print(count)
count = len(param_names)
print(count)

outputFile.write('{}\n'.format(count).encode('ASCII'))
for pn in paramNames:
toks = pn.lower().split('/')
if 'encoder' in pn:
assert('layer' in pn)
l = (re.findall('\d+',pn))[0]
outname = 'l{}_'.format(l) + '_'.join(toks[3:])
else:
outname = '_'.join(toks)
output_file.write('{}\n'.format(count).encode('ASCII'))
for pn in param_names:
toks = pn.lower().split('/')
if 'encoder' in pn:
assert('layer' in pn)
l = (re.findall('\d+',pn))[0]
outname = 'l{}_'.format(l) + '_'.join(toks[3:])
else:
outname = '_'.join(toks)

tensor = reader.get_tensor(pn)
shape = tensor.shape
flat_tensor = tensor.flatten()
shape_str = '{} '.format(len(shape)) + ' '.join([str(d) for d in shape])
tensor = reader.get_tensor(pn)
shape = tensor.shape
flat_tensor = tensor.flatten()
shape_str = '{} '.format(len(shape)) + ' '.join([str(d) for d in shape])

outputFile.write('{} 0 {} '.format(outname, shape_str).encode('ASCII'))
outputFile.write(flat_tensor.tobytes())
outputFile.write('\n'.encode('ASCII'));
print('Orig.name:', pn,'TRT name:', outname, 'shape:' , shape_str)
output_file.write('{} 0 {} '.format(outname, shape_str).encode('ASCII'))
output_file.write(flat_tensor.tobytes())
output_file.write('\n'.encode('ASCII'));
print('Orig.name:', pn,'TRT name:', outname, 'shape:' , shape_str)

except Exception as error:
print(str(error))
Loading

0 comments on commit 443e495

Please sign in to comment.