Skip to content

Commit 55b3540

Browse files
authored
[cadence][hifi] Patch operators on build time issues
Differential Revision: D77957939 Pull Request resolved: #12283
1 parent 3c33d29 commit 55b3540

File tree

17 files changed

+234
-497
lines changed

17 files changed

+234
-497
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
- op: clamp.Tensor_out
7676
kernels:
7777
- arg_meta: null
78-
kernel_name: cadence::impl::HiFi::clamp_tensor_out
78+
kernel_name: cadence::impl::HiFi::clamp_Tensor_out
7979

8080
- op: clone.out
8181
kernels:
@@ -100,7 +100,7 @@
100100
- op: eq.Tensor_out
101101
kernels:
102102
- arg_meta: null
103-
kernel_name: cadence::impl::HiFi::eq_tensor_out
103+
kernel_name: cadence::impl::HiFi::eq_Tensor_out
104104

105105
- op: fmod.Tensor_out
106106
kernels:
@@ -120,12 +120,12 @@
120120
- op: ge.Scalar_out
121121
kernels:
122122
- arg_meta: null
123-
kernel_name: cadence::impl::HiFi::ge_scalar_out
123+
kernel_name: cadence::impl::HiFi::ge_Scalar_out
124124

125125
- op: ge.Tensor_out
126126
kernels:
127127
- arg_meta: null
128-
kernel_name: cadence::impl::HiFi::ge_tensor_out
128+
kernel_name: cadence::impl::HiFi::ge_Tensor_out
129129

130130
- op: gelu.out
131131
kernels:
@@ -135,12 +135,12 @@
135135
- op: gt.Scalar_out
136136
kernels:
137137
- arg_meta: null
138-
kernel_name: cadence::impl::HiFi::gt_scalar_out
138+
kernel_name: cadence::impl::HiFi::gt_Scalar_out
139139

140140
- op: gt.Tensor_out
141141
kernels:
142142
- arg_meta: null
143-
kernel_name: cadence::impl::HiFi::gt_tensor_out
143+
kernel_name: cadence::impl::HiFi::gt_Tensor_out
144144

145145
- op: hardtanh.out
146146
kernels:
@@ -150,27 +150,27 @@
150150
- op: le.Scalar_out
151151
kernels:
152152
- arg_meta: null
153-
kernel_name: cadence::impl::HiFi::le_scalar_out
153+
kernel_name: cadence::impl::HiFi::le_Scalar_out
154154

155155
- op: le.Tensor_out
156156
kernels:
157157
- arg_meta: null
158-
kernel_name: cadence::impl::HiFi::le_tensor_out
158+
kernel_name: cadence::impl::HiFi::le_Tensor_out
159159

160160
- op: lt.Scalar_out
161161
kernels:
162162
- arg_meta: null
163-
kernel_name: cadence::impl::HiFi::lt_scalar_out
163+
kernel_name: cadence::impl::HiFi::lt_Scalar_out
164164

165165
- op: lt.Tensor_out
166166
kernels:
167167
- arg_meta: null
168-
kernel_name: cadence::impl::HiFi::lt_tensor_out
168+
kernel_name: cadence::impl::HiFi::lt_Tensor_out
169169

170170
- op: masked_fill.Scalar_out
171171
kernels:
172172
- arg_meta: null
173-
kernel_name: cadence::impl::HiFi::masked_fill_scalar_out
173+
kernel_name: cadence::impl::HiFi::masked_fill_Scalar_out
174174

175175
- op: max_pool2d_with_indices.out
176176
kernels:
@@ -185,7 +185,7 @@
185185
- op: mean.out
186186
kernels:
187187
- arg_meta: null
188-
kernel_name: cadence::impl::HiFi::mean_out
188+
kernel_name: cadence::impl::HiFi::mean_out
189189

190190
- op: minimum.out
191191
kernels:
@@ -205,7 +205,7 @@
205205
- op: ne.Tensor_out
206206
kernels:
207207
- arg_meta: null
208-
kernel_name: cadence::impl::HiFi::ne_tensor_out
208+
kernel_name: cadence::impl::HiFi::ne_Tensor_out
209209

210210
- op: permute_copy.out
211211
kernels:
@@ -289,11 +289,11 @@
289289
kernels:
290290
- arg_meta: null
291291
kernel_name: cadence::impl::HiFi::dequantize_per_tensor_out
292-
292+
293293
- func: cadence::quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
294294
kernels:
295295
- arg_meta: null
296-
kernel_name: cadence::impl::HiFi::quantized_conv_out
296+
kernel_name: cadence::impl::HiFi::quantized_conv_out
297297

298298
- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
299299
kernels:

backends/cadence/hifi/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ target_include_directories(
8888

8989
# Custom ops that are needed to run the test model.
9090
add_library(
91-
custom_ops "op_quantized_linear_out.cpp" "op_quantized_layer_norm.cpp" "quantized_matmul_out.cpp"
91+
custom_ops "op_quantized_linear_out.cpp" "op_quantized_layer_norm.cpp" "op_quantized_matmul_out.cpp"
9292
"op_quantize_per_tensor.cpp" "op_quantized_relu_out.cpp" "op_dequantize_per_tensor.cpp"
9393
"op_quantized_conv_out.cpp" "op_quantized_fully_connected_out"
9494
)

backends/cadence/hifi/operators/op_bitwise_and.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ Tensor& bitwise_and_Tensor_out(
169169
return out;
170170
}
171171

172-
return torch::executor::native::internal::bitwise_tensor_out<op_name>(
173-
ctx, a, b, out);
172+
return torch::executor::native::internal::
173+
bitwise_tensor_out<std::bit_and, op_name>(ctx, a, b, out);
174174
}
175175

176176
Tensor& bitwise_and_Scalar_out(
@@ -180,8 +180,8 @@ Tensor& bitwise_and_Scalar_out(
180180
Tensor& out) {
181181
// @lint-ignore CLANGTIDY facebook-hte-CArray
182182
static constexpr const char op_name[] = "bitwise_and.Scalar_out";
183-
return torch::executor::native::internal::bitwise_scalar_out<op_name>(
184-
ctx, a, b, out);
183+
return torch::executor::native::internal::
184+
bitwise_scalar_out<std::bit_and, op_name>(ctx, a, b, out);
185185
}
186186

187187
} // namespace native

backends/cadence/hifi/operators/op_bitwise_or.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ Tensor& bitwise_or_Tensor_out(
169169
return out;
170170
}
171171

172-
return torch::executor::native::internal::bitwise_tensor_out<op_name>(
173-
ctx, a, b, out);
172+
return torch::executor::native::internal::
173+
bitwise_tensor_out<std::bit_or, op_name>(ctx, a, b, out);
174174
}
175175

176176
Tensor& bitwise_or_Scalar_out(
@@ -180,8 +180,8 @@ Tensor& bitwise_or_Scalar_out(
180180
Tensor& out) {
181181
// @lint-ignore CLANGTIDY facebook-hte-CArray
182182
static constexpr const char op_name[] = "bitwise_or.Scalar_out";
183-
return torch::executor::native::internal::bitwise_scalar_out<op_name>(
184-
ctx, a, b, out);
183+
return torch::executor::native::internal::
184+
bitwise_scalar_out<std::bit_or, op_name>(ctx, a, b, out);
185185
}
186186

187187
} // namespace native

backends/cadence/hifi/operators/op_bitwise_xor.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ Tensor& bitwise_xor_Tensor_out(
169169
return out;
170170
}
171171

172-
return torch::executor::native::internal::bitwise_tensor_out<op_name>(
173-
ctx, a, b, out);
172+
return torch::executor::native::internal::
173+
bitwise_tensor_out<std::bit_xor, op_name>(ctx, a, b, out);
174174
}
175175

176176
Tensor& bitwise_xor_Scalar_out(
@@ -180,8 +180,8 @@ Tensor& bitwise_xor_Scalar_out(
180180
Tensor& out) {
181181
// @lint-ignore CLANGTIDY facebook-hte-CArray
182182
static constexpr const char op_name[] = "bitwise_xor.Scalar_out";
183-
return torch::executor::native::internal::bitwise_scalar_out<op_name>(
184-
ctx, a, b, out);
183+
return torch::executor::native::internal::
184+
bitwise_scalar_out<std::bit_xor, op_name>(ctx, a, b, out);
185185
}
186186

187187
} // namespace native

backends/cadence/hifi/operators/op_clamp.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,6 @@ Tensor& clamp_Tensor_out(
322322
return out;
323323
}
324324

325-
Tensor& clamp_tensor_out(
326-
RuntimeContext& ctx,
327-
const Tensor& in,
328-
const std::optional<Tensor>& min_opt,
329-
const std::optional<Tensor>& max_opt,
330-
Tensor& out) {
331-
return clamp_Tensor_out(ctx, in, min_opt, max_opt, out);
332-
}
333-
334325
} // namespace native
335326
} // namespace HiFi
336327
} // namespace impl

backends/cadence/hifi/operators/op_eq.cpp

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/portable/cpu/pattern/comparison_op.h>
910
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1011
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1112
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -28,7 +29,7 @@ namespace impl {
2829
namespace HiFi {
2930
namespace native {
3031

31-
Tensor& eq_tensor_out(
32+
Tensor& eq_Tensor_out(
3233
RuntimeContext& ctx,
3334
const Tensor& a,
3435
const Tensor& b,
@@ -39,14 +40,14 @@ Tensor& eq_tensor_out(
3940
InvalidArgument,
4041
out);
4142

42-
ScalarType a_type = a.scalar_type();
43-
ScalarType b_type = b.scalar_type();
4443
ScalarType out_type = out.scalar_type();
4544

46-
constexpr auto name = "eq.Tensor_out";
45+
// @lint-ignore CLANGTIDY facebook-hte-CArray
46+
static constexpr const char name[] = "eq.Tensor_out";
4747
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
4848

49-
int a_dim = a.dim(), b_dim = b.dim(), out_dim = out.dim();
49+
int a_dim = a.dim();
50+
int b_dim = b.dim();
5051
bool optimized = true;
5152
/*find broadcast*/
5253
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
@@ -110,32 +111,11 @@ Tensor& eq_tensor_out(
110111
return out;
111112
}
112113

113-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, name, CTYPE_A, [&]() {
114-
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, name, CTYPE_B, [&]() {
115-
using CTYPE_IN =
116-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
117-
ET_DCHECK(
118-
CppTypeToScalarType<CTYPE_IN>::value == promoteTypes(a_type, b_type));
119-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, name, CTYPE_OUT, [&]() {
120-
torch::executor::
121-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
122-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
123-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
124-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
125-
bool value = a_casted == b_casted;
126-
return static_cast<CTYPE_OUT>(value);
127-
},
128-
a,
129-
b,
130-
out);
131-
});
132-
});
133-
});
134-
135-
return out;
114+
return torch::executor::native::internal::
115+
comparison_tensor_out<std::equal_to, name>(ctx, a, b, out);
136116
}
137117

138-
Tensor& eq_scalar_out(
118+
Tensor& eq_Scalar_out(
139119
RuntimeContext& ctx,
140120
const Tensor& a,
141121
const Scalar& b,
@@ -149,40 +129,14 @@ Tensor& eq_scalar_out(
149129
InvalidArgument,
150130
out,
151131
"Failed to resize output tensor.");
132+
// @lint-ignore CLANGTIDY facebook-hte-CArray
133+
static constexpr const char name[] = "eq.Scalar_out";
152134

153-
constexpr auto name = "eq.Scalar_out";
154-
155-
ScalarType a_type = a.scalar_type();
156-
ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b);
157-
ScalarType out_type = out.scalar_type();
158-
159-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, name, CTYPE_A, [&]() {
160-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
161-
using CTYPE_IN =
162-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
163-
ET_DCHECK(
164-
CppTypeToScalarType<CTYPE_IN>::value == promoteTypes(a_type, b_type));
165-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, name, CTYPE_OUT, [&]() {
166-
CTYPE_B val_b = 0;
167-
torch::executor::native::utils::extract_scalar(b, &val_b);
168-
torch::executor::apply_unary_map_fn(
169-
[val_b](const CTYPE_A val_a) {
170-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
171-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
172-
bool value = a_casted == b_casted;
173-
return static_cast<CTYPE_OUT>(value);
174-
},
175-
a.const_data_ptr<CTYPE_A>(),
176-
out.mutable_data_ptr<CTYPE_OUT>(),
177-
out.numel());
178-
});
179-
});
180-
});
181-
182-
return out;
135+
return torch::executor::native::internal::
136+
comparison_scalar_out<std::equal_to, name>(ctx, a, b, out);
183137
}
184138

185139
} // namespace native
186140
} // namespace HiFi
187141
} // namespace impl
188-
} // namespace cadence
142+
} // namespace cadence

0 commit comments

Comments
 (0)