Skip to content

Commit

Permalink
Do not add newline for RecordIO InputSplit (#471)
Browse files Browse the repository at this point in the history
* Do not add newline for RecordIO InputSplit

Adding newlines caused apache/mxnet#12800.

TODO: Add a test case

* Add a test case

* Cache value of IsTextParser() in a tight loop to avoid virtual function call

* Address reviewer comment: use tiny example

* Fix typo
  • Loading branch information
hcho3 authored and Kumar committed Nov 5, 2018
1 parent 5ba446c commit aab67a9
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/io/indexed_recordio_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class IndexedRecordIOSplitter : public InputSplitBase {
this->ResetPartition(rank, nsplit);
}

bool IsTextParser(void) override {
return false;
}
bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) override;
bool ReadChunk(void *buf, size_t *size) override;
bool NextChunk(Blob *out_chunk) override;
Expand Down
22 changes: 16 additions & 6 deletions src/io/input_split_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ void InputSplitBase::InitInputFileInfo(const std::string& uri,
}

size_t InputSplitBase::Read(void *ptr, size_t size) {
const bool is_text_parser = this->IsTextParser();

if (fs_ == NULL) {
return 0;
}
Expand All @@ -191,7 +193,9 @@ size_t InputSplitBase::Read(void *ptr, size_t size) {
offset_curr_ += n;
if (nleft == 0) break;
if (n == 0) {
buf[0] = '\n'; ++buf; --nleft; // insert a newline between files
if (is_text_parser) {
buf[0] = '\n'; ++buf; --nleft; // insert a newline between files
}
if (offset_curr_ != file_offset_[file_ptr_ + 1]) {
LOG(ERROR) << "curr=" << offset_curr_
<< ",begin=" << offset_begin_
Expand Down Expand Up @@ -226,12 +230,18 @@ bool InputSplitBase::ReadChunk(void *buf, size_t *size) {
max_size - olen);
nread += olen;
if (nread == 0) return false;
if (nread == olen) { // add extra newline to handle files with NOEOL
char *bufptr = reinterpret_cast<char*>(buf);
bufptr[nread] = '\n';
nread++;
if (this->IsTextParser()) {
if (nread == olen) { // add extra newline to handle files with NOEOL
char *bufptr = reinterpret_cast<char*>(buf);
bufptr[nread] = '\n';
nread++;
}
} else {
if (nread != max_size) {
*size = nread;
return true;
}
}

const char *bptr = reinterpret_cast<const char*>(buf);
// return the last position where a record starts
const char *bend = this->FindLastRecordBegin(bptr, bptr + nread);
Expand Down
6 changes: 6 additions & 0 deletions src/io/input_split_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class InputSplitBase : public InputSplit {
* false if the chunk is already finishes its life
*/
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk) = 0;
/*!
* \brief query whether this object is a text parser
* \return true if this object represents a text parser; false if it represents
* a binary parser
*/
virtual bool IsTextParser(void) = 0;
/*!
* \brief fill the given
* chunk with new data without using internal
Expand Down
3 changes: 3 additions & 0 deletions src/io/line_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class LineSplitter : public InputSplitBase {
this->ResetPartition(rank, nsplit);
}

bool IsTextParser(void) {
return true;
}
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk);
protected:
virtual size_t SeekRecordBegin(Stream *fi);
Expand Down
3 changes: 3 additions & 0 deletions src/io/recordio_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class RecordIOSplitter : public InputSplitBase {
this->ResetPartition(rank, nsplit);
}

bool IsTextParser(void) {
return false;
}
virtual bool ExtractNextRecord(Blob *out_rec, Chunk *chunk);

protected:
Expand Down
2 changes: 2 additions & 0 deletions test/unittest/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
dmlc_unittest
build_config.h
cifar10_val.rec
1 change: 1 addition & 0 deletions test/unittest/build_config.h.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#cmakedefine CMAKE_CURRENT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
Binary file added test/unittest/sample.rec
Binary file not shown.
54 changes: 52 additions & 2 deletions test/unittest/unittest_inputsplit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#include <cstdlib>
#include <gtest/gtest.h>

static inline void CountDimensions(dmlc::Parser<uint32_t>* parser,
size_t* out_num_row, size_t* out_num_col) {
namespace {

inline void CountDimensions(dmlc::Parser<uint32_t>* parser,
size_t* out_num_row, size_t* out_num_col) {
size_t num_row = 0;
size_t num_col = 0;
while (parser->Next()) {
Expand All @@ -26,6 +28,14 @@ static inline void CountDimensions(dmlc::Parser<uint32_t>* parser,
*out_num_col = num_col;
}

struct RecordIOHeader {
uint32_t flag;
float label;
uint64_t image_id[2];
};

} // namespace anonymous

TEST(InputSplit, test_split_csv_noeol) {
size_t num_row, num_col;
{
Expand Down Expand Up @@ -133,3 +143,43 @@ TEST(InputSplit, test_split_libsvm_distributed) {
}
}
}

#ifdef DMLC_UNIT_TESTS_USE_CMAKE
/* Don't run the following when CMake is not used */

#include <build_config.h>

TEST(InputSplit, test_recordio) {
dmlc::TemporaryDirectory tempdir;

std::unique_ptr<dmlc::InputSplit> source(
dmlc::InputSplit::Create(CMAKE_CURRENT_SOURCE_DIR "/sample.rec", 0, 1, "recordio"));

source->BeforeFirst();
dmlc::InputSplit::Blob rec;
char* content;
RecordIOHeader header;
size_t content_size;

int idx = 1;

while (source->NextRecord(&rec)) {
ASSERT_GT(rec.size, sizeof(header));
std::memcpy(&header, rec.dptr, sizeof(header));
content = reinterpret_cast<char*>(rec.dptr) + sizeof(header);
content_size = rec.size - sizeof(header);

std::string expected;
for (int i = 0; i < 10; ++i) {
expected += std::to_string(idx) + "\n";
}

ASSERT_EQ(header.label, static_cast<float>(idx % 2));
ASSERT_EQ(header.image_id[0], idx);
ASSERT_EQ(std::string(content, content_size), expected);

++idx;
}
}

#endif // DMLC_UNIT_TESTS_USE_CMAKE

0 comments on commit aab67a9

Please sign in to comment.