diff --git a/extension/data_loader/shared_ptr_data_loader.h b/extension/data_loader/shared_ptr_data_loader.h index 551ab4d498c..d791c132803 100644 --- a/extension/data_loader/shared_ptr_data_loader.h +++ b/extension/data_loader/shared_ptr_data_loader.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -43,6 +44,24 @@ class SharedPtrDataLoader final : public executorch::runtime::DataLoader { return executorch::runtime::FreeableBuffer( static_cast(data_.get()) + offset, size, /*free_fn=*/nullptr); } + + ET_NODISCARD + executorch::runtime::Error load_into( + size_t offset, + size_t size, + const DataLoader::SegmentInfo& segment_info, + void* buffer) const override { + ET_CHECK_OR_RETURN_ERROR( + offset + size <= size_, + InvalidArgument, + "offset %zu + size %zu exceeds buffer size %zu", + offset, + size, + size_); + + std::memcpy(buffer, static_cast(data_.get()) + offset, size); + return executorch::runtime::Error::Ok; + } ET_NODISCARD executorch::runtime::Result size() const override { return size_; diff --git a/extension/data_loader/test/shared_ptr_data_loader_test.cpp b/extension/data_loader/test/shared_ptr_data_loader_test.cpp index 62d71ae0560..af735fadf04 100644 --- a/extension/data_loader/test/shared_ptr_data_loader_test.cpp +++ b/extension/data_loader/test/shared_ptr_data_loader_test.cpp @@ -140,3 +140,36 @@ TEST_F(SharedPtrDataLoaderTest, OutOfBoundsLoadFails) { EXPECT_NE(fb.error(), Error::Ok); } } + +// Unit test to check that SharedPtrDataLoader::load_into copies the correct data. +TEST(SharedPtrDataLoaderTest, LoadIntoCopiesCorrectData) { + std::vector source_data = {10, 20, 30, 40, 50}; + + // Wrap the source data in a shared_ptr without taking ownership. + auto data_ptr = std::shared_ptr(source_data.data(), [](void*) {}); + SharedPtrDataLoader loader(data_ptr, source_data.size()); + + uint8_t buffer[3] = {0}; + + // Load 3 bytes starting from offset 1 (expecting values 20, 30, 40). + auto err = loader.load_into(1, 3, DataLoader::SegmentInfo{}, buffer); + + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(buffer[0], 20); + EXPECT_EQ(buffer[1], 30); + EXPECT_EQ(buffer[2], 40); +} + +// Unit test to verify that SharedPtrDataLoader::load_into handles out-of-bounds requests. +TEST(SharedPtrDataLoaderTest, LoadIntoRejectsOutOfBoundsAccess) { + std::vector source_data = {10, 20, 30, 40, 50}; + auto data_ptr = std::shared_ptr(source_data.data(), [](void*) {}); + SharedPtrDataLoader loader(data_ptr, source_data.size()); + + uint8_t buffer[3] = {0}; + + // This should fail because offset + size = 4 + 3 = 7 > 5 (size of data). + auto err = loader.load_into(4, 3, DataLoader::SegmentInfo{}, buffer); + + EXPECT_EQ(err, Error::InvalidArgument); +} \ No newline at end of file