Skip to content

Commit

Permalink
Improving text memmap generated index files error messages (NVIDIA#6093)
Browse files Browse the repository at this point in the history
* 1. Improved testing of text memmap generated index files, and improved error message when files are missing.

Signed-off-by: Micha Livne <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. Improved error messages to help debugging failure cases of text memmap.

Signed-off-by: Micha Livne <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Micha Livne <[email protected]>
Co-authored-by: Micha Livne <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and titu1994 committed Mar 24, 2023
1 parent 43566d0 commit 7463b0a
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,22 @@ def __getitem__(self, idx):

# fetch sample from memmap

sample = self._fetch_sample_from_memmap(mdata, i, j)
try:
sample = self._fetch_sample_from_memmap(mdata, i, j)
except Exception as e:
logging.error(f"Error while fetching sample from memmap: {e}")
logging.error(f"file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}")
raise e

# parse raw text (e.g., tokenize)
data = self._build_data_from_text(sample)
try:
data = self._build_data_from_text(sample)
except Exception as e:
logging.error(
f"Error while building data from text, possible issue with sample expected format (see offending sample below): {e}"
)
logging.error(f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}")
raise e

return data

Expand Down Expand Up @@ -195,7 +208,7 @@ def load_file(self, fn):
# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')

if os.path.exists(idx_fn + ".npy"):
if _index_file_exists(idx_fn):
# load index file into memory map
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode='r')
# test for header
Expand All @@ -219,7 +232,9 @@ def load_file(self, fn):
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
)
else:
raise ValueError(f'Memory Map for {fn} is not found')
raise ValueError(
f'Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}'
)

return (mdata, midx)

Expand Down Expand Up @@ -281,12 +296,20 @@ def _build_data_from_text(self, text):
return json.loads(text)


def _index_file_exists(idx_fn):
"""Helper function to test if index file exists"""
if os.path.exists(idx_fn + ".npy") and os.path.exists(idx_fn + ".info"):
return True
else:
return False


def _build_memmap_index_files(newline_int, build_index_fn, fn):
"""Helper function to build an index file"""
idx_fn = f"{fn}.{__idx_suffix__}"

# create data map
if os.path.exists(idx_fn + ".npy"):
if _index_file_exists(idx_fn):
return False
else:
logging.info(f"Building indexing for fn = {fn}")
Expand Down

0 comments on commit 7463b0a

Please sign in to comment.