Skip to content

Commit 876366b

Browse files
authored
为paddle.less_than、paddle.less_equal、paddle.greater_than、paddle.greater_equal添加复数类型支持 (#72619)
* update code * update api * fix less_than
1 parent 5c82d4e commit 876366b

File tree

7 files changed

+658
-96
lines changed

7 files changed

+658
-96
lines changed

paddle/phi/kernels/cpu/compare_kernel.cc

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -124,28 +124,10 @@ PD_REGISTER_KERNEL(equal_all,
124124
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
125125
}
126126

127-
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
128-
PD_REGISTER_KERNEL(name, \
129-
CPU, \
130-
ALL_LAYOUT, \
131-
phi::func##Kernel, \
132-
bool, \
133-
int, \
134-
uint8_t, \
135-
int8_t, \
136-
int16_t, \
137-
int64_t, \
138-
float, \
139-
double, \
140-
phi::dtype::float16, \
141-
phi::dtype::bfloat16) { \
142-
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
143-
}
144-
145-
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
146-
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
147-
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
148-
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
127+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(less_than, LessThan)
128+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(less_equal, LessEqual)
129+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(greater_than, GreaterThan)
130+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(greater_equal, GreaterEqual)
149131

150132
PD_REGISTER_COMPLEX_COMPARE_KERNEL(equal, Equal)
151133
PD_REGISTER_COMPLEX_COMPARE_KERNEL(not_equal, NotEqual)

paddle/phi/kernels/funcs/compare_functors.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,49 @@ namespace funcs {
2626
} \
2727
};
2828

29+
#define COMPARE_COMPLEX_FUNCTOR(func_name, op) \
30+
template <typename T> \
31+
struct func_name<phi::dtype::complex<T>> { \
32+
HOSTDEVICE bool operator()(const phi::dtype::complex<T> a, \
33+
const phi::dtype::complex<T> b) const { \
34+
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) \
35+
return false; \
36+
T ar = a.real; \
37+
T br = b.real; \
38+
T ai = a.imag; \
39+
T bi = b.imag; \
40+
return (ar op br) || (ar == br && ai op bi); \
41+
} \
42+
};
43+
44+
#define COMPARE_COMPLEX_EQUAL_FUNCTOR(func_name, op_equal, op) \
45+
template <typename T> \
46+
struct func_name<phi::dtype::complex<T>> { \
47+
HOSTDEVICE bool operator()(const phi::dtype::complex<T> a, \
48+
const phi::dtype::complex<T> b) const { \
49+
if (isnan(a.real) || isnan(a.imag) || isnan(b.real) || isnan(b.imag)) \
50+
return false; \
51+
T ar = a.real; \
52+
T br = b.real; \
53+
T ai = a.imag; \
54+
T bi = b.imag; \
55+
return (ar op br) || (ar == br && ai op_equal bi); \
56+
} \
57+
};
58+
2959
COMPARE_FUNCTOR(LessThanFunctor, <)
3060
COMPARE_FUNCTOR(LessEqualFunctor, <=)
3161
COMPARE_FUNCTOR(GreaterThanFunctor, >)
3262
COMPARE_FUNCTOR(GreaterEqualFunctor, >=)
63+
64+
COMPARE_COMPLEX_FUNCTOR(LessThanFunctor, <)
65+
COMPARE_COMPLEX_FUNCTOR(GreaterThanFunctor, >)
66+
COMPARE_COMPLEX_EQUAL_FUNCTOR(LessEqualFunctor, <=, <)
67+
COMPARE_COMPLEX_EQUAL_FUNCTOR(GreaterEqualFunctor, >=, >)
68+
3369
#undef COMPARE_FUNCTOR
70+
#undef COMPARE_COMPLEX_FUNCTOR
71+
#undef COMPARE_COMPLEX_EQUAL_FUNCTOR
3472

3573
template <typename InT, typename OutT = bool>
3674
struct EqualFunctor {

paddle/phi/kernels/kps/compare_kernel.cu

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -145,24 +145,6 @@ PD_REGISTER_KERNEL(equal_all,
145145
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
146146
}
147147

148-
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
149-
PD_REGISTER_KERNEL(name, \
150-
KPS, \
151-
ALL_LAYOUT, \
152-
phi::func##Kernel, \
153-
bool, \
154-
int, \
155-
uint8_t, \
156-
int8_t, \
157-
int16_t, \
158-
int64_t, \
159-
float, \
160-
double, \
161-
phi::dtype::float16, \
162-
phi::dtype::bfloat16) { \
163-
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
164-
}
165-
166148
#define PD_REGISTER_COMPLEX_COMPARE_KERNEL(name, func) \
167149
PD_REGISTER_KERNEL(name, \
168150
KPS, \
@@ -183,11 +165,10 @@ PD_REGISTER_KERNEL(equal_all,
183165
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
184166
}
185167

186-
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
187-
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
188-
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
189-
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
190-
168+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(less_than, LessThan)
169+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(less_equal, LessEqual)
170+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(greater_than, GreaterThan)
171+
PD_REGISTER_COMPLEX_COMPARE_KERNEL(greater_equal, GreaterEqual)
191172
PD_REGISTER_COMPLEX_COMPARE_KERNEL(equal, Equal)
192173
PD_REGISTER_COMPLEX_COMPARE_KERNEL(not_equal, NotEqual)
193174

paddle/phi/kernels/legacy/cpu/compare_kernel.cc

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -120,31 +120,15 @@ PD_REGISTER_KERNEL(less_than_raw,
120120
int16_t,
121121
int,
122122
int64_t,
123+
phi::dtype::complex<float>,
124+
phi::dtype::complex<double>,
123125
float,
124126
double,
125127
phi::dtype::float16,
126128
phi::dtype::bfloat16) {
127129
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
128130
}
129131

130-
#define PD_REGISTER_COMPARE_RAW_KERNEL(name, func) \
131-
PD_REGISTER_KERNEL(name##_raw, \
132-
CPU, \
133-
ALL_LAYOUT, \
134-
phi::func##RawKernel, \
135-
bool, \
136-
uint8_t, \
137-
int8_t, \
138-
int16_t, \
139-
int, \
140-
int64_t, \
141-
float, \
142-
double, \
143-
phi::dtype::float16, \
144-
phi::dtype::bfloat16) { \
145-
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
146-
}
147-
148132
#define PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(name, func) \
149133
PD_REGISTER_KERNEL(name##_raw, \
150134
CPU, \
@@ -165,9 +149,9 @@ PD_REGISTER_KERNEL(less_than_raw,
165149
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
166150
}
167151

168-
PD_REGISTER_COMPARE_RAW_KERNEL(less_equal, LessEqual)
169-
PD_REGISTER_COMPARE_RAW_KERNEL(greater_than, GreaterThan)
170-
PD_REGISTER_COMPARE_RAW_KERNEL(greater_equal, GreaterEqual)
152+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(less_equal, LessEqual)
153+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(greater_than, GreaterThan)
154+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(greater_equal, GreaterEqual)
171155

172156
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(equal, Equal)
173157
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(not_equal, NotEqual)

paddle/phi/kernels/legacy/kps/compare_kernel.cu

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -146,31 +146,15 @@ PD_REGISTER_KERNEL(less_than_raw,
146146
int16_t,
147147
int,
148148
int64_t,
149+
phi::dtype::complex<float>,
150+
phi::dtype::complex<double>,
149151
float,
150152
double,
151153
phi::dtype::float16,
152154
phi::dtype::bfloat16) {
153155
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
154156
}
155157

156-
#define PD_REGISTER_COMPARE_RAW_KERNEL(name, func) \
157-
PD_REGISTER_KERNEL(name##_raw, \
158-
KPS, \
159-
ALL_LAYOUT, \
160-
phi::func##RawKernel, \
161-
bool, \
162-
uint8_t, \
163-
int16_t, \
164-
int, \
165-
int8_t, \
166-
int64_t, \
167-
float, \
168-
double, \
169-
phi::dtype::float16, \
170-
phi::dtype::bfloat16) { \
171-
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
172-
}
173-
174158
#define PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(name, func) \
175159
PD_REGISTER_KERNEL(name##_raw, \
176160
KPS, \
@@ -191,10 +175,9 @@ PD_REGISTER_KERNEL(less_than_raw,
191175
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
192176
}
193177

194-
PD_REGISTER_COMPARE_RAW_KERNEL(less_equal, LessEqual)
195-
PD_REGISTER_COMPARE_RAW_KERNEL(greater_than, GreaterThan)
196-
PD_REGISTER_COMPARE_RAW_KERNEL(greater_equal, GreaterEqual)
197-
178+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(less_equal, LessEqual)
179+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(greater_than, GreaterThan)
180+
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(greater_equal, GreaterEqual)
198181
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(equal, Equal)
199182
PD_REGISTER_COMPLEX_COMPARE_RAW_KERNEL(not_equal, NotEqual)
200183

python/paddle/tensor/logic.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,8 @@ def greater_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
668668
The output has no gradient.
669669
670670
Args:
671-
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
672-
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
671+
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
672+
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
673673
name (str|None, optional): The default value is None. Normally there is no need for
674674
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
675675
Returns:
@@ -704,6 +704,8 @@ def greater_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
704704
"int32",
705705
"int64",
706706
"uint16",
707+
"complex64",
708+
"complex128",
707709
],
708710
"greater_equal",
709711
)
@@ -721,6 +723,8 @@ def greater_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
721723
"int32",
722724
"int64",
723725
"uint16",
726+
"complex64",
727+
"complex128",
724728
],
725729
"greater_equal",
726730
)
@@ -758,8 +762,8 @@ def greater_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
758762
The output has no gradient.
759763
760764
Args:
761-
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
762-
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
765+
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
766+
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
763767
name (str|None, optional): The default value is None. Normally there is no need for
764768
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
765769
Returns:
@@ -794,6 +798,8 @@ def greater_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
794798
"int32",
795799
"int64",
796800
"uint16",
801+
"complex64",
802+
"complex128",
797803
],
798804
"greater_than",
799805
)
@@ -811,6 +817,8 @@ def greater_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
811817
"int32",
812818
"int64",
813819
"uint16",
820+
"complex64",
821+
"complex128",
814822
],
815823
"greater_than",
816824
)
@@ -848,8 +856,8 @@ def less_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
848856
The output has no gradient.
849857
850858
Args:
851-
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
852-
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
859+
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
860+
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
853861
name (str|None, optional): The default value is None. Normally there is no need for
854862
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
855863
@@ -885,6 +893,8 @@ def less_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
885893
"int32",
886894
"int64",
887895
"uint16",
896+
"complex64",
897+
"complex128",
888898
],
889899
"less_equal",
890900
)
@@ -902,6 +912,8 @@ def less_equal(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
902912
"int32",
903913
"int64",
904914
"uint16",
915+
"complex64",
916+
"complex128",
905917
],
906918
"less_equal",
907919
)
@@ -939,8 +951,8 @@ def less_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
939951
The output has no gradient.
940952
941953
Args:
942-
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
943-
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64.
954+
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
955+
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, bfloat16, float16, float32, float64, uint8, int8, int16, int32, int64, complex64, complex128.
944956
name (str|None, optional): The default value is None. Normally there is no need for
945957
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
946958
@@ -976,6 +988,8 @@ def less_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
976988
"int32",
977989
"int64",
978990
"uint16",
991+
"complex64",
992+
"complex128",
979993
],
980994
"less_than",
981995
)
@@ -993,6 +1007,8 @@ def less_than(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
9931007
"int32",
9941008
"int64",
9951009
"uint16",
1010+
"complex64",
1011+
"complex128",
9961012
],
9971013
"less_than",
9981014
)

0 commit comments

Comments
 (0)