Skip to content

Commit b382d32

Browse files
committed
Make thrust_allocator deallocate safe in multi-device setting
Previously, the user had to arrange that the device active when a thrust_allocator object was created was also active when allocate and deallocate was called. This is hard to manage if exceptions are thrown. Instead, save the active device on construction and ensure that it is active when calling deallocate and deallocate. This means that device_vector is safe to destruct with RAII semantics in a multi-device setting. Add tests of this facility, and correct the parameterization usage in the other thrust allocator tests such that we actually check the MRs we're parameterizing over. - Closes rapidsai#1527
1 parent 7d7d65a commit b382d32

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

include/rmm/mr/device/thrust_allocator_adaptor.hpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include <rmm/cuda_device.hpp>
1920
#include <rmm/detail/thrust_namespace.h>
2021
#include <rmm/mr/device/per_device_resource.hpp>
2122
#include <rmm/resource_ref.hpp>
@@ -39,6 +40,9 @@ namespace rmm::mr {
3940
* allocate objects of a specific type `T`, but can be freely rebound to other
4041
* types.
4142
*
43+
* The allocator records the current cuda device and may only be used with a backing
44+
* `device_async_resource_ref` valid for the same device.
45+
*
4246
* @tparam T The type of the objects that will be allocated by this allocator
4347
*/
4448
template <typename T>
@@ -92,7 +96,7 @@ class thrust_allocator : public thrust::device_malloc_allocator<T> {
9296
*/
9397
template <typename U>
9498
thrust_allocator(thrust_allocator<U> const& other)
95-
: _mr(other.resource()), _stream{other.stream()}
99+
: _mr(other.resource()), _stream{other.stream()}, _device{other._device}
96100
{
97101
}
98102

@@ -104,6 +108,7 @@ class thrust_allocator : public thrust::device_malloc_allocator<T> {
104108
*/
105109
pointer allocate(size_type num)
106110
{
111+
cuda_set_device_raii dev{_device};
107112
return thrust::device_pointer_cast(
108113
static_cast<T*>(_mr.allocate_async(num * sizeof(T), _stream)));
109114
}
@@ -117,6 +122,7 @@ class thrust_allocator : public thrust::device_malloc_allocator<T> {
117122
*/
118123
void deallocate(pointer ptr, size_type num)
119124
{
125+
cuda_set_device_raii dev{_device};
120126
return _mr.deallocate_async(thrust::raw_pointer_cast(ptr), num * sizeof(T), _stream);
121127
}
122128

@@ -143,6 +149,7 @@ class thrust_allocator : public thrust::device_malloc_allocator<T> {
143149
private:
144150
cuda_stream_view _stream{};
145151
rmm::device_async_resource_ref _mr{rmm::mr::get_current_device_resource()};
152+
cuda_device_id _device{get_current_cuda_device()};
146153
};
147154
/** @} */ // end of group
148155
} // namespace rmm::mr

tests/mr/device/thrust_allocator_tests.cu

+19
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include "mr_ref_test.hpp"
1818

19+
#include <rmm/cuda_device.hpp>
1920
#include <rmm/cuda_stream_view.hpp>
21+
#include <rmm/detail/error.hpp>
2022
#include <rmm/device_vector.hpp>
2123
#include <rmm/mr/device/per_device_resource.hpp>
2224
#include <rmm/mr/device/thrust_allocator_adaptor.hpp>
@@ -36,19 +38,36 @@ struct allocator_test : public mr_ref_test {};
3638

3739
TEST_P(allocator_test, first)
3840
{
41+
rmm::mr::set_current_device_resource(this->mr.get());
3942
auto const num_ints{100};
4043
rmm::device_vector<int> ints(num_ints, 1);
4144
EXPECT_EQ(num_ints, thrust::reduce(ints.begin(), ints.end()));
4245
}
4346

4447
TEST_P(allocator_test, defaults)
4548
{
49+
rmm::mr::set_current_device_resource(this->mr.get());
4650
rmm::mr::thrust_allocator<int> allocator(rmm::cuda_stream_default);
4751
EXPECT_EQ(allocator.stream(), rmm::cuda_stream_default);
4852
EXPECT_EQ(allocator.get_upstream_resource(),
4953
rmm::device_async_resource_ref{rmm::mr::get_current_device_resource()});
5054
}
5155

56+
TEST_P(allocator_test, multi_device)
57+
{
58+
if (rmm::get_num_cuda_devices() < 2) { GTEST_SKIP() << "Needs at least two devices"; }
59+
cuda_set_device_raii with_device{rmm::get_current_cuda_device()};
60+
rmm::cuda_stream stream{};
61+
// make allocator on device-0
62+
rmm::mr::thrust_allocator<int> allocator(stream.view(), this->ref);
63+
auto const size{100};
64+
EXPECT_NO_THROW([&]() {
65+
auto vec = rmm::device_vector<int>(size, allocator);
66+
// Destruct with device-1 active
67+
RMM_CUDA_TRY(cudaSetDevice(1));
68+
}());
69+
}
70+
5271
INSTANTIATE_TEST_CASE_P(ThrustAllocatorTests,
5372
allocator_test,
5473
::testing::Values(mr_factory{"CUDA", &make_cuda},

0 commit comments

Comments
 (0)