Skip to content

Commit

Permalink
fix dcnv2 trt8 compile error (#36850)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxinxin08 authored Oct 29, 2021
1 parent f3ee5c9 commit 82fb63e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ void gemm_impl<half>(cublasHandle_t handle, cublasOperation_t transa,
template <typename T>
int DeformableConvPlugin::enqueue_impl(int batch_size,
const void* const* inputs,
void** outputs, void* workspace,
void* const* outputs, void* workspace,
cudaStream_t stream) {
const T* input = reinterpret_cast<const T*>(inputs[0]);
const T* offset = reinterpret_cast<const T*>(inputs[1]);
Expand Down Expand Up @@ -527,8 +527,6 @@ nvinfer1::IPluginV2Ext* DeformableConvPlugin::clone() const TRT_NOEXCEPT {
offset_dim_, mask_dim_, output_dim_);
}

DeformableConvPluginCreator::DeformableConvPluginCreator() TRT_NOEXCEPT {}

void DeformableConvPluginCreator::setPluginNamespace(const char* lib_namespace)
TRT_NOEXCEPT {
namespace_ = std::string(lib_namespace);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {

private:
template <typename T>
int enqueue_impl(int batch_size, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream);
int enqueue_impl(int batch_size, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream);
nvinfer1::Weights copyToDevice(const void* hostData, size_t count);
void serializeFromDevice(void** hostBuffer,
const nvinfer1::Weights& deviceWeights) const;
Expand All @@ -119,7 +119,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext {

class DeformableConvPluginCreator : public nvinfer1::IPluginCreator {
public:
DeformableConvPluginCreator();
DeformableConvPluginCreator() = default;
~DeformableConvPluginCreator() override = default;

void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override;
Expand Down

0 comments on commit 82fb63e

Please sign in to comment.