Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 80 additions & 4 deletions extension/data_loader/mmap_data_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ void MunmapSegment(void* context, void* data, size_t size) {
}
} // namespace

Result<FreeableBuffer> MmapDataLoader::load(
size_t offset,
size_t size,
ET_UNUSED const DataLoader::SegmentInfo& segment_info) const {
/**
* Helper for input validation.
*/
Error MmapDataLoader::validate_input(size_t offset, size_t size) const {
ET_CHECK_OR_RETURN_ERROR(
// Probably had its value moved to another instance.
fd_ >= 0,
Expand All @@ -173,6 +173,18 @@ Result<FreeableBuffer> MmapDataLoader::load(
InvalidArgument,
"Offset %zu too large for off_t",
offset);
return Error::Ok;
}

Result<FreeableBuffer> MmapDataLoader::load(
size_t offset,
size_t size,
ET_UNUSED const DataLoader::SegmentInfo& segment_info) const {
// Validate input.
auto err = validate_input(offset, size);
if (err != Error::Ok) {
return err;
}

// mmap() will fail if the size is zero.
if (size == 0) {
Expand Down Expand Up @@ -268,5 +280,69 @@ Result<size_t> MmapDataLoader::size() const {
return file_size_;
}

Error MmapDataLoader::load_into(
size_t offset,
size_t size,
ET_UNUSED const SegmentInfo& segment_info,
void* buffer) const {
ET_CHECK_OR_RETURN_ERROR(
buffer != nullptr, InvalidArgument, "Buffer is null");

// Validate input.
auto err = validate_input(offset, size);
if (err != Error::Ok) {
return err;
}

// Nothing to copy
if (size == 0) {
return Error::Ok;
}

// Find the range of pages that covers the requested region.
Range range =
get_overlapping_pages(static_cast<uintptr_t>(offset), size, page_size_);

size_t map_size = range.size;
if (range.start + map_size > file_size_) {
// Clamp to the end of the file.
//
// The Windows implementation of mmap uses CreateFileMapping which returns
// error STATUS_SECTION_TOO_BIG (0xc0000040) if we try to map past the end
// of the last page of a file mapped in as read-only.
map_size = file_size_ - range.start;
}

// Map the pages read-only. MAP_PRIVATE vs. MAP_SHARED doesn't matter since
// the data is read-only, but use PRIVATE just to further avoid accidentally
// modifying the file.
void* pages = ::mmap(
nullptr,
map_size,
PROT_READ,
MAP_PRIVATE,
fd_,
static_cast<off_t>(range.start));
ET_CHECK_OR_RETURN_ERROR(
pages != MAP_FAILED,
AccessFailed,
"Failed to map %s: mmap(..., size=%zd, ..., fd=%d, offset=0x%zx)",
file_name_,
range.size,
fd_,
range.start);

// Offset into mapped region.
const size_t map_delta = offset - range.start;

// Copy data into caller's buffer.
std::memcpy(buffer, static_cast<uint8_t*>(pages) + map_delta, size);

// Unmap mapped region.
::munmap(pages, map_size);

return Error::Ok;
}

} // namespace extension
} // namespace executorch
11 changes: 11 additions & 0 deletions extension/data_loader/mmap_data_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ class MmapDataLoader final : public executorch::runtime::DataLoader {

ET_NODISCARD executorch::runtime::Result<size_t> size() const override;

ET_NODISCARD executorch::runtime::Error validate_input(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you make this private

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, moved to private section

size_t offset,
size_t size) const;

ET_NODISCARD
executorch::runtime::Error load_into(
size_t offset,
size_t size,
ET_UNUSED const SegmentInfo& segment_info,
void* buffer) const override;

private:
MmapDataLoader(
int fd,
Expand Down
53 changes: 53 additions & 0 deletions extension/data_loader/test/mmap_data_loader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,56 @@ TEST_F(MmapDataLoaderTest, DEPRECATEDFrom) {
ASSERT_EQ(total_size.error(), Error::Ok);
EXPECT_EQ(*total_size, contents_size);
}

// Tests that load_into copies bytes correctly.
TEST_F(MmapDataLoaderTest, LoadIntoCopiesCorrectly) {
// Create a test string.
const char* test_text = "FILE_CONTENTS";
const size_t text_size = std::strlen(test_text);
TempFile tf(test_text);

// Wrap it in a loader.
Result<MmapDataLoader> mdl = MmapDataLoader::from(tf.path().c_str());
ASSERT_EQ(mdl.error(), Error::Ok);

// Destination buffer.
std::vector<uint8_t> dst(text_size);

// Call load_into()
Error err = mdl->load_into(
/*offset=*/0,
/*size=*/text_size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program),
dst.data());
ASSERT_EQ(err, Error::Ok);

// Verify memory copied correctly.
EXPECT_EQ(0, std::memcmp(dst.data(), test_text, text_size));
}

// Tests that load_into copies offset slice correctly.
TEST_F(MmapDataLoaderTest, LoadIntoCopiesOffsetCorrectly) {
// Create a test string.
const char* contents = "ABCDEFGH";
TempFile tf(contents);

// Wrap it in a loader.
Result<MmapDataLoader> mdl = MmapDataLoader::from(tf.path().c_str());
ASSERT_EQ(mdl.error(), Error::Ok);

// Copying 3 bytes starting at offset 2 = "CDE"
const size_t offset = 2;
const size_t size = 3;
uint8_t dst[size];

// Call load_into()
Error err = mdl->load_into(
offset,
size,
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program),
dst);
ASSERT_EQ(err, Error::Ok);

// Verify memory copied correctly.
EXPECT_EQ(0, std::memcmp(dst, contents + offset, size));
}
Loading