Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine channel test #7946

Merged
merged 11 commits into from
Jan 31, 2018
Merged
2 changes: 1 addition & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)

cc_test(variable_test SRCS variable_test.cc)

cc_library(threadpool SRCS threadpool.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)

cc_library(scope SRCS scope.cc DEPS glog threadpool)
Expand Down
59 changes: 55 additions & 4 deletions paddle/framework/channel_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,64 @@ limitations under the License. */

#include "paddle/framework/channel.h"

#include <chrono>
#include <thread>

#include "gtest/gtest.h"


using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;

TEST(Channel, MakeAndClose) {
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
{
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int>*>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int>*>(ch), nullptr);
CloseChannel(ch);
}
}

TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
for (size_t i = 0; i < buffer_size; ++i) {
ch->Send(&i); // should not block
}

size_t out;
for (size_t i = 0; i < buffer_size; ++i) {
ch->Receive(&out); // should not block
EXPECT_EQ(out, i);
}
CloseChannel(ch);
}

TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);

Channel<int>* ch = MakeChannel<int>(10);
size_t sum = 0;
std::thread t([&](){
// Try to write more than buffer size.
for (size_t i = 0; i < 2*buffer_size; ++i) {
ch->Send(&i); // should not block
sum += i;
}
});

std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
EXPECT_EQ(sum, 45U);
CloseChannel(ch);
}
3 changes: 2 additions & 1 deletion paddle/framework/details/buffered_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <mutex>

#include "paddle/framework/channel.h"
#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {
Expand All @@ -40,7 +41,7 @@ class Buffered : public paddle::framework::Channel<T> {
std::condition_variable full_cond_var_;
std::deque<T> channel_;

Buffered(size_t cap) : cap_(cap) {}
Buffered(size_t cap) : cap_(cap) { PADDLE_ENFORCE_GT(cap, 0); }
virtual ~Buffered();

void NotifyAllSenders(std::unique_lock<std::mutex>*);
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/framework/threadpool.h"

#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {

Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ limitations under the License. */
#include <thread>
#include <vector>

#include "paddle/platform/enforce.h"
#include "paddle/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN

namespace paddle {
namespace framework {
Expand Down