Skip to content

Commit 6954fdf

Browse files
[ET][Metal] Update aoti_common with additional AOTI functions needed by Metal backend
ghstack-source-id: f11bbd0 ghstack-comment-id: 3391401077 Pull-Request: #15003
1 parent 896178e commit 6954fdf

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

backends/aoti/aoti_model_container.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ AOTInductorModelContainerGetNumOutputsFunc
2525
AOTInductorModelContainerGetNumOutputs = nullptr;
2626
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
2727

28+
// Global function pointers needed by Metal backend
29+
AOTInductorModelContainerGetInputNameFunc
30+
AOTInductorModelContainerGetInputName = nullptr;
31+
AOTInductorModelContainerGetNumConstantsFunc
32+
AOTInductorModelContainerGetNumConstants = nullptr;
33+
2834
} // extern "C"
2935

3036
} // namespace aoti

backends/aoti/aoti_model_container.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ extern AOTInductorModelContainerGetNumOutputsFunc
7070
AOTInductorModelContainerGetNumOutputs;
7171
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
7272

73+
// Function pointer types needed by Metal backend
74+
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
75+
AOTInductorModelContainerHandle container_handle,
76+
size_t input_idx,
77+
const char** input_name);
78+
79+
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
80+
AOTInductorModelContainerHandle container_handle,
81+
size_t* num_constants);
82+
83+
// Global function pointers needed by Metal backend
84+
extern AOTInductorModelContainerGetInputNameFunc
85+
AOTInductorModelContainerGetInputName;
86+
extern AOTInductorModelContainerGetNumConstantsFunc
87+
AOTInductorModelContainerGetNumConstants;
88+
7389
} // extern "C"
7490

7591
// AOTI Delegate Handle structure

backends/aoti/common_shims.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ void cleanup_tensor_metadata() {
145145
internal::tensor_to_strides.clear();
146146
}
147147

148+
// Needed by Metal backend
149+
size_t aoti_torch_dtype_element_size(int32_t dtype) {
150+
return dtype_to_element_size(dtype);
151+
}
152+
148153
} // extern "C"
149154

150155
} // namespace aoti

backends/aoti/common_shims.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ void aoti_torch_grad_mode_set_enabled(bool enabled);
6868
// Cleanup functions for clearing global state
6969
void cleanup_tensor_metadata();
7070

71+
// Needed by Metal backend
72+
size_t aoti_torch_dtype_element_size(int32_t dtype);
73+
7174
} // extern "C"
7275

7376
} // namespace aoti

0 commit comments

Comments
 (0)