Skip to content

Commit

Permalink
Support replacing the internal stream pool.
Browse files Browse the repository at this point in the history
  • Loading branch information
ndryden committed Mar 4, 2021
1 parent c267d7c commit d2ebe52
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
11 changes: 11 additions & 0 deletions include/aluminum/cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <utility>
#include <sstream>
#include <functional>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
Expand Down Expand Up @@ -149,6 +150,16 @@ void release_cuda_event(cudaEvent_t event);
cudaStream_t get_internal_stream();
/** Get a specific internal stream. */
cudaStream_t get_internal_stream(size_t id);
/**
* Replace the internal stream pool with user-provided streams.
*
* stream_getter may be called an arbitrary number of times and should
* return the streams to use in the pool.
*
* This is meant to help interface with external applications that
* need Aluminum to use their streams for everything.
*/
void replace_internal_streams(std::function<cudaStream_t()> stream_getter);

/** Return whether stream memory operations are supported. */
bool stream_memory_operations_supported();
Expand Down
21 changes: 19 additions & 2 deletions src/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ constexpr int num_internal_streams = 5;
cudaStream_t internal_streams[num_internal_streams];
// Whether stream memory operations are supported.
bool stream_mem_ops_supported = false;
// Whether we're using external streams (these are not freed).
bool using_external_streams = false;
}

void init(int&, char**&) {
Expand Down Expand Up @@ -89,8 +91,10 @@ void finalize() {
for (auto&& event : cuda_events) {
AL_CHECK_CUDA(cudaEventDestroy(event));
}
for (int i = 0; i < num_internal_streams; ++i) {
AL_CHECK_CUDA(cudaStreamDestroy(internal_streams[i]));
if (!using_external_streams) {
for (int i = 0; i < num_internal_streams; ++i) {
AL_CHECK_CUDA(cudaStreamDestroy(internal_streams[i]));
}
}
}

Expand Down Expand Up @@ -120,6 +124,19 @@ cudaStream_t get_internal_stream(size_t id) {
return internal_streams[id];
}

void replace_internal_streams(std::function<cudaStream_t()> stream_getter) {
// Clean up our streams if needed.
if (!using_external_streams) {
for (int i = 0; i < num_internal_streams; ++i) {
AL_CHECK_CUDA(cudaStreamDestroy(internal_streams[i]));
}
}
for (int i = 0; i < num_internal_streams; ++i) {
internal_streams[i] = stream_getter();
}
using_external_streams = true;
}

bool stream_memory_operations_supported() {
return stream_mem_ops_supported;
}
Expand Down

0 comments on commit d2ebe52

Please sign in to comment.