Skip to content

Commit 44de935

Browse files
authored
[Auto Parallel] Add spmd rule No.6 for unique ops. (#72824)
* add spmd_rule for unique * update * refine code * apply review * try
1 parent 993cf9b commit 44de935

File tree

11 files changed

+292
-0
lines changed

11 files changed

+292
-0
lines changed

paddle/fluid/pybind/auto_parallel_py.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,11 @@ static void parse_single_pyobject(PyObject *obj,
944944
phi::distributed::InferSpmdContext *ctx,
945945
const size_t arg_pos) {
946946
if (PyList_Check(obj)) { // list inputs, spmd not allow tuple inputs
947+
Py_ssize_t list_size = PyList_Size(obj);
948+
if (list_size == 0) {
949+
ctx->EmplaceBackAttr(std::vector<int64_t>());
950+
return;
951+
}
947952
PyObject *first_item = PyList_GetItem(obj, 0);
948953
if (PyObject_TypeCheck(first_item, g_dist_tensor_spec_pytype)) {
949954
parse_tensors(obj, ctx, arg_pos);

paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ std::vector<int> InferSpmdContext::AttrAt(size_t idx) const {
9292
if (attr.type() == typeid(std::vector<bool>)) {
9393
std::vector<bool> val = PADDLE_GET_CONST(std::vector<bool>, attr);
9494
return std::vector<int>(val.begin(), val.end());
95+
} else if (attr.type() == typeid(std::vector<int64_t>) &&
96+
paddle::get<std::vector<int64_t>>(attr).empty()) {
97+
return std::vector<int>();
9598
} else {
9699
return paddle::get<std::vector<int>>(attr);
97100
}

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,9 @@ PD_REGISTER_SPMD_RULE(cumsum,
708708
PD_INFER_SPMD(phi::distributed::CumSumInferSpmd),
709709
PD_INFER_SPMD(phi::distributed::CumSumInferSpmdReverse));
710710

711+
// unique
712+
PD_REGISTER_SPMD_RULE(unique, PD_INFER_SPMD(phi::distributed::UniqueInferSpmd));
713+
711714
// argmin
712715
PD_REGISTER_SPMD_RULE(
713716
argmin,

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,6 @@ limitations under the License. */
7575
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
7676
#include "paddle/phi/infermeta/spmd_rules/triu.h"
7777
#include "paddle/phi/infermeta/spmd_rules/unbind.h"
78+
#include "paddle/phi/infermeta/spmd_rules/unique.h"
7879
#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h"
7980
#include "paddle/phi/infermeta/spmd_rules/where.h"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/unique.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
18+
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
19+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
24+
SpmdInfo UniqueInferSpmd(const DistMetaTensor& x,
25+
bool return_index,
26+
bool return_inverse,
27+
bool return_counts,
28+
const std::vector<int>& axis,
29+
DataType dtype) {
30+
// Verify input args
31+
EXTRACT_SHAPE_AND_DIST_ATTR(x);
32+
std::vector<int64_t> x_dims_mapping_dst(x_ndim, -1);
33+
std::vector<int64_t> out_dims_mapping_dst(x_dims_mapping_dst);
34+
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
35+
x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
36+
37+
if (axis.empty()) {
38+
out_dims_mapping_dst = {-1};
39+
}
40+
TensorDistAttr out_dist_attr_dst =
41+
CopyTensorDistAttrForOutput(x_dist_attr_src);
42+
out_dist_attr_dst.set_dims_mapping(out_dims_mapping_dst);
43+
44+
TensorDistAttr indices_dist_attr_dst = TensorDistAttr();
45+
if (return_index) {
46+
indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
47+
indices_dist_attr_dst.set_dims_mapping({-1});
48+
}
49+
50+
TensorDistAttr inverse_dist_attr_dst = TensorDistAttr();
51+
if (return_inverse) {
52+
inverse_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
53+
inverse_dist_attr_dst.set_dims_mapping({-1});
54+
// TODO(dev): https://github.com/PaddlePaddle/Paddle/issues/72822
55+
// if (axis.empty()) {
56+
// inverse_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst);
57+
// }
58+
}
59+
60+
TensorDistAttr counts_dist_attr_dst = TensorDistAttr();
61+
if (return_counts) {
62+
counts_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
63+
counts_dist_attr_dst.set_dims_mapping({-1});
64+
}
65+
66+
VLOG(4) << "UniqueInferSpmd: All input and output TensorDistAttr are set to "
67+
"fully replicated status.";
68+
return {{x_dist_attr_dst},
69+
{out_dist_attr_dst,
70+
indices_dist_attr_dst,
71+
inverse_dist_attr_dst,
72+
counts_dist_attr_dst}};
73+
}
74+
75+
SpmdInfo UniqueInferSpmdStatic(const DistMetaTensor& x,
76+
bool return_index,
77+
bool return_inverse,
78+
bool return_counts,
79+
const std::vector<int>& axis,
80+
DataType dtype,
81+
bool is_sorted) {
82+
return UniqueInferSpmd(
83+
x, return_index, return_inverse, return_counts, axis, dtype);
84+
}
85+
} // namespace distributed
86+
} // namespace phi
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
18+
#include "paddle/phi/core/distributed/type_defs.h"
19+
20+
namespace phi {
21+
namespace distributed {
22+
23+
SpmdInfo UniqueInferSpmd(const DistMetaTensor& x,
24+
bool return_index,
25+
bool return_inverse,
26+
bool return_counts,
27+
const std::vector<int>& axis,
28+
DataType dtype);
29+
30+
SpmdInfo UniqueInferSpmdStatic(const DistMetaTensor& x,
31+
bool return_index,
32+
bool return_inverse,
33+
bool return_counts,
34+
const std::vector<int>& axis,
35+
DataType dtype,
36+
bool is_sorted);
37+
} // namespace distributed
38+
} // namespace phi

paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@
381381
output : Tensor(out), Tensor(indices), Tensor(inverse), Tensor(counts)
382382
infer_meta :
383383
func : UniqueInferMeta
384+
spmd_rule : UniqueInferSpmd
384385
kernel :
385386
func : unique
386387
data_type : x

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@
850850
optional : indices, counts
851851
infer_meta :
852852
func : UniqueRawInferMeta
853+
spmd_rule : UniqueInferSpmdStatic
853854
kernel :
854855
func : unique
855856
data_type : x

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ if(WITH_DISTRIBUTE)
4444
py_test_modules(test_logsumexp_rule MODULES test_logsumexp_rule)
4545
py_test_modules(test_nonzero_rule MODULES test_nonzero_rule)
4646
if(NOT WITH_ROCM)
47+
py_test_modules(test_unique_rule MODULES test_unique_rule)
4748
py_test_modules(test_topk_rule MODULES test_topk_rule)
4849
py_test_modules(test_add_n_rule MODULES test_add_n_rule)
4950
py_test_modules(test_mean_all_rule MODULES test_mean_all_rule)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from collections import OrderedDict
17+
18+
from paddle.distributed.auto_parallel.static.dist_attribute import (
19+
DistTensorSpec,
20+
TensorDistAttr,
21+
)
22+
from paddle.distributed.fleet import auto
23+
from paddle.framework import convert_np_dtype_to_dtype_, core
24+
25+
26+
class TestUniqueSPMDRule(unittest.TestCase):
27+
def setUp(self):
28+
self.rule = core.get_phi_spmd_rule("unique")
29+
x_shape = [4, 8]
30+
process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]])
31+
32+
x_tensor_dist_attr = TensorDistAttr()
33+
x_tensor_dist_attr.dims_mapping = [1, 0]
34+
x_tensor_dist_attr.process_mesh = process_mesh
35+
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
36+
self.attrs = OrderedDict()
37+
self.attrs["return_index"] = True
38+
self.attrs["return_inverse"] = True
39+
self.attrs["return_counts"] = True
40+
self.attrs["axis"] = []
41+
self.attrs['dtype'] = convert_np_dtype_to_dtype_("int32")
42+
43+
def test_infer_forward(self):
44+
# return_index=True, return_inverse=True, return_counts=True, axis={}
45+
# [0, -1] --> [-1,-1], [-1], [-1], [-1], [-1]
46+
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
47+
result_dist_attrs = self.rule.infer_forward(
48+
self.x_dist_tensor_spec,
49+
self.attrs["return_index"],
50+
self.attrs["return_inverse"],
51+
self.attrs["return_counts"],
52+
self.attrs["axis"],
53+
self.attrs['dtype'],
54+
)
55+
56+
self.assertEqual(len(result_dist_attrs), 2)
57+
inferred_input_dist_attrs = result_dist_attrs[0]
58+
inferred_output_dist_attrs = result_dist_attrs[1]
59+
60+
self.assertEqual(len(inferred_input_dist_attrs), 1)
61+
self.assertEqual(len(inferred_output_dist_attrs), 4)
62+
63+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, -1])
64+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [-1])
65+
self.assertEqual(inferred_output_dist_attrs[1].dims_mapping, [-1])
66+
self.assertEqual(inferred_output_dist_attrs[2].dims_mapping, [-1])
67+
self.assertEqual(inferred_output_dist_attrs[3].dims_mapping, [-1])
68+
69+
# return_index=True, return_inverse=True, return_counts=True, axis={0}
70+
# [0, -1] --> [-1,-1], [-1,-1], [-1], [-1], [-1]
71+
self.x_dist_tensor_spec.set_dims_mapping([0, -1])
72+
self.attrs["axis"] = [0]
73+
result_dist_attrs = self.rule.infer_forward(
74+
self.x_dist_tensor_spec,
75+
self.attrs["return_index"],
76+
self.attrs["return_inverse"],
77+
self.attrs["return_counts"],
78+
self.attrs["axis"],
79+
self.attrs['dtype'],
80+
)
81+
82+
self.assertEqual(len(result_dist_attrs), 2)
83+
inferred_input_dist_attrs = result_dist_attrs[0]
84+
inferred_output_dist_attrs = result_dist_attrs[1]
85+
86+
self.assertEqual(len(inferred_input_dist_attrs), 1)
87+
self.assertEqual(len(inferred_output_dist_attrs), 4)
88+
89+
self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, -1])
90+
self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [-1, -1])
91+
self.assertEqual(inferred_output_dist_attrs[1].dims_mapping, [-1])
92+
self.assertEqual(inferred_output_dist_attrs[2].dims_mapping, [-1])
93+
self.assertEqual(inferred_output_dist_attrs[3].dims_mapping, [-1])
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)