Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine channel test #7946

Merged
merged 11 commits into from
Jan 31, 2018
Merged
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)

cc_test(variable_test SRCS variable_test.cc)

cc_library(threadpool SRCS threadpool.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)

cc_library(scope SRCS scope.cc DEPS glog threadpool)
Expand Down
10 changes: 2 additions & 8 deletions paddle/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ class Channel {
virtual void Send(T*) = 0;
virtual void Receive(T*) = 0;
virtual size_t Cap() = 0;

// Don't delete channels; instead, call Channel::Close.
protected:
virtual void Close() = 0;
virtual ~Channel() {}
};

Expand All @@ -50,11 +48,7 @@ Channel<T>* MakeChannel(size_t buffer_size) {

template <typename T>
void CloseChannel(Channel<T>* ch) {
if (ch->Cap() > 0) {
delete dynamic_cast<details::Buffered<T>*>(ch);
} else {
delete dynamic_cast<details::UnBuffered<T>*>(ch);
}
ch->Close();
}

} // namespace framework
Expand Down
62 changes: 58 additions & 4 deletions paddle/framework/channel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,67 @@ limitations under the License. */

#include "paddle/framework/channel.h"

#include <chrono>
#include <thread>

#include "gtest/gtest.h"

using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;

TEST(Channel, MakeAndClose) {
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
{
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
}

TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
ch->Send(&i); // should not block
}

size_t out;
for (size_t i = 0; i < buffer_size; ++i) {
ch->Receive(&out); // should not block
EXPECT_EQ(out, i);
}
CloseChannel(ch);
delete ch;
}

TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
size_t sum = 0;
std::thread t([&]() {
// Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) {
ch->Send(&i); // should not block
sum += i;
}
});
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
EXPECT_EQ(sum, 45U);

Channel<int>* ch = MakeChannel<int>(10);
CloseChannel(ch);
t.join();
delete ch;
}
40 changes: 30 additions & 10 deletions paddle/framework/details/buffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex>

#include "paddle/framework/channel.h"
#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {
Expand All @@ -32,49 +33,68 @@ class Buffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();

private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::deque<T> channel_;
bool closed_;

Buffered(size_t cap) : cap_(cap) {}
virtual ~Buffered();
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
}

void NotifyAllSenders(std::unique_lock<std::mutex>*);
};

template <typename T>
void Buffered<T>::Send(T* item) {
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
}
}

template <typename T>
void Buffered<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
*item = std::move(channel_.front());
channel_.pop_front();
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
if (!closed_) {
*item = std::move(channel_.front());
channel_.pop_front();
NotifyAllSenders(&lock);
} else {
item = nullptr;
}
}

template <typename T>
void Buffered<T>::Close() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllSenders(&lock);
}

template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear();
NotifyAllSenders(&lock);
}

template <typename T>
void Buffered<T>::NotifyAllSenders(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_one();
full_cond_var_.notify_all();
}

} // namespace details
Expand Down
6 changes: 5 additions & 1 deletion paddle/framework/details/unbuffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ class UnBuffered : public paddle::framework::Channel<T> {
virtual void Send(T*);
virtual void Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();

private:
UnBuffered() {}
virtual ~UnBuffered();
};

template <typename T>
Expand All @@ -44,6 +45,9 @@ void UnBuffered<T>::Send(T* channel_element) {}
template <typename T>
void UnBuffered<T>::Receive(T*) {}

template <typename T>
void UnBuffered<T>::Close() {}

template <typename T>
UnBuffered<T>::~UnBuffered() {}

Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/framework/threadpool.h"

#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {

Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include <thread>
#include <vector>

#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN

namespace paddle {
namespace framework {
Expand Down