diff --git a/extension/module/module.cpp b/extension/module/module.cpp index aa750e2691e..26e74e84364 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -178,34 +178,36 @@ runtime::Result> Module::method_names() { runtime::Error Module::load_method( const std::string& method_name, + runtime::HierarchicalAllocator* planned_memory, torch::executor::EventTracer* event_tracer) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); MethodHolder method_holder; - const auto method_metadata = - ET_UNWRAP(program_->method_meta(method_name.c_str())); - const auto planned_buffersCount = - method_metadata.num_memory_planned_buffers(); - method_holder.planned_buffers.reserve(planned_buffersCount); - method_holder.planned_spans.reserve(planned_buffersCount); + if (!planned_memory) { + const auto method_metadata = + ET_UNWRAP(program_->method_meta(method_name.c_str())); + const auto planned_buffers_count = + method_metadata.num_memory_planned_buffers(); + method_holder.planned_buffers.reserve(planned_buffers_count); + method_holder.planned_spans.reserve(planned_buffers_count); - for (auto index = 0; index < planned_buffersCount; ++index) { - const auto buffer_size = - method_metadata.memory_planned_buffer_size(index).get(); - method_holder.planned_buffers.emplace_back(buffer_size); - method_holder.planned_spans.emplace_back( - method_holder.planned_buffers.back().data(), buffer_size); + for (auto index = 0; index < planned_buffers_count; ++index) { + const auto buffer_size = + method_metadata.memory_planned_buffer_size(index).get(); + method_holder.planned_buffers.emplace_back(buffer_size); + method_holder.planned_spans.emplace_back( + method_holder.planned_buffers.back().data(), buffer_size); + } + method_holder.planned_memory = + std::make_unique(runtime::Span( + method_holder.planned_spans.data(), + method_holder.planned_spans.size())); + planned_memory = method_holder.planned_memory.get(); } - method_holder.planned_memory = - std::make_unique(runtime::Span( - method_holder.planned_spans.data(), - method_holder.planned_spans.size())); method_holder.memory_manager = std::make_unique( - memory_allocator_.get(), - method_holder.planned_memory.get(), - temp_allocator_.get()); + memory_allocator_.get(), planned_memory, temp_allocator_.get()); method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), diff --git a/extension/module/module.h b/extension/module/module.h index dc7c930d7c6..d58a447fdba 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -152,6 +152,8 @@ class Module { * needed. The loaded method is cached to reuse the next time it's executed. * * @param[in] method_name The name of the method to load. + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing a method. * @param[in] event_tracer Per-method event tracer to profile/trace methods * individually. When not given, the event tracer passed to the Module * constructor is used. Otherwise, this per-method event tracer takes @@ -162,20 +164,35 @@ class Module { ET_NODISCARD runtime::Error load_method( const std::string& method_name, + runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr); + ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method( + const std::string& method_name, + torch::executor::EventTracer* event_tracer) { + return load_method(method_name, nullptr, event_tracer); + } + /** * Load the 'forward' method from the program and set up memory management if * needed. The loaded method is cached to reuse the next time it's executed. * + * @param[in] planned_memory The memory-planned buffers to use for mutable + * tensor data when executing the 'forward' method. * @param[in] event_tracer An event tracer used for tracking and logging * events. * * @returns An Error to indicate success or failure. */ ET_NODISCARD inline runtime::Error load_forward( + runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr) { - return load_method("forward", event_tracer); + return load_method("forward", planned_memory, event_tracer); + } + + ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward( + torch::executor::EventTracer* event_tracer) { + return load_forward(nullptr, event_tracer); } /**