diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c9ecb6af2..143c730e9 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -20,6 +20,7 @@ import multiprocessing import os import sys +import functools sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import time @@ -53,8 +54,9 @@ def tokenize(self, *text): return text class Encoder(object): - def __init__(self, args): + def __init__(self, args, chunksize): self.args = args + self.chunksize = chunksize def initializer(self): # Use Encoder class as a container for global data @@ -74,8 +76,12 @@ def initializer(self): else: Encoder.splitter = IdentitySplitter() + Encoder.chunk_index = -1 - def encode(self, json_line): + def encode(self, json_line, semaphore: multiprocessing.Semaphore): + Encoder.chunk_index = (Encoder.chunk_index + 1) % self.chunksize + if Encoder.chunk_index == 0: + semaphore.acquire() data = json.loads(json_line) ids = {} for key in self.args.json_keys: @@ -126,6 +132,8 @@ def get_args(): help='Number of worker processes to launch') group.add_argument('--log-interval', type=int, default=100, help='Interval between progress updates') + group.add_argument('--max-sample-in-memory', type=int, default=100, + help='Maximum sample stored in memory awaiting to be stored in preprocessed dataset.') args = parser.parse_args() args.keep_empty = False @@ -151,10 +159,19 @@ def main(): if nltk_available and args.split_sentences: nltk.download("punkt", quiet=True) - encoder = Encoder(args) + chunksize = 25 + encoder = Encoder(args, chunksize) tokenizer = build_tokenizer(args) + + # Necessary to share a semaphore across processes + m = multiprocessing.Manager() + semaphore = m.Semaphore(args.max_sample_in_memory) # we're going to release/acquire by chunks + + # This helps prevent deadlock + assert args.max_sample_in_memory >= args.workers * chunksize + pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) - encoded_docs = pool.imap(encoder.encode, fin, 25) + encoded_docs = pool.imap_unordered(functools.partial(encoder.encode, semaphore=semaphore), fin, chunksize) #encoded_docs = map(encoder.encode, fin) level = "document" @@ -181,6 +198,8 @@ def main(): print("Time to startup:", startup_end - startup_start) for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + if i % chunksize == 0: + semaphore.release() total_bytes_processed += bytes_processed for key, sentences in doc.items(): if len(sentences) == 0: