Skip to content

Commit

Permalink
fix 'BlasAXPBY unimplemented' error with custom device (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…48762)

* fix 'BlasAXPBY unimplemented' error with custom device

* fix utils CmakeLists bug
  • Loading branch information
USTCKAY authored Dec 8, 2022
1 parent 816065f commit 90d11c5
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 36 deletions.
7 changes: 5 additions & 2 deletions paddle/fluid/eager/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ set(eager_deps
eager_nan_inf_utils
grad_node_info
grad_tensor_holder
accumulation_node
custom_operator_node)

if(NOT (NOT WITH_PYTHON AND ON_INFER))
set(eager_deps ${eager_deps} accumulation_node)
endif()

set(fluid_deps
tracer
layer
Expand All @@ -33,9 +36,9 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
endif()

add_subdirectory(api)
add_subdirectory(accumulation)
add_subdirectory(custom_operator)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(accumulation)
add_subdirectory(tests)
add_subdirectory(pylayer)
cc_library(
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/eager/accumulation/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
cc_library(
accumulation_node
SRCS accumulation_node.cc
DEPS gradient_accumulator phi_api grad_node_info)
if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
accumulation_node
SRCS accumulation_node.cc
DEPS gradient_accumulator phi_api grad_node_info)
endif()
17 changes: 13 additions & 4 deletions paddle/fluid/eager/accumulation/accumulation_node.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/eager/accumulation/accumulation_node.h"

#include "glog/logging.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
Expand Down Expand Up @@ -44,8 +45,12 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
// Accumulation
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t,
tensor);
if (t.is_custom_device()) {
*tensor = add_ad_func(t, *tensor);
} else {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t,
tensor);
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
Expand All @@ -68,8 +73,12 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
paddle::experimental::Tensor tensor_values(
std::make_shared<phi::DenseTensor>(
tensor_sparse->non_zero_elements()));
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t_values, &tensor_values);
if (t.is_custom_device()) {
tensor_values = add_ad_func(t_values, tensor_values);
} else {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t_values, &tensor_values);
}
}
} else {
// TODO(jiabin): Support Other TensorBase later
Expand Down
28 changes: 20 additions & 8 deletions paddle/fluid/eager/api/utils/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
cc_library(
tensor_utils
SRCS tensor_utils.cc
DEPS phi_api autograd_meta grad_node_info accumulation_node)
cc_library(
hook_utils
SRCS hook_utils.cc
DEPS phi tensor_utils autograd_meta grad_node_info utils accumulation_node)
cc_library(
global_utils
SRCS global_utils.cc
DEPS place tracer)

if(NOT (NOT WITH_PYTHON AND ON_INFER))
cc_library(
tensor_utils
SRCS tensor_utils.cc
DEPS phi_api autograd_meta grad_node_info accumulation_node)
cc_library(
hook_utils
SRCS hook_utils.cc
DEPS phi tensor_utils autograd_meta grad_node_info utils accumulation_node)
else()
cc_library(
tensor_utils
SRCS tensor_utils.cc
DEPS phi_api autograd_meta grad_node_info)
cc_library(
hook_utils
SRCS hook_utils.cc
DEPS phi tensor_utils autograd_meta grad_node_info utils)
endif()
12 changes: 6 additions & 6 deletions paddle/fluid/eager/tests/data_structure_tests/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ cc_test_old(test_egr_ds_eager_tensor SRCS eager_tensor_test.cc DEPS
${eager_deps})
cc_test_old(test_egr_ds_auotgrad_meta SRCS autograd_meta_test.cc DEPS
${eager_deps})
cc_test_old(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS
${eager_deps})
cc_test_old(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS
${eager_deps})
cc_test_old(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS
${eager_deps})

if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test_old(test_egr_ds_grad_tensor_holder SRCS grad_tensor_holder_test.cc
DEPS ${eager_deps} ${generated_deps})
cc_test_old(test_egr_ds_grad_node_info SRCS grad_node_info_test.cc DEPS
${eager_deps} ${generated_deps})
cc_test_old(test_egr_ds_accumulation_node SRCS accumulation_node_test.cc DEPS
${eager_deps} ${generated_deps})
cc_test_old(test_egr_ds_tensor_wrapper SRCS tensor_wrapper_test.cc DEPS
${eager_deps} ${generated_deps})
endif()
24 changes: 12 additions & 12 deletions paddle/fluid/eager/tests/task_tests/CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
cc_test(
test_egr_task_tensor_utils
SRCS tensor_utils_test.cc
DEPS ${eager_deps})
cc_test(
test_egr_task_eager_utils
SRCS eager_utils_test.cc
DEPS ${eager_deps})
cc_test(
test_egr_task_forward_autograd
SRCS forward_autograd_test.cc
DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(
test_egr_task_nan_inf_utils
SRCS nan_inf_utils_test.cc
Expand Down Expand Up @@ -44,4 +32,16 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
test_egr_task_autocodegen
SRCS generated_test.cc
DEPS ${eager_deps} ${fluid_deps} ${generated_deps})
cc_test(
test_egr_task_tensor_utils
SRCS tensor_utils_test.cc
DEPS ${eager_deps} ${generated_deps})
cc_test(
test_egr_task_eager_utils
SRCS eager_utils_test.cc
DEPS ${eager_deps} ${generated_deps})
cc_test(
test_egr_task_forward_autograd
SRCS forward_autograd_test.cc
DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
endif()
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/custom_runtime/test_custom_cpu_plugin.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_custom_device(self):
self._test_eager_copy_to()
self._test_fallback_kernel()
self._test_scalar()
self._test_custom_device_gradient_accumulation()
self._test_custom_device_dataloader()
self._test_custom_device_mnist()

Expand Down Expand Up @@ -208,6 +209,60 @@ def _test_scalar(self):
k_t = paddle.to_tensor([3], dtype="int32")
value_1, indices_1 = paddle.topk(data_1, k=k_t)

def _test_custom_device_gradient_accumulation(self):
import paddle

class MNIST(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.shape = 1 * 28 * 28
self.size = 10
self.output_weight = self.create_parameter(
[self.shape, self.size]
)
self.accuracy = paddle.metric.Accuracy()

def forward(self, inputs, label=None):
x = paddle.reshape(inputs, shape=[-1, self.shape])
x = paddle.matmul(x, self.output_weight)
x = paddle.nn.functional.softmax(x)
if label is not None:
self.accuracy.reset()
correct = self.accuracy.compute(x, label)
self.accuracy.update(correct)
acc = self.accuracy.accumulate()
return x, acc
else:
return x

paddle.set_device('custom_cpu')
dataset = paddle.vision.datasets.MNIST(
mode='train',
transform=paddle.vision.transforms.Compose(
[paddle.vision.transforms.ToTensor()]
),
)
loader = paddle.io.DataLoader(
dataset, batch_size=64, num_workers=1, shuffle=True
)

mnist = MNIST()
sgd = paddle.optimizer.SGD(
learning_rate=0.01, parameters=mnist.parameters()
)

data = next(loader())
img = data[0]
label = data[1]
label_int32 = paddle.cast(label, 'int32')

pred, acc = mnist(img, label_int32)
avg_loss = paddle.nn.functional.cross_entropy(pred, label_int32)
avg_loss.backward(retain_graph=True)
avg_loss = paddle.nn.functional.cross_entropy(pred, label_int32)
avg_loss.backward()
sgd.step()


if __name__ == '__main__':
if os.name == 'nt' or sys.platform.startswith('darwin'):
Expand Down

0 comments on commit 90d11c5

Please sign in to comment.