Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions extension/data_loader/shared_ptr_data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <cstring>
#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
Expand Down Expand Up @@ -43,6 +44,24 @@ class SharedPtrDataLoader final : public executorch::runtime::DataLoader {
return executorch::runtime::FreeableBuffer(
static_cast<uint8_t*>(data_.get()) + offset, size, /*free_fn=*/nullptr);
}

ET_NODISCARD
executorch::runtime::Error load_into(
size_t offset,
size_t size,
const DataLoader::SegmentInfo& segment_info,
void* buffer) const override {
ET_CHECK_OR_RETURN_ERROR(
offset + size <= size_,
InvalidArgument,
"offset %zu + size %zu exceeds buffer size %zu",
offset,
size,
size_);

std::memcpy(buffer, static_cast<uint8_t*>(data_.get()) + offset, size);
return executorch::runtime::Error::Ok;
}

ET_NODISCARD executorch::runtime::Result<size_t> size() const override {
return size_;
Expand Down
33 changes: 33 additions & 0 deletions extension/data_loader/test/shared_ptr_data_loader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,36 @@ TEST_F(SharedPtrDataLoaderTest, OutOfBoundsLoadFails) {
EXPECT_NE(fb.error(), Error::Ok);
}
}

// Unit test to check that SharedPtrDataLoader::load_into copies the correct data.
TEST(SharedPtrDataLoaderTest, LoadIntoCopiesCorrectData) {
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};

// Wrap the source data in a shared_ptr without taking ownership.
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
SharedPtrDataLoader loader(data_ptr, source_data.size());

uint8_t buffer[3] = {0};

// Load 3 bytes starting from offset 1 (expecting values 20, 30, 40).
auto err = loader.load_into(1, 3, DataLoader::SegmentInfo{}, buffer);

EXPECT_EQ(err, Error::Ok);
EXPECT_EQ(buffer[0], 20);
EXPECT_EQ(buffer[1], 30);
EXPECT_EQ(buffer[2], 40);
}

// Unit test to verify that SharedPtrDataLoader::load_into handles out-of-bounds requests.
TEST(SharedPtrDataLoaderTest, LoadIntoRejectsOutOfBoundsAccess) {
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
SharedPtrDataLoader loader(data_ptr, source_data.size());

uint8_t buffer[3] = {0};

// This should fail because offset + size = 4 + 3 = 7 > 5 (size of data).
auto err = loader.load_into(4, 3, DataLoader::SegmentInfo{}, buffer);

EXPECT_EQ(err, Error::InvalidArgument);
}