Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 53 additions & 91 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,132 +120,78 @@ class InferenceSession::Impl {
return Status::OK();
}

template <typename T>
common::Status Load(const T& model_uri) {
common::Status Load(std::function<common::Status(std::shared_ptr<Model>&)> loader, const std::string& event_name) {
Status status = Status::OK();
auto tp = session_profiler_.StartTime();
try {
std::lock_guard<std::mutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED,
"This session already contains a loaded model.");
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_uri, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
status = loader(p_tmp_model);
ORT_RETURN_IF_ERROR(status);

model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));
status = DoPostLoadProcessing(*model_);
ORT_RETURN_IF_ERROR(status);

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_uri", tp);
return common::Status::OK();
}

common::Status Load(const ModelProto& model_proto) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using model_proto";
std::lock_guard<std::mutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, event_name, tp);
return status;
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_proto, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;
template <typename T>
common::Status Load(const T& model_uri) {
auto loader = [this, &model_uri](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(model_uri, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));
return Load(loader, "model_loading_uri");
}

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
common::Status Load(const ModelProto& model_proto) {
auto loader = [this, &model_proto](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
return Status::OK();
return Load(loader, "model_loading_proto");
}

common::Status Load(std::unique_ptr<ModelProto> p_model_proto) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using model_proto";
std::lock_guard<std::mutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
return onnxruntime::Model::Load(std::move(p_model_proto), model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(std::move(p_model_proto), p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_proto", tp);
return Status::OK();
return Load(loader, "model_loading_proto");
}

common::Status Load(std::istream& model_istream) {
auto tp = session_profiler_.StartTime();
try {
LOGS(*session_logger_, INFO) << "Loading model using istream";
std::lock_guard<std::mutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}

auto loader = [this, &model_istream](std::shared_ptr<onnxruntime::Model>& model) {
ModelProto model_proto;
const bool result = model_proto.ParseFromIstream(&model_istream);
if (!result) {
return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, "Failed to load model because protobuf parsing failed.");
return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF,
"Failed to load model because protobuf parsing failed.");
}

std::shared_ptr<onnxruntime::Model> p_tmp_model;
ORT_RETURN_IF_ERROR(onnxruntime::Model::Load(model_proto, p_tmp_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr));
model_ = p_tmp_model;

ORT_RETURN_IF_ERROR(DoPostLoadProcessing(*model_.get()));

// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
};

LOGS(*session_logger_, INFO) << "Model successfully loaded.";
} catch (const std::exception& ex) {
return Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
return Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
}
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_loading_istream", tp);
return common::Status::OK();
return Load(loader, "model_loading_istream");
}

static common::Status TransformGraph(onnxruntime::Graph& graph,
Expand Down Expand Up @@ -1157,6 +1103,22 @@ common::Status InferenceSession::Initialize() {
return impl_->Initialize();
}

common::Status InferenceSession::Initialize(std::shared_ptr<Model>& model_in) {
if (!model_in) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "model_in does not contain a Model instance.");
}

// make sure graph is resolved.
Status status = model_in->MainGraph().Resolve();
ORT_RETURN_IF_ERROR(status);

status = impl_->Load([&](std::shared_ptr<Model>& model) { model = model_in; return Status::OK(); },
"initialize_with_model");
ORT_RETURN_IF_ERROR(status);

return impl_->Initialize();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds,
const std::vector<std::string>& output_names,
std::vector<MLValue>* p_fetches) {
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ class InferenceSession {
*/
common::Status Initialize();

/**
* Initializes with a Model instance. Initialization includes but is not
* limited to graph transformations, construction of kernels, etc.
* @param model Model to use. InferenceSession::Load must not have been called previously.
* @return OK if success. Error information if failure.
*/
common::Status Initialize(std::shared_ptr<Model>& model);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we ensure that the model is created with the exact same version of protobuf?

@pranavsharma pranavsharma Jan 2, 2019

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scott, FYI: We intentionally removed this method from the interface after RS5 due to protobuf version mismatch issues. We'd added it initially to facilitate faster dev cycles for the Windows folks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what alternate approach is best to avoid current ugly necessity to have to serialize a dynamically created Model and reload it in order to be able to execute it? Internal usage only.

e.g. #168 constant_folding.cc dynamically creates a model but is forced to serialize it to use the current API.

Do we need to refactor to split out internal aspects from InferenceSession so that those can be called directly in this sort of scenario instead of using the public InferenceSession API?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from this PR. Constant folding is going to be refactored and shouldn't need an API change in InferenceSession.


/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
Loading