forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BYOC][TensorRT] Add TensorRT own int8 calibration support to TensorR…
…T BYOC integration (apache#8808) * update trt * clean codes * tetsing running trt * clean data * clean codes? * remove env func * fix num_bings * add buildfromjson func * change condition * reset input and output func * re-config func * re-added trt version check * checking sanity * try to fix sanity issue * checking sainity * fixing sanity issue * fixing sainity issue * fixing sanity * clang format fixed * clang format fixing * clean trt cali * try to fix clang format * fixed some comments * remove double destroy engine codes * modify comments * add checking function * add trt int8 test * update trt int8 test file * Update test_tensorrt_int8_exp.py * update trt int8 fikle * change a little * upate trt int8 file * upate trt int8 file * fixing ci * fixing ci * fixing ci * fixing ci * fixing ci * fixing ci issue * fixing ci issue * fixing ci * fixing ci issue * fixing ci * fixing ci problem * fixing ci * upate trt python int8 test file * fixed ci * fixed ci * fix gpu build * fixed ci * update trt int8 test file * fix bug * fix bug * update trtint8 file * reformat * update trt int8 file * update * modify
- Loading branch information
1 parent
128d3dd
commit 95860bb
Showing
5 changed files
with
399 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* * Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
* file runtime/contrib/tensorrt/tensorrt_builder.h | ||
* brief Contains TensorRTBuilder class which can be used to convert a relay | ||
* program into a TRT engine which can be used for inference. | ||
*/ | ||
|
||
#ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ | ||
#define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "../../cuda/cuda_common.h" | ||
#include "NvInfer.h" | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
||
class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { | ||
public: | ||
TensorRTCalibrator(int batch_size, const std::vector<std::string>& input_names) | ||
: batch_size_(batch_size), num_batches_calibrated_(0), input_names_(input_names) {} | ||
|
||
~TensorRTCalibrator() { | ||
// Free calibration data | ||
for (auto& inputs : data_) { | ||
for (size_t i = 0; i < inputs.size(); ++i) { | ||
delete[] inputs[i]; | ||
} | ||
} | ||
// Free buffers | ||
for (size_t i = 0; i < buffers_.size(); ++i) { | ||
CUDA_CALL(cudaFree(buffers_[i])); | ||
} | ||
} | ||
|
||
void AddBatchData(const std::vector<void*>& bindings, const std::vector<size_t>& binding_sizes) { | ||
// Copy data from GPU | ||
std::vector<float*> data_host(bindings.size(), nullptr); | ||
for (size_t i = 0; i < bindings.size(); ++i) { | ||
data_host[i] = new float[batch_size_ * binding_sizes[i]]; | ||
CUDA_CALL(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i], | ||
batch_size_ * binding_sizes[i] * sizeof(float), cudaMemcpyDeviceToHost)); | ||
} | ||
data_.push_back(data_host); | ||
data_sizes_.push_back(binding_sizes); | ||
} | ||
|
||
int getBatchSize() const override { return batch_size_; } | ||
|
||
/*! | ||
* \brief TensorRT will call this method to get next batch of data to | ||
* calibrate with. | ||
*/ | ||
bool getBatch(void* bindings[], const char* names[], int nbBindings) override { | ||
AllocateBuffersIfNotAllocated(); | ||
CHECK_EQ(input_names_.size(), nbBindings); | ||
for (size_t i = 0; i < input_names_.size(); ++i) { | ||
CHECK_EQ(input_names_[i], names[i]); | ||
CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i], | ||
batch_size_ * data_sizes_[num_batches_calibrated_][i] * sizeof(float), | ||
cudaMemcpyHostToDevice)); | ||
bindings[i] = buffers_[i]; | ||
} | ||
num_batches_calibrated_++; | ||
// TODO(trevmorr): Free data from previous batch? | ||
return (num_batches_calibrated_ < data_.size()); | ||
} | ||
|
||
const void* readCalibrationCache(size_t& length) override { | ||
if (calibration_cache_.empty()) return nullptr; | ||
length = calibration_cache_.size(); | ||
return calibration_cache_.data(); | ||
} | ||
|
||
void writeCalibrationCache(const void* cache, size_t length) override { | ||
calibration_cache_.assign(static_cast<const char*>(cache), length); | ||
} | ||
|
||
private: | ||
/*! \brief Batch size. */ | ||
int batch_size_; | ||
/*! \brief Number of batches already fed to calibrator. */ | ||
int num_batches_calibrated_; | ||
/*! \brief Storage for calibration cache. */ | ||
std::string calibration_cache_; | ||
|
||
/*! \brief Data to be used for calibration. */ | ||
std::vector<std::vector<float*>> data_; | ||
/*! \brief Number of elements for data to be used for calibration. */ | ||
std::vector<std::vector<size_t>> data_sizes_; | ||
|
||
/*! \brief Device buffers to be used for calibration. */ | ||
std::vector<void*> buffers_; | ||
|
||
/*! \brief Names of inputs */ | ||
const std::vector<std::string> input_names_; | ||
|
||
/*! \brief Allocate device memory buffers. data_sizes_ must already have one | ||
* entry. */ | ||
void AllocateBuffersIfNotAllocated() { | ||
if (!buffers_.empty()) return; | ||
CHECK_GE(data_sizes_.size(), 1); | ||
const int num_inputs = data_sizes_[0].size(); | ||
buffers_.assign(num_inputs, nullptr); | ||
for (int i = 0; i < num_inputs; ++i) { | ||
CUDA_CALL(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float))); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace runtime | ||
} // namespace tvm | ||
#endif // TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_ |
Oops, something went wrong.