Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu kernel for new api : linalg.lstsq #38621

Merged
merged 11 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
list(REMOVE_ITEM hip_srcs "qr_op.cu")
list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "lstsq_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
Expand Down
211 changes: 211 additions & 0 deletions paddle/fluid/operators/lstsq_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PADDLE_WITH_HIP
// HIP not support cusolver

#include <string>
#include <vector>
#include "paddle/fluid/operators/lstsq_op.h"
#include "paddle/fluid/operators/qr_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

template <typename DeviceContext, typename T>
class LstsqCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto* solution = context.Output<Tensor>("Solution");

auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(context);
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int min_mn = std::min(m, n);
int max_mn = std::max(m, n);
int k = min_mn;

int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int tau_stride = min_mn;
int batch_count = BatchCount(x);

Tensor new_x, new_y;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
new_y.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), &new_y);

// Prepare tau
auto tau_dims_vec = framework::vectorize<int>(x_dims);
tau_dims_vec.pop_back();
tau_dims_vec[tau_dims_vec.size() - 1] = min_mn;
Tensor tau = dito.Fill(tau_dims_vec, 0);
auto tau_data = tau.mutable_data<T>(context.GetPlace());

if (m >= n) {
Tensor tmp_x = dito.Transpose(new_x);
Tensor tmp_y = dito.Transpose(new_y);
auto x_data = tmp_x.mutable_data<T>(context.GetPlace());
auto y_data = tmp_y.mutable_data<T>(context.GetPlace());

// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx, batch_count, m, n, x_data, m,
tau_data, x_stride, tau_stride);

// Step 2, Y <- Q^H Y
BatchedOrmqr<DeviceContext, T>(dev_ctx, true, true, batch_count, m, n, k,
x_data, x_stride, tau_data, tau_stride,
y_data, y_stride);

Tensor trans_r = dito.Transpose(tmp_x);
Tensor slice_r = dito.Slice(trans_r, {-2}, {0}, {min_mn});
Tensor res_r = dito.TrilTriu(slice_r, 0, false);

Tensor trans_y = dito.Transpose(tmp_y);
Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn});

// Step 3, solve R X = Y
triangular_solve<DeviceContext, T>(dev_ctx, res_r, slice_y, solution,
true, false, false);
} else {
auto x_data = new_x.mutable_data<T>(context.GetPlace());
auto y_data = new_y.mutable_data<T>(context.GetPlace());

// step 1, compute QR factorization using geqrf
BatchedGeqrf<DeviceContext, T>(dev_ctx, batch_count, n, m, x_data, n,
tau_data, x_stride, tau_stride);

// Step 2, solve R^H Z = Y
Tensor trans_r = dito.Transpose(new_x);
triangular_solve<DeviceContext, T>(dev_ctx, trans_r, new_y, solution,
true, true, false);

// Step 3, X <- Q Z
BatchedOrgqr<DeviceContext, T>(dev_ctx, batch_count, n, n, min_mn, x_data,
n, tau_data, x_stride, tau_stride);
Tensor trans_q = dito.Transpose(new_x);
Tensor slice_q = dito.Slice(trans_q, {-1}, {0}, {m});
Tensor solu_tensor = dito.Matmul(slice_q, *solution, false, false);
framework::TensorCopy(solu_tensor, solution->place(), solution);
}
}
};

template <>
void BatchedOrmqr<platform::CUDADeviceContext, float>(
const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, float* a, int a_stride, float* tau,
int tau_stride, float* other, int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);

auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());

for (int i = 0; i < batch_size; ++i) {
float* a_working_ptr = &a[i * a_stride];
float* tau_working_ptr = &tau[i * tau_stride];
float* other_working_ptr = &other[i * other_stride];
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSormqr(
handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr,
other_working_ptr, ldc, workspace_ptr, lwork, info_d));

// check the error info
int info_h;
memory::Copy(platform::CPUPlace(), &info_h,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
info_d, sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}

template <>
void BatchedOrmqr<platform::CUDADeviceContext, double>(
const platform::CUDADeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, double* a, int a_stride, double* tau,
int tau_stride, double* other, int other_stride) {
int lwork = 0;
auto side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
int lda = std::max<int>(1, left ? m : n);
int ldc = std::max<int>(1, m);

auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr_bufferSize(
handle, side, trans, m, n, k, a, lda, tau, other, ldc, &lwork));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
auto info = memory::Alloc(dev_ctx, sizeof(int));
int* info_d = reinterpret_cast<int*>(info->ptr());

for (int i = 0; i < batch_size; ++i) {
double* a_working_ptr = &a[i * a_stride];
double* tau_working_ptr = &tau[i * tau_stride];
double* other_working_ptr = &other[i * other_stride];
// compute ormgr
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDormqr(
handle, side, trans, m, n, k, a_working_ptr, lda, tau_working_ptr,
other_working_ptr, ldc, workspace_ptr, lwork, info_d));

// check the error info
int info_h;
memory::Copy(platform::CPUPlace(), &info_h,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
info_d, sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
info_h, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver info is not zero but [%d]", i, info_h));
}
}

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
lstsq, ops::LstsqCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LstsqCUDAKernel<paddle::platform::CUDADeviceContext, double>);

#endif // not PADDLE_WITH_HIP
47 changes: 37 additions & 10 deletions paddle/fluid/operators/lstsq_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
using ValueType = math::Real<T>;

const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto y = context.Input<Tensor>("Y");
auto rcond = context.Attr<float>("rcond");
auto driver_string = context.Attr<std::string>("driver");

Expand All @@ -68,13 +68,15 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);

auto x_dims = x.dims();
auto y_dims = y.dims();
auto y_dims = y->dims();
int dim_size = x_dims.size();
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int y_stride = MatrixStride(*y);
int batch_count = BatchCount(x);
auto ori_solution_dim = solution->dims();
auto solution_dim = solution->dims();
int ori_solu_stride = MatrixStride(*solution);
int max_solu_stride = std::max(y_stride, ori_solu_stride);
int min_solu_stride = std::min(y_stride, ori_solu_stride);

// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
Expand All @@ -88,13 +90,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
Tensor new_x;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);

solution->mutable_data<T>(
context.GetPlace(),
size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), solution);

if (m < n) solution->Resize(UDDim(ori_solution_dim));
if (m >= n) {
const Tensor& new_y = *context.Input<Tensor>("Y");
framework::TensorCopy(new_y, context.GetPlace(), solution);
} else {
auto* solu_data = solution->data<T>();
auto* y_data = y->data<T>();
for (auto i = 0; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * max_solu_stride + j] = y_data[i * y_stride + j];
}
}
}

Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution);
Expand Down Expand Up @@ -186,10 +199,9 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
iwork_data = iwork.mutable_data<int>(context.GetPlace());
}

int solu_stride = std::max(y_stride, ori_solu_stride);
for (auto i = 0; i < batch_count; ++i) {
auto* x_input = &x_vector[i * x_stride];
auto* y_input = &y_vector[i * solu_stride];
auto* y_input = &y_vector[i * max_solu_stride];
rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;

Expand Down Expand Up @@ -221,9 +233,24 @@ class LstsqCPUKernel : public framework::OpKernel<T> {
Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, solution->place(), solution);

if (m >= n) solution->Resize(UDDim(ori_solution_dim));
if (m > n) {
auto* solu_data = solution->data<T>();
for (auto i = 1; i < batch_count; i++) {
for (auto j = 0; j < min_solu_stride; j++) {
solu_data[i * min_solu_stride + j] =
solu_data[i * max_solu_stride + j];
}
}
}

solution->Resize(UDDim(solution_dim));
}
};

template <typename DeviceContext, typename T>
void BatchedOrmqr(const DeviceContext& dev_ctx, bool left, bool transpose,
int batch_size, int m, int n, int k, T* a, int a_stride,
T* tau, int tau_stride, T* other, int other_stride);

} // namespace operators
} // namespace paddle
Loading