diff --git a/src/io/indexed_recordio_split.h b/src/io/indexed_recordio_split.h index 6fca63e4d..8c34ed794 100644 --- a/src/io/indexed_recordio_split.h +++ b/src/io/indexed_recordio_split.h @@ -38,6 +38,9 @@ class IndexedRecordIOSplitter : public InputSplitBase { this->ResetPartition(rank, nsplit); } + bool IsTextParser(void) override { + return false; + } bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) override; bool ReadChunk(void *buf, size_t *size) override; bool NextChunk(Blob *out_chunk) override; diff --git a/src/io/input_split_base.cc b/src/io/input_split_base.cc index e817e6fcb..8ea506936 100644 --- a/src/io/input_split_base.cc +++ b/src/io/input_split_base.cc @@ -175,6 +175,8 @@ void InputSplitBase::InitInputFileInfo(const std::string& uri, } size_t InputSplitBase::Read(void *ptr, size_t size) { + const bool is_text_parser = this->IsTextParser(); + if (fs_ == NULL) { return 0; } @@ -191,7 +193,9 @@ size_t InputSplitBase::Read(void *ptr, size_t size) { offset_curr_ += n; if (nleft == 0) break; if (n == 0) { - buf[0] = '\n'; ++buf; --nleft; // insert a newline between files + if (is_text_parser) { + buf[0] = '\n'; ++buf; --nleft; // insert a newline between files + } if (offset_curr_ != file_offset_[file_ptr_ + 1]) { LOG(ERROR) << "curr=" << offset_curr_ << ",begin=" << offset_begin_ @@ -226,12 +230,18 @@ bool InputSplitBase::ReadChunk(void *buf, size_t *size) { max_size - olen); nread += olen; if (nread == 0) return false; - if (nread == olen) { // add extra newline to handle files with NOEOL - char *bufptr = reinterpret_cast(buf); - bufptr[nread] = '\n'; - nread++; + if (this->IsTextParser()) { + if (nread == olen) { // add extra newline to handle files with NOEOL + char *bufptr = reinterpret_cast(buf); + bufptr[nread] = '\n'; + nread++; + } + } else { + if (nread != max_size) { + *size = nread; + return true; + } } - const char *bptr = reinterpret_cast(buf); // return the last position where a record starts const char *bend = this->FindLastRecordBegin(bptr, bptr + nread); diff --git a/src/io/input_split_base.h b/src/io/input_split_base.h index 334d0d78f..0679eec5f 100644 --- a/src/io/input_split_base.h +++ b/src/io/input_split_base.h @@ -92,6 +92,12 @@ class InputSplitBase : public InputSplit { * false if the chunk is already finishes its life */ virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) = 0; + /*! + * \brief query whether this object is a text parser + * \return true if this object represents a text parser; false if it represents + * a binary parser + */ + virtual bool IsTextParser(void) = 0; /*! * \brief fill the given * chunk with new data without using internal diff --git a/src/io/line_split.h b/src/io/line_split.h index 9962b6322..d0d525e13 100644 --- a/src/io/line_split.h +++ b/src/io/line_split.h @@ -27,6 +27,9 @@ class LineSplitter : public InputSplitBase { this->ResetPartition(rank, nsplit); } + bool IsTextParser(void) { + return true; + } virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); protected: virtual size_t SeekRecordBegin(Stream *fi); diff --git a/src/io/recordio_split.h b/src/io/recordio_split.h index afc033856..63386b8c0 100644 --- a/src/io/recordio_split.h +++ b/src/io/recordio_split.h @@ -29,6 +29,9 @@ class RecordIOSplitter : public InputSplitBase { this->ResetPartition(rank, nsplit); } + bool IsTextParser(void) { + return false; + } virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk); protected: diff --git a/test/unittest/.gitignore b/test/unittest/.gitignore index 1a0e29fa0..5bfd3288d 100644 --- a/test/unittest/.gitignore +++ b/test/unittest/.gitignore @@ -1 +1,3 @@ dmlc_unittest +build_config.h +cifar10_val.rec diff --git a/test/unittest/build_config.h.in b/test/unittest/build_config.h.in new file mode 100644 index 000000000..c31d345d7 --- /dev/null +++ b/test/unittest/build_config.h.in @@ -0,0 +1 @@ +#cmakedefine CMAKE_CURRENT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" diff --git a/test/unittest/sample.rec b/test/unittest/sample.rec new file mode 100644 index 000000000..d950c211b Binary files /dev/null and b/test/unittest/sample.rec differ diff --git a/test/unittest/unittest_inputsplit.cc b/test/unittest/unittest_inputsplit.cc index adcb2ae0d..d18e9ceb3 100644 --- a/test/unittest/unittest_inputsplit.cc +++ b/test/unittest/unittest_inputsplit.cc @@ -10,8 +10,10 @@ #include #include -static inline void CountDimensions(dmlc::Parser* parser, - size_t* out_num_row, size_t* out_num_col) { +namespace { + +inline void CountDimensions(dmlc::Parser* parser, + size_t* out_num_row, size_t* out_num_col) { size_t num_row = 0; size_t num_col = 0; while (parser->Next()) { @@ -26,6 +28,14 @@ static inline void CountDimensions(dmlc::Parser* parser, *out_num_col = num_col; } +struct RecordIOHeader { + uint32_t flag; + float label; + uint64_t image_id[2]; +}; + +} // namespace anonymous + TEST(InputSplit, test_split_csv_noeol) { size_t num_row, num_col; { @@ -133,3 +143,43 @@ TEST(InputSplit, test_split_libsvm_distributed) { } } } + +#ifdef DMLC_UNIT_TESTS_USE_CMAKE +/* Don't run the following when CMake is not used */ + +#include + +TEST(InputSplit, test_recordio) { + dmlc::TemporaryDirectory tempdir; + + std::unique_ptr source( + dmlc::InputSplit::Create(CMAKE_CURRENT_SOURCE_DIR "/sample.rec", 0, 1, "recordio")); + + source->BeforeFirst(); + dmlc::InputSplit::Blob rec; + char* content; + RecordIOHeader header; + size_t content_size; + + int idx = 1; + + while (source->NextRecord(&rec)) { + ASSERT_GT(rec.size, sizeof(header)); + std::memcpy(&header, rec.dptr, sizeof(header)); + content = reinterpret_cast(rec.dptr) + sizeof(header); + content_size = rec.size - sizeof(header); + + std::string expected; + for (int i = 0; i < 10; ++i) { + expected += std::to_string(idx) + "\n"; + } + + ASSERT_EQ(header.label, static_cast(idx % 2)); + ASSERT_EQ(header.image_id[0], idx); + ASSERT_EQ(std::string(content, content_size), expected); + + ++idx; + } +} + +#endif // DMLC_UNIT_TESTS_USE_CMAKE