Skip to content

Commit

Permalink
[xla:gpu] Amortize the cost of acquiring mutex for stream borrowing
Browse files Browse the repository at this point in the history
Make the StreamBorrowerWithPriority interface accept a num_stream argument, so that it allows borrowing multiple streams with a single mutex lock.

PiperOrigin-RevId: 542364157
  • Loading branch information
anlunx authored and pull[bot] committed Jun 27, 2023
1 parent 68f3f27 commit 3350500
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 19 deletions.
18 changes: 18 additions & 0 deletions tensorflow/compiler/xla/service/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/service/compiler.h"
Expand Down Expand Up @@ -118,6 +119,23 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor,
return stream_pools_.at(executor)->BorrowStream(executor, priority);
}

StatusOr<std::vector<StreamPool::Ptr>> Backend::BorrowStreams(
int device_ordinal, int num_streams, se::StreamPriority priority) {
absl::MutexLock l(&mu_);
TF_ASSIGN_OR_RETURN(auto executor, stream_executor(device_ordinal));
if (!stream_pools_.contains(executor)) {
stream_pools_.emplace(executor, std::make_unique<StreamPool>());
}

std::vector<StreamPool::Ptr> ptrs;
for (int i = 0; i < num_streams; i++) {
StreamPool::Ptr ptr =
stream_pools_.at(executor)->BorrowStream(executor, priority);
ptrs.push_back(std::move(ptr));
}
return ptrs;
}

Backend::Backend(se::Platform* platform, Compiler* compiler,
absl::Span<se::StreamExecutor* const> stream_executors,
TransferManager* transfer_manager,
Expand Down
15 changes: 10 additions & 5 deletions tensorflow/compiler/xla/service/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,20 @@ class Backend {
StatusOr<StreamPool::Ptr> BorrowStream(
se::StreamExecutor* executor,
se::StreamPriority priority = se::StreamPriority::Default);
StatusOr<std::vector<StreamPool::Ptr>> BorrowStreams(
int device_ordinal, int num_streams,
se::StreamPriority priority = se::StreamPriority::Default);

// Returns a function to borrow a stream with a given priority,
// as `BorrowStream` above does.
// Returns a function to borrow streams with a given priority,
// as `BorrowStreams` above does.
// Purely for convenience, the caller could rather make this anonymous
// function itself.
std::function<StatusOr<StreamPool::Ptr>(int, se::StreamPriority)>
std::function<StatusOr<std::vector<StreamPool::Ptr>>(int, int,
se::StreamPriority)>
StreamBorrowerWithPriority() {
return [this](int device_ordinal, se::StreamPriority priority) {
return BorrowStream(device_ordinal, priority);
return [this](int device_ordinal, int num_streams,
se::StreamPriority priority) {
return BorrowStreams(device_ordinal, num_streams, priority);
};
}

Expand Down
11 changes: 7 additions & 4 deletions tensorflow/compiler/xla/service/gpu/runtime/concurrent_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ absl::Status ConcurrentRegionStatus::StartConcurrentRegion(
se::StreamExecutor* executor = run_options_->stream()->parent();

// Stream borrowing should only happen in the first call to this function.
for (int i = borrowed_streams_.size(); i < num_borrowed_streams_; i++) {
TF_ASSIGN_OR_RETURN(StreamPool::Ptr ptr,
run_options_->BorrowStream(executor->device_ordinal()));
borrowed_streams_.push_back(std::move(ptr));
if (borrowed_streams_.empty()) {
TF_ASSIGN_OR_RETURN(std::vector<StreamPool::Ptr> borrowed_streams,
run_options_->BorrowStreams(executor->device_ordinal(),
num_borrowed_streams_));
for (StreamPool::Ptr& stream : borrowed_streams) {
borrowed_streams_.push_back(std::move(stream));
}
}

// Switch borrowed streams into capture mode
Expand Down
36 changes: 26 additions & 10 deletions tensorflow/compiler/xla/service/service_executable_run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include <functional>
#include <utility>
#include <vector>

#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
Expand All @@ -31,18 +32,19 @@ namespace xla {
class ServiceExecutableRunOptions {
public:
// Defines the interface of the stream borrower function pointer
// with the first argument being the device ordinal and second
// argument being the priority of the stream.
using StreamBorrower =
std::function<StatusOr<StreamPool::Ptr>(int, se::StreamPriority)>;
// with the first argument being the device ordinal, the second
// argument being the number of streams to borrow, and the third
// argument being the priority of the streams.
using StreamBorrower = std::function<StatusOr<std::vector<StreamPool::Ptr>>(
int, int, se::StreamPriority)>;

ServiceExecutableRunOptions()
: ServiceExecutableRunOptions(ExecutableRunOptions()) {}

explicit ServiceExecutableRunOptions(ExecutableRunOptions run_options,
StreamBorrower borrow_stream = nullptr)
StreamBorrower stream_borrower = nullptr)
: run_options_(std::move(run_options)),
borrow_stream_(std::move(borrow_stream)) {}
stream_borrower_(std::move(stream_borrower)) {}

// Returns reference or pointer to `ExecutableRunOptions` member.
const ExecutableRunOptions& run_options() const { return run_options_; }
Expand All @@ -60,14 +62,28 @@ class ServiceExecutableRunOptions {
StatusOr<StreamPool::Ptr> BorrowStream(
int device_ordinal,
se::StreamPriority priority = se::StreamPriority::Default) const {
return borrow_stream_
? borrow_stream_(device_ordinal, priority)
: Status(absl::StatusCode::kUnimplemented, "No stream cache");
if (!stream_borrower_) {
return Status(absl::StatusCode::kUnimplemented, "No stream borrower");
}

TF_ASSIGN_OR_RETURN(
std::vector<StreamPool::Ptr> streams,
stream_borrower_(device_ordinal, /*num_streams=*/1, priority));
StreamPool::Ptr stream = std::move(streams.back());
return stream;
}

StatusOr<std::vector<StreamPool::Ptr>> BorrowStreams(
int device_ordinal, int num_streams,
se::StreamPriority priority = se::StreamPriority::Default) const {
return stream_borrower_
? stream_borrower_(device_ordinal, num_streams, priority)
: Status(absl::StatusCode::kUnimplemented, "No stream borrower");
}

private:
ExecutableRunOptions run_options_;
StreamBorrower borrow_stream_;
StreamBorrower stream_borrower_;
};

} // namespace xla
Expand Down

0 comments on commit 3350500

Please sign in to comment.