@@ -173,27 +173,13 @@ get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
173173template <typename CTYPE_COMMON, const char * op_name>
174174store_common_to_tensor_fn<CTYPE_COMMON>
175175get_store_common_to_tensor_fn_same_as_compute (const Tensor& t) {
176- return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
176+ // We already validate tensor types earlier in the process, so at
177+ // this phase, treat same_as_compute the same as our widest
178+ // SupportedTensorDtypes set.
179+ return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
177180}
178181
179- template <
180- typename CTYPE_COMMON,
181- const char * op_name,
182- std::enable_if_t <std::is_same_v<CTYPE_COMMON, float >, bool > = true >
183- store_common_to_tensor_fn<CTYPE_COMMON>
184- get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
185- void (*result)(CTYPE_COMMON, void *) = nullptr ;
186- ET_SWITCH_THREE_TYPES (
187- Float, Half, BFloat16, t.scalar_type (), unused, op_name, CTYPE, [&]() {
188- result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
189- });
190- return result;
191- }
192-
193- template <
194- typename CTYPE_COMMON,
195- const char * op_name,
196- std::enable_if_t <!std::is_same_v<CTYPE_COMMON, float >, bool > = true >
182+ template <typename CTYPE_COMMON, const char * op_name>
197183store_common_to_tensor_fn<CTYPE_COMMON>
198184get_store_common_to_tensor_fn_same_as_common (const Tensor& t) {
199185 return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
0 commit comments