Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions extension/data_loader/shared_ptr_data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/log.h>
#include <memory>
#include <cstring>

namespace executorch {
namespace extension {
Expand Down Expand Up @@ -43,11 +44,35 @@ class SharedPtrDataLoader final : public executorch::runtime::DataLoader {
return executorch::runtime::FreeableBuffer(
static_cast<uint8_t*>(data_.get()) + offset, size, /*free_fn=*/nullptr);
}

ET_NODISCARD executorch::runtime::Error load_into(
size_t offset,
size_t size,
const DataLoader::SegmentInfo& segment_info,
void* buffer) const override;

ET_NODISCARD executorch::runtime::Result<size_t> size() const override {
return size_;
}

ET_NODISCARD executorch::runtime::Error SharedPtrDataLoader::load_into(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge the declaration and implementation?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have merged the 'load_into' function declaration and implemention as suggested by you, please review it.

size_t offset,
size_t size,
const DataLoader::SegmentInfo& segment_info,
void* buffer) const {
ET_CHECK_OR_RETURN_ERROR(
offset + size <= size_,
executorch::runtime::Error::OutOfBounds,
"offset %zu + size %zu exceeds buffer size %zu",
offset,
size,
size_);

std::memcpy(buffer, static_cast<uint8_t*>(data_.get()) + offset, size);
return executorch::runtime::Error::Ok;
}


private:
const std::shared_ptr<void> data_;
const size_t size_;
Expand Down
33 changes: 33 additions & 0 deletions extension/data_loader/test/shared_ptr_data_loader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,36 @@ TEST_F(SharedPtrDataLoaderTest, OutOfBoundsLoadFails) {
EXPECT_NE(fb.error(), Error::Ok);
}
}

// Unit test to check that SharedPtrDataLoader::load_into copies the correct data.
TEST(SharedPtrDataLoaderTest, LoadIntoCopiesCorrectData) {
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};

// Wrap the source data in a shared_ptr without taking ownership.
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
SharedPtrDataLoader loader(data_ptr, source_data.size());

uint8_t buffer[3] = {0};

// Load 3 bytes starting from offset 1 (expecting values 20, 30, 40).
auto err = loader.load_into(1, 3, DataLoader::SegmentInfo{}, buffer);

EXPECT_EQ(err, Error::Ok);
EXPECT_EQ(buffer[0], 20);
EXPECT_EQ(buffer[1], 30);
EXPECT_EQ(buffer[2], 40);
}

// Unit test to verify that SharedPtrDataLoader::load_into handles out-of-bounds requests.
TEST(SharedPtrDataLoaderTest, LoadIntoRejectsOutOfBoundsAccess) {
std::vector<uint8_t> source_data = {10, 20, 30, 40, 50};
auto data_ptr = std::shared_ptr<void>(source_data.data(), [](void*) {});
SharedPtrDataLoader loader(data_ptr, source_data.size());

uint8_t buffer[3] = {0};

// This should fail because offset + size = 4 + 3 = 7 > 5 (size of data).
auto err = loader.load_into(4, 3, DataLoader::SegmentInfo{}, buffer);

EXPECT_EQ(err, Error::OutOfBounds);
}
Loading