Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import multiprocessing
import os
import sys
import functools
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
Expand Down Expand Up @@ -53,8 +54,9 @@ def tokenize(self, *text):
return text

class Encoder(object):
def __init__(self, args):
def __init__(self, args, chunksize):
self.args = args
self.chunksize = chunksize

def initializer(self):
# Use Encoder class as a container for global data
Expand All @@ -74,8 +76,12 @@ def initializer(self):

else:
Encoder.splitter = IdentitySplitter()
Encoder.chunk_index = -1

def encode(self, json_line):
def encode(self, json_line, semaphore: multiprocessing.Semaphore):
Encoder.chunk_index = (Encoder.chunk_index + 1) % self.chunksize
if Encoder.chunk_index == 0:
semaphore.acquire()
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
Expand Down Expand Up @@ -126,6 +132,8 @@ def get_args():
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
group.add_argument('--max-sample-in-memory', type=int, default=100,
help='Maximum sample stored in memory awaiting to be stored in preprocessed dataset.')
args = parser.parse_args()
args.keep_empty = False

Expand All @@ -151,10 +159,19 @@ def main():
if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)

encoder = Encoder(args)
chunksize = 25
encoder = Encoder(args, chunksize)
tokenizer = build_tokenizer(args)

# Necessary to share a semaphore across processes
m = multiprocessing.Manager()
semaphore = m.Semaphore(args.max_sample_in_memory) # we're going to release/acquire by chunks

# This helps prevent deadlock
assert args.max_sample_in_memory >= args.workers * chunksize

pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
encoded_docs = pool.imap_unordered(functools.partial(encoder.encode, semaphore=semaphore), fin, chunksize)
#encoded_docs = map(encoder.encode, fin)

level = "document"
Expand All @@ -181,6 +198,8 @@ def main():
print("Time to startup:", startup_end - startup_start)

for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
if i % chunksize == 0:
semaphore.release()
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
Expand Down