diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index 004f80a27e..78336b4994 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -257,6 +257,14 @@ void StandardGpuResourcesImpl::setDefaultStream( if (prevStream != stream) { streamWait({stream}, {prevStream}); } +#if defined USE_NVIDIA_RAFT + // delete the raft handle for this device, which will be initialized + // with the updated stream during any subsequent calls to getRaftHandle + auto it2 = raftHandles_.find(device); + if (it2 != raftHandles_.end()) { + raftHandles_.erase(it2); + } +#endif } userDefaultStreams_[device] = stream; @@ -275,6 +283,14 @@ void StandardGpuResourcesImpl::revertDefaultStream(int device) { streamWait({newStream}, {prevStream}); } +#if defined USE_NVIDIA_RAFT + // delete the raft handle for this device, which will be initialized + // with the updated stream during any subsequent calls to getRaftHandle + auto it2 = raftHandles_.find(device); + if (it2 != raftHandles_.end()) { + raftHandles_.erase(it2); + } +#endif } userDefaultStreams_.erase(device);