@@ -229,59 +229,6 @@ AOTITorchError aoti_torch_mps_mm_out(
229229 }
230230}
231231
232- AOTITorchError aoti_torch_mps_addmm_out (
233- AOTITensorHandle out,
234- AOTITensorHandle self,
235- AOTITensorHandle mat1,
236- AOTITensorHandle mat2,
237- double beta,
238- double alpha) {
239- ET_LOG (Debug, " aoti_torch_mps_addmm_out: Starting with out=%p, self=%p, mat1=%p, mat2=%p, beta=%f, alpha=%f" ,
240- out, self, mat1, mat2, beta, alpha);
241-
242- if (!out || !self || !mat1 || !mat2) {
243- ET_LOG (Error, " aoti_torch_mps_addmm_out: null tensor handles" );
244- return Error::InvalidArgument;
245- }
246-
247- @autoreleasepool {
248- try {
249- // Convert AOTITensorHandle to ExecutorTorch tensors
250- auto out_tensor = reinterpret_cast <executorch::runtime::etensor::Tensor*>(out);
251- auto self_tensor = reinterpret_cast <executorch::runtime::etensor::Tensor*>(self);
252- auto mat1_tensor = reinterpret_cast <executorch::runtime::etensor::Tensor*>(mat1);
253- auto mat2_tensor = reinterpret_cast <executorch::runtime::etensor::Tensor*>(mat2);
254-
255- ET_LOG (Debug, " aoti_torch_mps_addmm_out: Converted tensor handles to ET tensors" );
256-
257- // For now, just zero out the output tensor to get the right shape
258- // TODO: Implement actual matrix multiplication: out = beta * self + alpha * (mat1 @ mat2)
259-
260- // Get output data pointer and size
261- float * out_data = static_cast <float *>(out_tensor->mutable_data_ptr ());
262- size_t out_numel = out_tensor->numel ();
263-
264- if (!out_data) {
265- ET_LOG (Error, " aoti_torch_mps_addmm_out: null output data pointer" );
266- return Error::InvalidArgument;
267- }
268-
269- // Zero out the output tensor
270- std::memset (out_data, 0 , out_numel * sizeof (float ));
271-
272- ET_LOG (Debug, " aoti_torch_mps_addmm_out: Zeroed output tensor with %zu elements" , out_numel);
273- return Error::Ok;
274-
275- } catch (const std::exception& e) {
276- ET_LOG (Error, " aoti_torch_mps_addmm_out exception: %s" , e.what ());
277- return Error::Internal;
278- } catch (...) {
279- ET_LOG (Error, " aoti_torch_mps_addmm_out: unknown exception" );
280- return Error::Internal;
281- }
282- }
283- }
284-
285232AOTITorchError aoti_torch_mps_convolution (
286233 AOTITensorHandle input,
287234 AOTITensorHandle weight,
@@ -743,7 +690,7 @@ AOTITorchError aoti_torch_mps_convolution(
743690 output_strides.data (),
744691 0 , // storage_offset
745692 dtype, // dtype
746- 2 , // device_type (MPS)
693+ 13 , // device_type (MPS)
747694 0 , // device_index
748695 &output_tensor_handle,
749696 0 , // layout (strided)
@@ -859,6 +806,12 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
859806
860807 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: mps_dtype=%d, element_size=%zu" , mps_dtype, element_size);
861808
809+ // Check that headSize is not zero to avoid division by zero
810+ if (headSize == 0 ) {
811+ ET_LOG (Error, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero" );
812+ throw std::runtime_error (" headSize must be non-zero for scaled dot product attention" );
813+ }
814+
862815 // Calculate scale factor
863816 double scale_factor = scale ? *scale : (1.0 / sqrt (static_cast <double >(headSize)));
864817 ET_LOG (Debug, " aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f" , scale_factor);
@@ -1193,7 +1146,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
11931146 out_strides.data (),
11941147 0 , // storage_offset
11951148 dtype,
1196- 2 , // device_type (MPS)
1149+ 13 , // device_type (MPS)
11971150 0 , // device_index
11981151 &out_tensor_handle,
11991152 0 , // layout (strided)
@@ -1208,7 +1161,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12081161 attn_strides.data (),
12091162 0 , // storage_offset
12101163 dtype,
1211- 2 , // device_type (MPS)
1164+ 13 , // device_type (MPS)
12121165 0 , // device_index
12131166 &attn_tensor_handle,
12141167 0 , // layout (strided)
0 commit comments