Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
pmattione-nvidia committed Feb 28, 2025
1 parent 99b501f commit c8a8505
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
45 changes: 22 additions & 23 deletions include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ class pool_memory_resource final
* Attempts to allocate `try_size` bytes from upstream. If it fails, it iteratively reduces the
* attempted size by half until `min_size`, returning the allocated block once it succeeds.
*
* @throws rmm::bad_alloc if `min_size` bytes cannot be allocated from upstream or maximum pool
* size is exceeded.
* @throws rmm::out_of_memory if `min_size` bytes cannot be allocated from upstream or maximum
* pool size is exceeded.
*
* @param try_size The initial requested size to try allocating.
* @param min_size The minimum requested size to try allocating.
Expand All @@ -252,24 +252,34 @@ class pool_memory_resource final
*/
block_type try_to_expand(std::size_t try_size, std::size_t min_size, cuda_stream_view stream)
{
while (true) {
auto report_error = [&](const char* reason) {
RMM_LOG_ERROR("[A][Stream %s][Upstream %zuB][FAILURE maximum pool size exceeded: %s]",
rmm::detail::format_stream(stream),
min_size,
reason);
auto const msg = std::string("Maximum pool size exceeded (failed to allocate ") +
rmm::detail::format_bytes(min_size) + std::string("): ") + reason;
RMM_FAIL(msg.c_str(), rmm::out_of_memory);
};

while (try_size >= min_size) {
try {
auto block = block_from_upstream(try_size, stream);
current_pool_size_ += block.size();
return block;
} catch (std::exception const& e) {
if (try_size <= min_size) {
RMM_LOG_ERROR("[A][Stream %s][Upstream %zuB][FAILURE maximum pool size exceeded: %s]",
rmm::detail::format_stream(stream),
try_size,
e.what());
auto const msg = std::string("Maximum pool size exceeded (failed to allocate ") +
rmm::detail::format_bytes(try_size) + std::string("): ") + e.what();
RMM_FAIL(msg.c_str(), rmm::out_of_memory);
}
if (try_size == min_size) { report_error(e.what()); }
}
try_size = std::max(min_size, try_size / 2);
}

auto const max_size = maximum_pool_size_.value_or(std::numeric_limits<std::size_t>::max());
auto const msg = std::string("Not enough room to grow, current/max/try size = ") +
rmm::detail::format_bytes(pool_size()) + ", " +
rmm::detail::format_bytes(max_size) + ", " +
rmm::detail::format_bytes(min_size);
report_error(msg.c_str());
return {};
}

/**
Expand Down Expand Up @@ -311,17 +321,6 @@ class pool_memory_resource final
// time. Upon failure, attempt to back off exponentially, e.g. by half the attempted size,
// until either success or the attempt is less than the requested size.

if (maximum_pool_size_.has_value()) {
auto const max_size = maximum_pool_size_.value();
if (size > max_size) {
auto const msg = std::string("Maximum pool size exceeded (failed to allocate ") +
rmm::detail::format_bytes(size) +
std::string("): Request larger than capacity (") +
rmm::detail::format_bytes(max_size) + std::string(")");
RMM_FAIL(msg.c_str(), rmm::out_of_memory);
}
}

return try_to_expand(size_to_grow(size), size, stream);
}

Expand Down
11 changes: 11 additions & 0 deletions tests/mr/host/pinned_pool_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,16 @@ TEST(PinnedPoolTest, NonAlignedPoolSize)
rmm::logic_error);
}

TEST(PinnedPoolTest, ThrowOutOfMemory)
{
rmm::mr::pinned_memory_resource pinned_mr{};
const auto initial{0};
const auto maximum{1024};
pool_mr mr{pinned_mr, initial, maximum};
mr.allocate(1024);

EXPECT_THROW(mr.allocate(1024), rmm::out_of_memory);
}

} // namespace
} // namespace rmm::test

0 comments on commit c8a8505

Please sign in to comment.