Skip to content

Commit a309b6b

Browse files
authored
[Thrust] Use pointer to tls pool to prevent creating new pool (#16856)
1 parent 0594994 commit a309b6b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/runtime/contrib/thrust/thrust.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
5454
this->workspace_size = workspace->shape[0];
5555
} else {
5656
// Fallback to thrust TLS caching allocator if workspace is not provided.
57-
thrust_pool_ = thrust::mr::tls_disjoint_pool(
57+
thrust_pool_ = &thrust::mr::tls_disjoint_pool(
5858
thrust::mr::get_global_resource<thrust::device_memory_resource>(),
5959
thrust::mr::get_global_resource<thrust::mr::new_delete_resource>());
6060
}
@@ -67,20 +67,20 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource<void*> {
6767
<< " bytes.";
6868
return result;
6969
}
70-
return thrust_pool_.do_allocate(bytes, alignment).get();
70+
return thrust_pool_->do_allocate(bytes, alignment).get();
7171
}
7272

7373
void do_deallocate(void* p, size_t bytes, size_t alignment) override {
7474
if (workspace != nullptr) {
7575
// No-op
7676
} else {
77-
thrust_pool_.do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment);
77+
thrust_pool_->do_deallocate(thrust::device_memory_resource::pointer(p), bytes, alignment);
7878
}
7979
}
8080

8181
thrust::mr::disjoint_unsynchronized_pool_resource<thrust::device_memory_resource,
82-
thrust::mr::new_delete_resource>
83-
thrust_pool_;
82+
thrust::mr::new_delete_resource>* thrust_pool_ =
83+
nullptr;
8484

8585
void* workspace = nullptr;
8686
size_t workspace_size = 0;

0 commit comments

Comments
 (0)