Skip to content

Commit

Permalink
refine close channel
Browse files Browse the repository at this point in the history
  • Loading branch information
chengduoZH committed Jan 30, 2018
1 parent be0525f commit 3e7ca5c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
6 changes: 6 additions & 0 deletions paddle/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Channel {
virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Close() = 0;

// Don't delete channels; instead, call Channel::Close.
protected:
Expand All @@ -50,6 +51,11 @@ Channel<T>* MakeChannel(size_t buffer_size) {

template <typename T>
void CloseChannel(Channel<T>* ch) {
ch->Close();
}

template <typename T>
void DeleteChannel(Channel<T>* ch) {
if (ch->Cap() > 0) {
delete dynamic_cast<details::Buffered<T>*>(ch);
} else {
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/channel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,5 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
EXPECT_EQ(sum, 45U);
CloseChannel(ch);
t.join();
}
25 changes: 19 additions & 6 deletions paddle/framework/details/buffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ class Buffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();

private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
bool close;

Buffered(size_t cap) : cap_(cap) { PADDLE_ENFORCE_GT(cap, 0); }
Buffered(size_t cap) : cap_(cap), close(false) { PADDLE_ENFORCE_GT(cap, 0); }
virtual ~Buffered();

void NotifyAllSenders(std::unique_lock<std::mutex>*);
Expand All @@ -50,10 +52,13 @@ class Buffered : public paddle::framework::Channel<T> {
template <typename T>
void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || close; });
if (!close) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
}

template <typename T>
Expand All @@ -65,17 +70,25 @@ void Buffered<T>::Receive(T* item) {
NotifyAllSenders(&lock);
}

template <typename T>
void Buffered<T>::Close() {
std::unique_lock<std::mutex> lock(mu_);
close = true;
NotifyAllSenders(&lock);
}

template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
close = true;
channel_.clear();
NotifyAllSenders(&lock);
}

template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_one();
full_cond_var_.notify_all();
}

} // namespace details
Expand Down
4 changes: 4 additions & 0 deletions paddle/framework/details/unbuffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class UnBuffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();

private:
UnBuffered() {}
Expand All @@ -41,6 +42,9 @@ class UnBuffered : public paddle::framework::Channel<T> {
template <typename T>
void UnBuffered<T>::Send(T* channel_element) {}

template <typename T>
void UnBuffered<T>::Close() {}

template <typename T>
void UnBuffered<T>::Receive(T*) {}

Expand Down

0 comments on commit 3e7ca5c

Please sign in to comment.