@@ -51,12 +51,8 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151}
5252
5353namespace internal {
54- template <
55- typename CTYPE_COMPUTE,
56- const char * op_name,
57- typename Op,
58- typename ... Args>
59- inline void apply_elementwise_fn (
54+ template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
55+ inline bool validate_elementwise_fn_inputs (
6056 const Op& compute_fun,
6157 KernelRuntimeContext& ctx,
6258 const Tensor& out,
@@ -65,7 +61,6 @@ inline void apply_elementwise_fn(
6561 static_assert (
6662 (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
6763 ...));
68- constexpr auto kNumInputs = sizeof ...(inputs);
6964 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
7065 const auto check_input_dtype = [](auto input, auto compute_type) {
7166 return internal::check_tensor_dtype (
@@ -75,7 +70,30 @@ inline void apply_elementwise_fn(
7570 ctx,
7671 (check_input_dtype (inputs, compute_type) && ...) &&
7772 internal::check_tensor_dtype (out, out_dtypes, compute_type),
78- InvalidArgument, );
73+ InvalidArgument,
74+ false );
75+
76+ return true ;
77+ }
78+
79+ template <
80+ typename CTYPE_COMPUTE,
81+ const char * op_name,
82+ typename Op,
83+ typename ... Args>
84+ inline void apply_elementwise_fn (
85+ const Op& compute_fun,
86+ KernelRuntimeContext& ctx,
87+ const Tensor& out,
88+ SupportedTensorDtypes out_dtypes,
89+ Args... inputs) {
90+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
91+ compute_fun, ctx, out, out_dtypes, inputs...);
92+ if (!inputs_valid) {
93+ return ;
94+ }
95+
96+ constexpr auto kNumInputs = sizeof ...(inputs);
7997
8098 struct InputInfo {
8199 load_to_compute_fn<CTYPE_COMPUTE> load_to_compute;
@@ -120,6 +138,7 @@ inline void apply_elementwise_fn(
120138 });
121139}
122140
141+ // / DEPRECATED: prefer the variant with out_dtypes in the template argument.
123142template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
124143inline void apply_unitensor_elementwise_fn (
125144 const Op& compute_fun,
@@ -132,19 +151,83 @@ inline void apply_unitensor_elementwise_fn(
132151 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
133152}
134153
154+ template <
155+ typename CTYPE_COMPUTE,
156+ const char * op_name,
157+ SupportedTensorDtypes out_dtypes,
158+ typename Op>
159+ inline void apply_unitensor_elementwise_fn (
160+ const Op& compute_fun,
161+ KernelRuntimeContext& ctx,
162+ const Tensor& a,
163+ SupportedTensorDtypes a_dtypes,
164+ const Tensor& out) {
165+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
166+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
167+ }
168+
169+ /* *
170+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
171+ */
172+ template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
173+ inline void apply_bitensor_elementwise_fn (
174+ const Op& compute_fun,
175+ KernelRuntimeContext& ctx,
176+ const Tensor& a,
177+ SupportedTensorDtypes a_dtypes,
178+ const Tensor& b,
179+ SupportedTensorDtypes b_dtypes,
180+ const Tensor& out,
181+ SupportedTensorDtypes out_dtypes) {
182+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
183+ compute_fun,
184+ ctx,
185+ out,
186+ out_dtypes,
187+ std::make_pair (&a, a_dtypes),
188+ std::make_pair (&b, b_dtypes));
189+ }
190+
135191/* *
136192 * Useful for bi-tensor elementwise operators. For each element of the inputs,
137193 * perform a computation and write to the corresponding element of the output.
138194 * Tensor broadcasting is applied wherever it is required.
139195 */
140- template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
196+ template <
197+ typename CTYPE_COMPUTE,
198+ const char * op_name,
199+ SupportedTensorDtypes out_dtypes,
200+ typename Op>
141201inline void apply_bitensor_elementwise_fn (
142202 const Op& compute_fun,
143203 KernelRuntimeContext& ctx,
144204 const Tensor& a,
145205 SupportedTensorDtypes a_dtypes,
146206 const Tensor& b,
147207 SupportedTensorDtypes b_dtypes,
208+ const Tensor& out) {
209+ internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
210+ compute_fun,
211+ ctx,
212+ out,
213+ out_dtypes,
214+ std::make_pair (&a, a_dtypes),
215+ std::make_pair (&b, b_dtypes));
216+ }
217+
218+ /* *
219+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
220+ */
221+ template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
222+ inline void apply_tritensor_elementwise_fn (
223+ const Op& compute_fun,
224+ KernelRuntimeContext& ctx,
225+ const Tensor& a,
226+ SupportedTensorDtypes a_dtypes,
227+ const Tensor& b,
228+ SupportedTensorDtypes b_dtypes,
229+ const Tensor& c,
230+ SupportedTensorDtypes c_dtypes,
148231 const Tensor& out,
149232 SupportedTensorDtypes out_dtypes) {
150233 internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
@@ -153,7 +236,8 @@ inline void apply_bitensor_elementwise_fn(
153236 out,
154237 out_dtypes,
155238 std::make_pair (&a, a_dtypes),
156- std::make_pair (&b, b_dtypes));
239+ std::make_pair (&b, b_dtypes),
240+ std::make_pair (&c, c_dtypes));
157241}
158242
159243/* *
@@ -176,7 +260,11 @@ inline void apply_bitensor_elementwise_fn(
176260 * static constexpr const char op_name[] = "my_op";
177261 * apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>.
178262 */
179- template <typename CTYPE_COMPUTE, const char * op_name, typename Op>
263+ template <
264+ typename CTYPE_COMPUTE,
265+ const char * op_name,
266+ SupportedTensorDtypes out_dtypes,
267+ typename Op>
180268inline void apply_tritensor_elementwise_fn (
181269 const Op& compute_fun,
182270 KernelRuntimeContext& ctx,
@@ -186,8 +274,7 @@ inline void apply_tritensor_elementwise_fn(
186274 SupportedTensorDtypes b_dtypes,
187275 const Tensor& c,
188276 SupportedTensorDtypes c_dtypes,
189- const Tensor& out,
190- SupportedTensorDtypes out_dtypes) {
277+ const Tensor& out) {
191278 internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>(
192279 compute_fun,
193280 ctx,
0 commit comments