Skip to content

Commit 61ead64

Browse files
Update
[ghstack-poisoned]
2 parents 780d883 + 4eaa345 commit 61ead64

File tree

4 files changed

+35
-92
lines changed

4 files changed

+35
-92
lines changed

backends/apple/metal/runtime/shims/et_metal_ops.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,6 @@ namespace metal {
1818
extern "C" {
1919
#endif
2020

21-
/**
22-
* ExecutorTorch implementation of aoti_torch_mps_addmm_out.
23-
* Performs matrix multiplication with bias: out = beta * self + alpha * (mat1 @
24-
* mat2)
25-
*/
26-
AOTITorchError aoti_torch_mps_addmm_out(
27-
AOTITensorHandle out,
28-
AOTITensorHandle self,
29-
AOTITensorHandle mat1,
30-
AOTITensorHandle mat2,
31-
double beta,
32-
double alpha);
33-
3421
/**
3522
* ExecutorTorch implementation of aoti_torch_mps_mm_out.
3623
* Performs simple matrix multiplication: out = self @ mat2

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -229,59 +229,6 @@ AOTITorchError aoti_torch_mps_mm_out(
229229
}
230230
}
231231

232-
AOTITorchError aoti_torch_mps_addmm_out(
233-
AOTITensorHandle out,
234-
AOTITensorHandle self,
235-
AOTITensorHandle mat1,
236-
AOTITensorHandle mat2,
237-
double beta,
238-
double alpha) {
239-
ET_LOG(Debug, "aoti_torch_mps_addmm_out: Starting with out=%p, self=%p, mat1=%p, mat2=%p, beta=%f, alpha=%f",
240-
out, self, mat1, mat2, beta, alpha);
241-
242-
if (!out || !self || !mat1 || !mat2) {
243-
ET_LOG(Error, "aoti_torch_mps_addmm_out: null tensor handles");
244-
return Error::InvalidArgument;
245-
}
246-
247-
@autoreleasepool {
248-
try {
249-
// Convert AOTITensorHandle to ExecutorTorch tensors
250-
auto out_tensor = reinterpret_cast<executorch::runtime::etensor::Tensor*>(out);
251-
auto self_tensor = reinterpret_cast<executorch::runtime::etensor::Tensor*>(self);
252-
auto mat1_tensor = reinterpret_cast<executorch::runtime::etensor::Tensor*>(mat1);
253-
auto mat2_tensor = reinterpret_cast<executorch::runtime::etensor::Tensor*>(mat2);
254-
255-
ET_LOG(Debug, "aoti_torch_mps_addmm_out: Converted tensor handles to ET tensors");
256-
257-
// For now, just zero out the output tensor to get the right shape
258-
// TODO: Implement actual matrix multiplication: out = beta * self + alpha * (mat1 @ mat2)
259-
260-
// Get output data pointer and size
261-
float* out_data = static_cast<float*>(out_tensor->mutable_data_ptr());
262-
size_t out_numel = out_tensor->numel();
263-
264-
if (!out_data) {
265-
ET_LOG(Error, "aoti_torch_mps_addmm_out: null output data pointer");
266-
return Error::InvalidArgument;
267-
}
268-
269-
// Zero out the output tensor
270-
std::memset(out_data, 0, out_numel * sizeof(float));
271-
272-
ET_LOG(Debug, "aoti_torch_mps_addmm_out: Zeroed output tensor with %zu elements", out_numel);
273-
return Error::Ok;
274-
275-
} catch (const std::exception& e) {
276-
ET_LOG(Error, "aoti_torch_mps_addmm_out exception: %s", e.what());
277-
return Error::Internal;
278-
} catch (...) {
279-
ET_LOG(Error, "aoti_torch_mps_addmm_out: unknown exception");
280-
return Error::Internal;
281-
}
282-
}
283-
}
284-
285232
AOTITorchError aoti_torch_mps_convolution(
286233
AOTITensorHandle input,
287234
AOTITensorHandle weight,
@@ -743,7 +690,7 @@ AOTITorchError aoti_torch_mps_convolution(
743690
output_strides.data(),
744691
0, // storage_offset
745692
dtype, // dtype
746-
2, // device_type (MPS)
693+
13, // device_type (MPS)
747694
0, // device_index
748695
&output_tensor_handle,
749696
0, // layout (strided)
@@ -859,6 +806,12 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
859806

860807
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: mps_dtype=%d, element_size=%zu", mps_dtype, element_size);
861808

809+
// Check that headSize is not zero to avoid division by zero
810+
if (headSize == 0) {
811+
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: headSize is zero");
812+
throw std::runtime_error("headSize must be non-zero for scaled dot product attention");
813+
}
814+
862815
// Calculate scale factor
863816
double scale_factor = scale ? *scale : (1.0 / sqrt(static_cast<double>(headSize)));
864817
ET_LOG(Debug, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: scale_factor=%f", scale_factor);
@@ -1193,7 +1146,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
11931146
out_strides.data(),
11941147
0, // storage_offset
11951148
dtype,
1196-
2, // device_type (MPS)
1149+
13, // device_type (MPS)
11971150
0, // device_index
11981151
&out_tensor_handle,
11991152
0, // layout (strided)
@@ -1208,7 +1161,7 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
12081161
attn_strides.data(),
12091162
0, // storage_offset
12101163
dtype,
1211-
2, // device_type (MPS)
1164+
13, // device_type (MPS)
12121165
0, // device_index
12131166
&attn_tensor_handle,
12141167
0, // layout (strided)

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
ExecutorchProgramManager,
3939
)
4040
from executorch.exir.passes import ToOutVarPass
41-
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
41+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
4242
from executorch.exir.program._program import to_edge
4343

4444
from torch.export.exported_program import ExportedProgram
@@ -460,7 +460,7 @@ def _lower_ep_to_cadence_gen_etrecord(
460460
emit_stacktrace=False,
461461
to_out_var_pass=ToOutVarPass(),
462462
extract_delegate_segments=False,
463-
sym_shape_eval_pass=HintBasedSymShapeEvalPass(),
463+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
464464
),
465465
)
466466

examples/models/llama/export_llama_lib.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ def _to_edge_and_lower_llama_xnnpack(
874874
xnnpack_extended_ops: bool = False,
875875
generate_etrecord: bool = False,
876876
verbose: bool = False,
877+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
877878
) -> LLMEdgeManager: # noqa: C901
878879
partitioners = []
879880

@@ -896,9 +897,27 @@ def _to_edge_and_lower_llama_xnnpack(
896897
if generate_etrecord:
897898
builder_exported.generate_etrecord = True
898899

899-
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
900-
partitioners
901-
)
900+
builder = builder_exported.pt2e_quantize(quantizers)
901+
if gen_tag_fn is not None:
902+
from executorch.exir.passes.external_constants_pass import (
903+
delegate_external_constants_pass_unlifted,
904+
external_constants_pass,
905+
)
906+
907+
assert (
908+
builder_exported.pre_autograd_graph_module is not None
909+
), "pre_autograd_graph_module shouldn't be None here"
910+
delegate_external_constants_pass_unlifted(
911+
module=builder_exported.pre_autograd_graph_module,
912+
gen_tag_fn=gen_tag_fn,
913+
)
914+
915+
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
916+
additional_passes.append(
917+
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
918+
)
919+
920+
builder = builder.to_edge_transform_and_lower(partitioners)
902921
if verbose:
903922
print_delegation_info(builder.edge_manager.exported_program().graph_module)
904923

@@ -1136,6 +1155,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11361155
llm_config.backend.xnnpack.enabled = True
11371156

11381157
if llm_config.backend.xnnpack.enabled:
1158+
gen_tag_fn = None
11391159
if (
11401160
llm_config.export.foundation_weights_file is not None
11411161
or llm_config.export.lora_weights_file is not None
@@ -1145,24 +1165,6 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11451165
if "lora" not in x.name
11461166
else llm_config.export.lora_weights_file
11471167
)
1148-
from executorch.exir.passes.external_constants_pass import (
1149-
delegate_external_constants_pass_unlifted,
1150-
external_constants_pass,
1151-
)
1152-
1153-
assert (
1154-
builder_exported.pre_autograd_graph_module is not None
1155-
), "pre_autograd_graph_module shouldn't be None here"
1156-
delegate_external_constants_pass_unlifted(
1157-
module=builder_exported.pre_autograd_graph_module,
1158-
gen_tag_fn=gen_tag_fn,
1159-
)
1160-
1161-
# Also add a pass for 'to_executorch' to tag weights that aren't delegated.
1162-
additional_passes.append(
1163-
partial(external_constants_pass, gen_tag_fn=gen_tag_fn)
1164-
)
1165-
11661168
builder = _to_edge_and_lower_llama_xnnpack(
11671169
builder_exported,
11681170
modelname,
@@ -1173,6 +1175,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
11731175
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
11741176
generate_etrecord=llm_config.debug.generate_etrecord,
11751177
verbose=llm_config.debug.verbose,
1178+
gen_tag_fn=gen_tag_fn,
11761179
)
11771180
elif llm_config.backend.openvino.enabled:
11781181
builder = _to_edge_and_lower_llama_openvino(

0 commit comments

Comments
 (0)