66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < executorch/kernels/portable/cpu/pattern/comparison_op.h>
910#include < executorch/kernels/portable/cpu/scalar_utils.h>
1011#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
1112#include < executorch/kernels/portable/cpu/util/functional_util.h>
@@ -28,7 +29,7 @@ namespace impl {
2829namespace HiFi {
2930namespace native {
3031
31- Tensor& eq_tensor_out (
32+ Tensor& eq_Tensor_out (
3233 RuntimeContext& ctx,
3334 const Tensor& a,
3435 const Tensor& b,
@@ -39,14 +40,14 @@ Tensor& eq_tensor_out(
3940 InvalidArgument,
4041 out);
4142
42- ScalarType a_type = a.scalar_type ();
43- ScalarType b_type = b.scalar_type ();
4443 ScalarType out_type = out.scalar_type ();
4544
46- constexpr auto name = " eq.Tensor_out" ;
45+ // @lint-ignore CLANGTIDY facebook-hte-CArray
46+ static constexpr const char name[] = " eq.Tensor_out" ;
4747 constexpr int kNnlibMaxDim = 4 ; /* fallback if broadcast and dim > 4 */
4848
49- int a_dim = a.dim (), b_dim = b.dim (), out_dim = out.dim ();
49+ int a_dim = a.dim ();
50+ int b_dim = b.dim ();
5051 bool optimized = true ;
5152 /* find broadcast*/
5253 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
@@ -110,32 +111,11 @@ Tensor& eq_tensor_out(
110111 return out;
111112 }
112113
113- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, name, CTYPE_A, [&]() {
114- ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, name, CTYPE_B, [&]() {
115- using CTYPE_IN =
116- typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
117- ET_DCHECK (
118- CppTypeToScalarType<CTYPE_IN>::value == promoteTypes (a_type, b_type));
119- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, name, CTYPE_OUT, [&]() {
120- torch::executor::
121- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
122- [](const CTYPE_A val_a, const CTYPE_B val_b) {
123- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
124- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
125- bool value = a_casted == b_casted;
126- return static_cast <CTYPE_OUT>(value);
127- },
128- a,
129- b,
130- out);
131- });
132- });
133- });
134-
135- return out;
114+ return torch::executor::native::internal::
115+ comparison_tensor_out<std::equal_to, name>(ctx, a, b, out);
136116}
137117
138- Tensor& eq_scalar_out (
118+ Tensor& eq_Scalar_out (
139119 RuntimeContext& ctx,
140120 const Tensor& a,
141121 const Scalar& b,
@@ -149,40 +129,14 @@ Tensor& eq_scalar_out(
149129 InvalidArgument,
150130 out,
151131 " Failed to resize output tensor." );
132+ // @lint-ignore CLANGTIDY facebook-hte-CArray
133+ static constexpr const char name[] = " eq.Scalar_out" ;
152134
153- constexpr auto name = " eq.Scalar_out" ;
154-
155- ScalarType a_type = a.scalar_type ();
156- ScalarType b_type = torch::executor::native::utils::get_scalar_dtype (b);
157- ScalarType out_type = out.scalar_type ();
158-
159- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, name, CTYPE_A, [&]() {
160- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
161- using CTYPE_IN =
162- typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
163- ET_DCHECK (
164- CppTypeToScalarType<CTYPE_IN>::value == promoteTypes (a_type, b_type));
165- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, name, CTYPE_OUT, [&]() {
166- CTYPE_B val_b = 0 ;
167- torch::executor::native::utils::extract_scalar (b, &val_b);
168- torch::executor::apply_unary_map_fn (
169- [val_b](const CTYPE_A val_a) {
170- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
171- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
172- bool value = a_casted == b_casted;
173- return static_cast <CTYPE_OUT>(value);
174- },
175- a.const_data_ptr <CTYPE_A>(),
176- out.mutable_data_ptr <CTYPE_OUT>(),
177- out.numel ());
178- });
179- });
180- });
181-
182- return out;
135+ return torch::executor::native::internal::
136+ comparison_scalar_out<std::equal_to, name>(ctx, a, b, out);
183137}
184138
185139} // namespace native
186140} // namespace HiFi
187141} // namespace impl
188- } // namespace cadence
142+ } // namespace cadence
0 commit comments