2828
2929namespace executorch ::backends::cuda {
3030
31- #define LOAD_SYMBOL (handle, member, name, so_handle ) \
32- do { \
33- handle->member = reinterpret_cast <name##Func>(dlsym (so_handle, #name)); \
34- ET_CHECK_OR_RETURN_ERROR ( \
35- handle->member != nullptr , AccessFailed, " Failed to load " #name); \
36- } while (0 )
37-
3831using namespace std ;
3932using namespace aoti ;
4033
@@ -60,35 +53,32 @@ class ET_EXPERIMENTAL CudaBackend final
6053 Error load_function_pointers_into_handle (
6154 void * so_handle,
6255 AOTIDelegateHandle* handle) const {
63- LOAD_SYMBOL (
64- handle,
65- create_with_device,
66- AOTInductorModelContainerCreateWithDevice,
67- so_handle);
68-
69- LOAD_SYMBOL (
70- handle, delete_container, AOTInductorModelContainerDelete, so_handle);
71-
72- LOAD_SYMBOL (
73- handle,
74- get_num_inputs,
75- AOTInductorModelContainerGetNumInputs,
76- so_handle);
77-
78- LOAD_SYMBOL (
79- handle,
80- get_num_outputs,
81- AOTInductorModelContainerGetNumOutputs,
82- so_handle);
83-
84- LOAD_SYMBOL (handle, run, AOTInductorModelContainerRun, so_handle);
85-
86- LOAD_SYMBOL (
87- handle,
88- update_constants_from_blob,
89- AOTInductorModelUpdateConstantsFromBlob,
90- so_handle);
56+ #define LOAD_SYMBOL (member, name ) \
57+ do { \
58+ handle->member = reinterpret_cast <name##Func>(dlsym (so_handle, #name)); \
59+ ET_CHECK_OR_RETURN_ERROR ( \
60+ handle->member != nullptr , AccessFailed, " Failed to load " #name); \
61+ } while (0 )
62+ LOAD_SYMBOL (create_with_device, AOTInductorModelContainerCreateWithDevice);
63+
64+ LOAD_SYMBOL (delete_container, AOTInductorModelContainerDelete);
9165
66+ LOAD_SYMBOL (get_num_inputs, AOTInductorModelContainerGetNumInputs);
67+
68+ LOAD_SYMBOL (get_num_outputs, AOTInductorModelContainerGetNumOutputs);
69+
70+ LOAD_SYMBOL (run, AOTInductorModelContainerRun);
71+ #undef LOAD_SYMBOL
72+
73+ handle->update_constants_from_blob =
74+ reinterpret_cast <AOTInductorModelUpdateConstantsFromBlobFunc>(
75+ dlsym (so_handle, " AOTInductorModelUpdateConstantsFromBlob" ));
76+ if (handle->update_constants_from_blob == nullptr ) {
77+ ET_LOG (
78+ Info,
79+ " Failed to load AOTInductorModelUpdateConstantsFromBlob"
80+ " this .so is probably compiled on an old version of torch (<2.9.0)" );
81+ }
9282 return Error::Ok;
9383 }
9484
@@ -184,7 +174,7 @@ class ET_EXPERIMENTAL CudaBackend final
184174 std::string weights_blob_key =
185175 method_name.empty () ? " weights_blob" : method_name + " _weights_blob" ;
186176 auto buffer_res = named_data_map->get_data (weights_blob_key.c_str ());
187- if (buffer_res.ok ()) {
177+ if (buffer_res.ok () && handle-> update_constants_from_blob != nullptr ) {
188178 ET_LOG (Info, " Found %s in named data map" , weights_blob_key.c_str ());
189179 const void * weights_blob = buffer_res->data ();
190180 // Feed the weights blob into the container. Under the hood it's copying
0 commit comments