Skip to content

Commit

Permalink
Merge branch 'main' into athitten/ptl_2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
athitten committed Apr 25, 2023
2 parents 36424ec + 4205006 commit 7c8c0f0
Showing 1 changed file with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import os
import subprocess
import time
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -1255,6 +1256,7 @@ def get_samples_mapping(
name,
binary_head,
index_mapping_dir: str = None,
samples_mapping: Any = None,
):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""

Expand All @@ -1280,8 +1282,8 @@ def get_samples_mapping(
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
# Build the indexed mapping if not exist and not provided externally.
if samples_mapping is None and torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
# Fake index mapping if missing
if (getattr(indexed_dataset, 'doc_idx', None) is None) and (getattr(indexed_dataset, 'sizes', None) is None):
make_indexed_dataset_compatibility(indexed_dataset)
Expand Down Expand Up @@ -1334,15 +1336,16 @@ def get_samples_mapping(
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
)
# Load indexed dataset.
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))
# Load indexed dataset if not given externally.
if samples_mapping is None:
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))

# Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed
if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'):
deallocate_indexed_dataset_memory(indexed_dataset)
# Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed
if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'):
deallocate_indexed_dataset_memory(indexed_dataset)

return samples_mapping

0 comments on commit 7c8c0f0

Please sign in to comment.