Skip to content

Commit

Permalink
Merge pull request #48 from dongwenxin2046/paddlebox
Browse files Browse the repository at this point in the history
model server 3thrd
  • Loading branch information
qingshui authored Oct 8, 2022
2 parents 2dc67a7 + 3641fc1 commit 6466203
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 1 deletion.
34 changes: 34 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,40 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<AnalysisConfig>(
config);
}

namespace experimental {

void InternalUtils::SyncStream(paddle::PaddlePredictor *p) {
#ifdef PADDLE_WITH_CUDA
auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p);
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto *dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext *>(
pool.Get(pred->place_));
cudaStreamSynchronize(dev_ctx->stream());
#endif
}
bool InternalUtils::QueryStream(paddle::PaddlePredictor *p) {
#ifdef PADDLE_WITH_CUDA
auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p);
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto *dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext *>(
pool.Get(pred->place_));
return (cudaSuccess == cudaStreamQuery(dev_ctx->stream()));
#endif
}
void InternalUtils::SyncStream(cudaStream_t stream) {
#ifdef PADDLE_WITH_CUDA
cudaStreamSynchronize(stream);
#endif
}
bool InternalUtils::QueryStream(cudaStream_t stream) {
#ifdef PADDLE_WITH_CUDA
return (cudaSuccess == cudaStreamQuery(stream));
#endif
}

} // namespace experimental
} // namespace paddle

#if PADDLE_WITH_TENSORRT
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
using framework::NaiveExecutor;

namespace experimental {
using float16 = paddle::platform::float16;
using PlaceType = paddle::PaddlePlace;
class InternalUtils;
};

///
/// \class AnalysisPredictor
///
Expand Down Expand Up @@ -416,6 +422,7 @@ class AnalysisPredictor : public PaddlePredictor {
// Some status here that help to determine the status inside the predictor.
bool status_is_cloned_{false};
bool status_use_gpu_{false};
friend class paddle::experimental::InternalUtils;
};

} // namespace paddle
127 changes: 127 additions & 0 deletions paddle/fluid/inference/api/details/zero_copy_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {

Expand Down Expand Up @@ -230,4 +231,130 @@ std::vector<std::vector<size_t>> ZeroCopyTensor::lod() const {
return res;
}

namespace experimental {

using float16 = paddle::platform::float16;

template <typename T>
void InternalUtils::CopyFromCpuWithIoStream(paddle::ZeroCopyTensor *t,
const T *data,
cudaStream_t stream) {
if (t->tensor_ == nullptr) {
PADDLE_ENFORCE_EQ(
t->name_.empty(), false,
paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
auto *var = scope->FindVar(t->name_);
PADDLE_ENFORCE_NOT_NULL(
var, paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", t->name_));
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
t->tensor_ = tensor;
}

auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
PADDLE_ENFORCE_GE(tensor->numel(), 0,
paddle::platform::errors::PreconditionNotMet(
"You should call Tensor::Reshape(const "
"std::vector<int> &shape)"
"function before copying data from cpu."));
size_t ele_size = tensor->numel() * sizeof(T);
if (t->place_ == PlaceType::kCPU) {
auto *t_data = tensor->mutable_data<T>(paddle::platform::CPUPlace());
std::memcpy(static_cast<void *>(t_data), data, ele_size);
} else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::platform::CUDAPlace gpu_place(t->device_);
auto *t_data = tensor->mutable_data<T>(gpu_place);
paddle::memory::Copy(gpu_place, static_cast<void *>(t_data),
paddle::platform::CPUPlace(), data, ele_size, stream);
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"CopyFromCpuWithIoStream only supports CPU and GPU now."));
}
}

template <typename T>
void InternalUtils::CopyToCpuWithIoStream(paddle::ZeroCopyTensor *t, T *data,
cudaStream_t stream) {
if (t->tensor_ == nullptr) {
PADDLE_ENFORCE_EQ(
t->name_.empty(), false,
paddle::platform::errors::PreconditionNotMet(
"Need to SetName first, so that the corresponding tensor can "
"be retrieved."));
auto *scope = static_cast<paddle::framework::Scope *>(t->scope_);
auto *var = scope->FindVar(t->name_);
PADDLE_ENFORCE_NOT_NULL(
var, paddle::platform::errors::PreconditionNotMet(
"No tensor called [%s] in the runtime scope", t->name_));
auto *tensor = var->GetMutable<paddle::framework::LoDTensor>();
t->tensor_ = tensor;
}

auto *tensor = static_cast<paddle::framework::LoDTensor *>(t->tensor_);
auto ele_num = tensor->numel();
auto *t_data = tensor->data<T>();
auto t_place = tensor->place();

paddle::framework::Tensor out;
auto mem_allocation =
std::make_shared<paddle::memory::allocation::Allocation>(
static_cast<void *>(data), ele_num * sizeof(T),
paddle::platform::CPUPlace());
out.ResetHolder(mem_allocation);

if (paddle::platform::is_cpu_place(t_place)) {
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
} else if (t->place_ == PlaceType::kGPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, t_place);
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data), gpu_place, t_data,
ele_num * sizeof(T), stream);
#else
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Can not create tensor with CUDA place because paddle is not compiled "
"with CUDA."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"CopyToCpuWithIoStream only supports CPU and GPU now."));
}
}

template void InternalUtils::CopyFromCpuWithIoStream<float>(
paddle::ZeroCopyTensor *t, const float *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int64_t>(
paddle::ZeroCopyTensor *t, const int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int32_t>(
paddle::ZeroCopyTensor *t, const int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<uint8_t>(
paddle::ZeroCopyTensor *t, const uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<int8_t>(
paddle::ZeroCopyTensor *t, const int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyFromCpuWithIoStream<float16>(
paddle::ZeroCopyTensor *t, const float16 *data, cudaStream_t stream);

template void InternalUtils::CopyToCpuWithIoStream<float>(
paddle::ZeroCopyTensor *t, float *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int64_t>(
paddle::ZeroCopyTensor *t, int64_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int32_t>(
paddle::ZeroCopyTensor *t, int32_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<uint8_t>(
paddle::ZeroCopyTensor *t, uint8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<int8_t>(
paddle::ZeroCopyTensor *t, int8_t *data, cudaStream_t stream);
template void InternalUtils::CopyToCpuWithIoStream<float16>(
paddle::ZeroCopyTensor *t, float16 *data, cudaStream_t stream);

} // namespace experimental
} // namespace paddle
30 changes: 29 additions & 1 deletion paddle/fluid/inference/api/paddle_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@
#include <memory>
#include <string>
#include <vector>
#include <cuda_runtime.h>
#include "crypto/cipher.h"
#include "paddle_infer_declare.h" // NOLINT
/*! \namespace paddle
*/
namespace paddle {

namespace experimental {
class InternalUtils;
};

/// \brief Paddle data type.
enum PaddleDType {
FLOAT32,
Expand Down Expand Up @@ -233,7 +238,7 @@ class PD_INFER_DECL ZeroCopyTensor {
void SetName(const std::string& name) { name_ = name; }
void* FindTensor() const;

private:
protected:
std::string name_;
bool input_or_output_;
friend class AnalysisPredictor;
Expand All @@ -244,6 +249,8 @@ class PD_INFER_DECL ZeroCopyTensor {
PaddlePlace place_;
PaddleDType dtype_;
int device_;

friend class paddle::experimental::InternalUtils;
};

/// \brief A Predictor for executing inference on a model.
Expand Down Expand Up @@ -450,4 +457,25 @@ PD_INFER_DECL std::string get_version();

PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);

namespace experimental {

using PlaceType = paddle::PaddlePlace;

// Unstable interface, may be modified or deleted in the future.
class PD_INFER_DECL InternalUtils {
public:
// Note: Can only be used under thread_local semantics.
static void SyncStream(paddle::PaddlePredictor* pred);
static bool QueryStream(paddle::PaddlePredictor* pred);
static void SyncStream(cudaStream_t stream);
static bool QueryStream(cudaStream_t stream);
template <typename T>
static void CopyFromCpuWithIoStream(paddle::ZeroCopyTensor* t, const T* data,
cudaStream_t stream);
template <typename T>
static void CopyToCpuWithIoStream(paddle::ZeroCopyTensor* t, T* data,
cudaStream_t stream);
};
} // namespace experimental

} // namespace paddle

0 comments on commit 6466203

Please sign in to comment.