Skip to content

Commit

Permalink
[Prim][PIR] support roll, gather, scatter, scatter_nd_add op backward…
Browse files Browse the repository at this point in the history
… in pir prim (#60481)

* prim gather op backward

* prim scatter op backward

* prim roll op backward

* prim scatter_nd op backward
  • Loading branch information
kevincheng2 authored Jan 2, 2024
1 parent cfad7d2 commit 5e2a3db
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 14 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@
'sum_grad',
'cast_grad',
'reshape_grad',
'roll_grad',
'split_grad',
'transpose_grad',
'concat_grad',
'expand_grad',
'gather_grad',
'gather_nd_grad',
'pad_grad',
'max_grad',
'scatter_grad',
'scatter_nd_add_grad',
'slice_grad',
'tile_grad',
'topk_grad',
Expand Down
100 changes: 100 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ void reshape_grad(const Tensor& xshape,
}
}

template <typename T>
void roll_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& shifts,
const std::vector<int64_t>& axis,
Tensor* x_grad) {
if (x_grad) {
auto shifts_ = shifts.GetData();
int64_t nums = shifts_.size();
for (int64_t i = 0; i < nums; i++) {
shifts_[i] = 0 - shifts_[i];
}
auto x_grad_output = roll<T>(out_grad, shifts_, axis);
set_output<T>(x_grad_output, x_grad);
}
}

template <typename T>
void transpose_grad(const Tensor& grad_out,
const std::vector<int>& perm,
Expand All @@ -262,6 +279,43 @@ void transpose_grad(const Tensor& grad_out,
}
}

template <typename T>
void scatter_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
bool overwrite,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
auto zero_tensor =
full<T>(common::vectorize(updates.dims()), 0.0, updates.dtype());
auto tmp_grad = scatter<T>(out_grad, index, zero_tensor, false);
set_output<T>(tmp_grad, x_grad);
}

if (updates_grad) {
Scalar tmp_zero = 0;
auto tmp_updates_grad = gather<T>(out_grad, index, tmp_zero);
set_output<T>(tmp_updates_grad, updates_grad);
}
}

template <typename T>
void scatter_nd_add_grad(const Tensor& index,
const Tensor& updates,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* updates_grad) {
if (x_grad) {
by_pass<T>(out_grad, x_grad);
}
if (updates_grad) {
// Gradient by Gather: dUpdates = dO[Ids]
auto tmp_updates_grad = gather_nd<T>(out_grad, index);
set_output<T>(tmp_updates_grad, updates_grad);
}
}

template <typename T>
void sin_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
auto x_grad_tmp = cos<T>(x) * out_grad;
Expand Down Expand Up @@ -818,6 +872,52 @@ void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
const Scalar& axis,
Tensor* grad_x) {
auto zero_tensor = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int> tmp_perm;

// change axis to rank 0
int axis_value = axis.to<int>();
tmp_perm.push_back(axis_value);
// make other ranks
for (int i = 0; i < x.dims().size(); ++i) {
if (i != axis_value) {
tmp_perm.push_back(i);
}
}
std::vector<int> reverse_perm(tmp_perm);
// make origin ranks
for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
if (tmp_perm[i] >= 0) {
reverse_perm[tmp_perm[i]] = i;
} else {
reverse_perm[tmp_perm[i] + tmp_perm.size()] = i;
}
}

// transpose out_grad and zero grad to target rank.
auto tmp_zero_x_grad = zero_tensor;
auto tmp_out_grad = out_grad;
if (zero_tensor.dims().size() > 0) {
tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
}
if (out_grad.dims().size() > 0) {
tmp_out_grad = transpose<T>(out_grad, tmp_perm);
}
// scatter grad to grad_x
auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
auto tmp_grad_x_tranposed = tmp_grad_x;
if (tmp_grad_x.dims().size() > 0) {
tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
}
set_output<T>(tmp_grad_x_tranposed, grad_x);
}

template <typename T>
void gather_nd_grad(const Tensor& x,
const Tensor& index,
Expand Down
11 changes: 9 additions & 2 deletions test/legacy_test/test_gather_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)
self.check_grad(
['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True
)

def config(self):
"""
Expand Down Expand Up @@ -119,7 +121,12 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad_with_place(
paddle.CUDAPlace(0), ['X'], 'Out', check_prim=True, check_pir=True
paddle.CUDAPlace(0),
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down
18 changes: 15 additions & 3 deletions test/legacy_test/test_roll_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def test_check_output(self):
self.check_output(check_prim=True, check_pir=True)

def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)
self.check_grad(
['X'], 'Out', check_prim=True, check_pir=True, check_prim_pir=True
)


class TestRollOpCase2(TestRollOp):
Expand Down Expand Up @@ -139,7 +141,12 @@ def test_check_output(self):

def test_check_grad_normal(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', check_prim=True, check_pir=True
self.place,
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand All @@ -163,7 +170,12 @@ def test_check_output(self):

def test_check_grad_normal(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', check_prim=True, check_pir=True
self.place,
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down
26 changes: 22 additions & 4 deletions test/legacy_test/test_scatter_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
['X', 'Updates'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -133,7 +137,12 @@ def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True
place,
['X', 'Updates'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -176,7 +185,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
['X', 'Updates'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -211,7 +224,12 @@ def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Updates'], 'Out', check_prim=True, check_pir=True
place,
['X', 'Updates'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down
38 changes: 33 additions & 5 deletions test/legacy_test/test_scatter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
["X", "Updates"], "Out", check_prim=True, check_pir=True
["X", "Updates"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -92,6 +96,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -128,7 +133,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
["X", "Updates"], "Out", check_prim=True, check_pir=True
["X", "Updates"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -163,6 +172,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -202,7 +212,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
["X", "Updates"], "Out", check_prim=True, check_pir=True
["X", "Updates"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -237,6 +251,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -284,6 +299,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -356,6 +372,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -412,7 +429,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
['X', 'Updates'], 'Out', check_prim=True, check_pir=True
['X', 'Updates'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -447,6 +468,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -494,6 +516,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -550,7 +573,11 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
["X", "Updates"], "Out", check_prim=True, check_pir=True
["X", "Updates"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down Expand Up @@ -585,6 +612,7 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
)


Expand Down

0 comments on commit 5e2a3db

Please sign in to comment.