Skip to content

Commit b1448a6

Browse files
committed
Address comments
1 parent cf1305b commit b1448a6

File tree

2 files changed

+36
-47
lines changed

2 files changed

+36
-47
lines changed

backends/cuda/cuda_backend.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def preprocess(
183183
elif path.endswith(".wrapper_weights.blob"):
184184
blob_path = path
185185

186-
if so_path is None:
186+
if so_path is None or blob_path is None:
187187
raise RuntimeError(
188-
f"Could not find .wrapper.so file in compiled paths, got {paths}"
188+
f"Could not find required files in compiled paths, got {paths}"
189189
)
190190

191191
# pyre-ignorep[6]: Incompatible parameter type
@@ -198,15 +198,14 @@ def preprocess(
198198
# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
199199
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
200200

201-
# Add weights blob to named data store if it exists
202-
if blob_path is not None:
203-
with open(blob_path, "rb") as f:
204-
blob_data = f.read()
205-
named_data_store.add_named_data(
206-
method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob"
207-
)
208-
# Clean up the weights blob file
209-
os.remove(blob_path)
201+
# Add weights blob to named data store
202+
with open(blob_path, "rb") as f:
203+
blob_data = f.read()
204+
named_data_store.add_named_data(
205+
method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob"
206+
)
207+
# Clean up the weights blob file
208+
os.remove(blob_path)
210209

211210
# Clean up the generated so file; it has been packaged into the NamedDataStore
212211
# pyre-ignorep[6]: Incompatible parameter type

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,6 @@
2828

2929
namespace executorch::backends::cuda {
3030

31-
#define LOAD_SYMBOL(handle, member, name, so_handle) \
32-
do { \
33-
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
34-
ET_CHECK_OR_RETURN_ERROR( \
35-
handle->member != nullptr, AccessFailed, "Failed to load " #name); \
36-
} while (0)
37-
3831
using namespace std;
3932
using namespace aoti;
4033

@@ -60,35 +53,32 @@ class ET_EXPERIMENTAL CudaBackend final
6053
Error load_function_pointers_into_handle(
6154
void* so_handle,
6255
AOTIDelegateHandle* handle) const {
63-
LOAD_SYMBOL(
64-
handle,
65-
create_with_device,
66-
AOTInductorModelContainerCreateWithDevice,
67-
so_handle);
68-
69-
LOAD_SYMBOL(
70-
handle, delete_container, AOTInductorModelContainerDelete, so_handle);
71-
72-
LOAD_SYMBOL(
73-
handle,
74-
get_num_inputs,
75-
AOTInductorModelContainerGetNumInputs,
76-
so_handle);
77-
78-
LOAD_SYMBOL(
79-
handle,
80-
get_num_outputs,
81-
AOTInductorModelContainerGetNumOutputs,
82-
so_handle);
83-
84-
LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);
85-
86-
LOAD_SYMBOL(
87-
handle,
88-
update_constants_from_blob,
89-
AOTInductorModelUpdateConstantsFromBlob,
90-
so_handle);
56+
#define LOAD_SYMBOL(member, name) \
57+
do { \
58+
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
59+
ET_CHECK_OR_RETURN_ERROR( \
60+
handle->member != nullptr, AccessFailed, "Failed to load " #name); \
61+
} while (0)
62+
LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice);
63+
64+
LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete);
9165

66+
LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs);
67+
68+
LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs);
69+
70+
LOAD_SYMBOL(run, AOTInductorModelContainerRun);
71+
#undef LOAD_SYMBOL
72+
73+
handle->update_constants_from_blob =
74+
reinterpret_cast<AOTInductorModelUpdateConstantsFromBlobFunc>(
75+
dlsym(so_handle, "AOTInductorModelUpdateConstantsFromBlob"));
76+
if (handle->update_constants_from_blob == nullptr) {
77+
ET_LOG(
78+
Info,
79+
"Failed to load AOTInductorModelUpdateConstantsFromBlob"
80+
" this .so is probably compiled on an old version of torch (<2.9.0)");
81+
}
9282
return Error::Ok;
9383
}
9484

@@ -184,7 +174,7 @@ class ET_EXPERIMENTAL CudaBackend final
184174
std::string weights_blob_key =
185175
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
186176
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
187-
if (buffer_res.ok()) {
177+
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
188178
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
189179
const void* weights_blob = buffer_res->data();
190180
// Feed the weights blob into the container. Under the hood it's copying

0 commit comments

Comments
 (0)