Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the grad and enhance the cache of norm_convolution fusion ops. #36168

Merged
merged 8 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/operator_kernel_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ limitations under the License. */
#pragma once

#include <algorithm>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"

namespace paddle {
namespace framework {
Expand Down
65 changes: 34 additions & 31 deletions paddle/fluid/operators/fused/cudnn_fusion_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"

Expand All @@ -41,12 +39,9 @@ class CudnnFusionOp {
}

~CudnnFusionOp() {
// New 'fused op' descriptor destruction
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_));
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_);
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_);
dynload::cudnnDestroyFusedOpsPlan(op_);
}

// Execute fused op
Expand Down Expand Up @@ -121,41 +116,49 @@ class CudnnFusionOp {

// Get the workspace, which is required before Execute().
size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) {
size_t workspace_bytes = 0U;
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan(
cudnn_handle, op_, op_const_params_, &workspace_bytes));
plan_created_ = true;
return workspace_bytes;
if (!plan_created_) {
workspace_bytes_ = 0U;
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan(
cudnn_handle, op_, op_const_params_, &workspace_bytes_));
plan_created_ = true;
}
return workspace_bytes_;
}

private:
bool plan_created_;
size_t workspace_bytes_;

cudnnFusedOpsPlan_t op_;
cudnnFusedOpsConstParamPack_t op_const_params_;
cudnnFusedOpsVariantParamPack_t op_variant_params_;
};

static inline std::vector<int> GetStrides(const std::vector<int> &shape) {
if (shape.size() < 1) {
return {};
class CudnnFusionOpCache {
public:
static CudnnFusionOpCache &Instance() {
static CudnnFusionOpCache instance;
return instance;
}

framework::AlgorithmsCache<CudnnFusionOp *> *GetForward() {
return &forward_cache_;
}
int dim = static_cast<int>(shape.size());
std::vector<int> pro_shape(shape);
std::vector<int> strides(dim);
int temp = pro_shape[1];
pro_shape.erase(pro_shape.begin() + 1);
pro_shape.push_back(temp);
strides.back() = 1;
for (int i = dim - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * pro_shape[i + 1];
framework::AlgorithmsCache<CudnnFusionOp *> *GetBackward() {
return &backward_cache_;
}
strides.pop_back();
strides.insert(strides.begin() + 1, 1);
return strides;
}

static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; }
private:
CudnnFusionOpCache() {}
~CudnnFusionOpCache() {
// Need to delete the memory of cache.
}
CudnnFusionOpCache(const CudnnFusionOpCache &) {}

private:
framework::AlgorithmsCache<CudnnFusionOp *> forward_cache_;
framework::AlgorithmsCache<CudnnFusionOp *> backward_cache_;
};

#endif // CUDNN_VERSION >= 8000
} // namespace operators
Expand Down
Loading