diff --git a/runtime/core/memory_allocator.h b/runtime/core/memory_allocator.h index ec9315a5c2b..8613b9ac647 100644 --- a/runtime/core/memory_allocator.h +++ b/runtime/core/memory_allocator.h @@ -63,7 +63,7 @@ class MemoryAllocator { /** * Allocates `size` bytes of memory. * - * @param[in] size Number of memory chunks to allocate. + * @param[in] size Number of bytes to allocate. * @param[in] alignment Minimum alignment for the returned pointer. Must be a * power of 2. * diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 0a3a64a13c4..59a00ebd3af 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1013,11 +1013,14 @@ Error Method::execute_instruction() { EXECUTORCH_SCOPE_PROF("OPERATOR_CALL"); internal::EventTracerProfileScope event_tracer_scope = internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL"); - // TODO(T147221312): Also expose the temp allocator and tensor resizer - // via the context. - KernelRuntimeContext context(event_tracer_); + // TODO(T147221312): Also expose tensor resizer via the context. + // The temp_allocator passed can be null, but calling allocate_temp will + // fail + KernelRuntimeContext context( + event_tracer_, memory_manager_->temp_allocator()); auto args = chain.argument_lists_[step_state_.instr_idx]; chain.kernels_[step_state_.instr_idx](context, args.data()); + // We reset the temp_allocator after the switch statement err = context.failure_state(); if (err != Error::Ok) { // We know that instr_args_as_KernelCall is non-null because it was diff --git a/runtime/kernel/kernel_runtime_context.h b/runtime/kernel/kernel_runtime_context.h index 8b51ca5ed08..7317315b529 100644 --- a/runtime/kernel/kernel_runtime_context.h +++ b/runtime/kernel/kernel_runtime_context.h @@ -10,6 +10,8 @@ #include #include +#include +#include #include namespace torch { @@ -24,10 +26,21 @@ namespace executor { class KernelRuntimeContext { public: /** - * Construct a new kernel runtime context along with an optional event tracer. + * Construct a new kernel runtime context. + * + * KernelRuntimeContext does not take ownership + * of these pointers, so they must outlive the context instance. + * + * @param[in] event_tracer The optional EventTracer to use for + * profiling/debugging + * @param[in] temp_allocator The optional MemoryAllocator used to allocate + * temporary memory for the kernel. If not provided, an error will be + * returned when calling allocate_temp. */ - KernelRuntimeContext(EventTracer* event_tracer = nullptr) - : event_tracer_(event_tracer) {} + KernelRuntimeContext( + EventTracer* event_tracer = nullptr, + MemoryAllocator* temp_allocator = nullptr) + : event_tracer_(event_tracer), temp_allocator_(temp_allocator) {} /** * Tells the runtime that the kernel call has failed. Prefer this over * ET_CHECK_*(), which fatally panics the process/system. @@ -60,12 +73,37 @@ class KernelRuntimeContext { return event_tracer_; } - // TODO(T147221312): Add a way to allocate temporary memory. + /** + * Allocates temporary memory that will be freed when the kernel returns. This + * returns a pointer to the allocated memory or an error if the allocation + * fails. + * + * @param[in] size Number of bytes to allocate. + * @param[in] alignment Minimum alignment for the returned pointer. Must be a + * power of 2. + * + * @returns A result object containing either a pointer to the allocated + * memory or an error to indicate failure + */ + Result allocate_temp( + size_t size, + size_t alignment = MemoryAllocator::kDefaultAlignment) { + ET_CHECK_OR_RETURN_ERROR( + temp_allocator_ != nullptr, NotFound, "No temp allocator provided"); + void* temp_memory = temp_allocator_->allocate(size, alignment); + ET_CHECK_OR_RETURN_ERROR( + temp_memory != nullptr, + MemoryAllocationFailed, + "Failed to allocate temp memory. Bytes requested: %zu", + size); + return temp_memory; + } // TODO(T147221312): Add a way to resize a tensor. private: EventTracer* event_tracer_ = nullptr; + MemoryAllocator* temp_allocator_ = nullptr; Error failure_state_ = Error::Ok; }; diff --git a/runtime/kernel/targets.bzl b/runtime/kernel/targets.bzl index a8f9eb50525..0bf45321dc9 100644 --- a/runtime/kernel/targets.bzl +++ b/runtime/kernel/targets.bzl @@ -55,6 +55,7 @@ def define_common_targets(): exported_deps = [ "//executorch/runtime/core:core", "//executorch/runtime/platform:platform", + "//executorch/runtime/core:memory_allocator", "//executorch/runtime/core:event_tracer" + aten_suffix, # TODO(T147221312): This will eventually depend on exec_aten # once KernelRuntimeContext support tensor resizing, which is diff --git a/runtime/kernel/test/kernel_runtime_context_test.cpp b/runtime/kernel/test/kernel_runtime_context_test.cpp index 7147dc2a169..15709d52bff 100644 --- a/runtime/kernel/test/kernel_runtime_context_test.cpp +++ b/runtime/kernel/test/kernel_runtime_context_test.cpp @@ -15,6 +15,8 @@ using namespace ::testing; using torch::executor::Error; using torch::executor::KernelRuntimeContext; +using torch::executor::MemoryAllocator; +using torch::executor::Result; class KernelRuntimeContextTest : public ::testing::Test { public: @@ -23,6 +25,17 @@ class KernelRuntimeContextTest : public ::testing::Test { } }; +class TestMemoryAllocator : public MemoryAllocator { + public: + TestMemoryAllocator(uint32_t size, uint8_t* base_address) + : MemoryAllocator(size, base_address), last_seen_alignment(0) {} + void* allocate(size_t size, size_t alignment) override { + last_seen_alignment = alignment; + return MemoryAllocator::allocate(size, alignment); + } + size_t last_seen_alignment; +}; + TEST_F(KernelRuntimeContextTest, FailureStateDefaultsToOk) { KernelRuntimeContext context; @@ -47,3 +60,43 @@ TEST_F(KernelRuntimeContextTest, FailureStateReflectsFailure) { context.fail(Error::Ok); EXPECT_EQ(context.failure_state(), Error::Ok); } + +TEST_F(KernelRuntimeContextTest, FailureNoMemoryAllocatorProvided) { + KernelRuntimeContext context; + Result allocated_memory = context.allocate_temp(4); + EXPECT_EQ(allocated_memory.error(), Error::NotFound); +} + +TEST_F(KernelRuntimeContextTest, SuccessfulMemoryAllocation) { + constexpr size_t temp_memory_allocator_pool_size = 4; + auto temp_memory_allocator_pool = + std::make_unique(temp_memory_allocator_pool_size); + MemoryAllocator temp_allocator( + temp_memory_allocator_pool_size, temp_memory_allocator_pool.get()); + KernelRuntimeContext context(nullptr, &temp_allocator); + Result allocated_memory = context.allocate_temp(4); + EXPECT_EQ(allocated_memory.ok(), true); +} + +TEST_F(KernelRuntimeContextTest, FailureMemoryAllocationInsufficientSpace) { + constexpr size_t temp_memory_allocator_pool_size = 4; + auto temp_memory_allocator_pool = + std::make_unique(temp_memory_allocator_pool_size); + MemoryAllocator temp_allocator( + temp_memory_allocator_pool_size, temp_memory_allocator_pool.get()); + KernelRuntimeContext context(nullptr, &temp_allocator); + Result allocated_memory = context.allocate_temp(8); + EXPECT_EQ(allocated_memory.error(), Error::MemoryAllocationFailed); +} + +TEST_F(KernelRuntimeContextTest, MemoryAllocatorAlignmentPassed) { + constexpr size_t temp_memory_allocator_pool_size = 4; + auto temp_memory_allocator_pool = + std::make_unique(temp_memory_allocator_pool_size); + TestMemoryAllocator temp_allocator( + temp_memory_allocator_pool_size, temp_memory_allocator_pool.get()); + KernelRuntimeContext context(nullptr, &temp_allocator); + Result allocated_memory = context.allocate_temp(4, 2); + EXPECT_EQ(allocated_memory.ok(), true); + EXPECT_EQ(temp_allocator.last_seen_alignment, 2); +}