Skip to content

Commit f3172ac

Browse files
authored
Support layer parallelism in transformer application (#2420)
This PR adds the capability to support layer parallelism in transformers, variable-length version of The Pile pretokenized dataset, updates to the LBANN graph visualizer script, and some minor tweaks to weights layer.
1 parent 6011d03 commit f3172ac

File tree

18 files changed

+660
-128
lines changed

18 files changed

+660
-128
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from tqdm import trange
2+
from multiprocessing import Pool
3+
import numpy as np
4+
import pickle
5+
6+
7+
class Processor:
8+
9+
def __init__(self, total_threads: int):
10+
self.threads = total_threads
11+
12+
def __call__(self, tid: int):
13+
import thepile as dataset
14+
num_samples = dataset.num_val_samples()
15+
filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val.bin'
16+
len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val-seqlen.bin'
17+
18+
with open(filename, 'ab') as fp:
19+
with open(len_filename, 'ab') as slfp:
20+
for i in trange(num_samples):
21+
text = dataset.dataset_val[i]['text']
22+
tokenized = dataset.tokenize(text)
23+
sample = np.array(tokenized, dtype=np.uint16)
24+
sample_len = np.array([len(sample)], dtype=np.uint32)
25+
sample.tofile(fp)
26+
sample_len.tofile(slfp)
27+
28+
print('Done')
29+
30+
31+
if __name__ == '__main__':
32+
threads = 1
33+
with Pool(threads) as pool:
34+
pool.map(Processor(threads), range(threads))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from tqdm import trange
2+
from multiprocessing import Pool
3+
import numpy as np
4+
import os
5+
import argparse
6+
from pathlib import Path
7+
8+
9+
class Processor:
10+
11+
def __init__(self, total_threads: int):
12+
self.threads = total_threads
13+
14+
def __call__(self, tid: int):
15+
import thepile as dataset
16+
num_samples = dataset.num_train_samples()
17+
np.random.seed(20231023)
18+
indices = np.random.permutation(num_samples)
19+
local_samples = num_samples // self.threads
20+
offset = tid * local_samples
21+
# Add remainder
22+
if tid == self.threads - 1:
23+
local_samples += num_samples % self.threads
24+
section = indices[offset:offset + local_samples]
25+
filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-pretokenized-{tid:02d}-of-{self.threads}.bin'
26+
len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-seqlen-{tid:02d}-of-{self.threads}.bin'
27+
28+
# Create file
29+
if not os.path.isfile(filename):
30+
Path(filename).touch()
31+
if not os.path.isfile(len_filename):
32+
Path(len_filename).touch()
33+
34+
sz = os.path.getsize(len_filename)
35+
assert sz % 4 == 0
36+
sequences_processed = sz // 4
37+
print(tid, ': Size in bytes:', sz, '. Sequences processed:',
38+
sequences_processed)
39+
40+
with open(filename, 'ab') as fp:
41+
with open(len_filename, 'ab') as slfp:
42+
for i in trange(sequences_processed,
43+
section.shape[0],
44+
desc=f'Thread {tid}'):
45+
text = dataset.dataset_train[int(section[i])]['text']
46+
sample = dataset.tokenize(text)
47+
sample = np.array(sample, dtype=np.uint16)
48+
sample.tofile(fp)
49+
sample_len = np.array([len(sample)], dtype=np.uint32)
50+
sample_len.tofile(slfp)
51+
52+
53+
if __name__ == '__main__':
54+
parser = argparse.ArgumentParser()
55+
56+
parser.add_argument('-j',
57+
action='store',
58+
default=0,
59+
type=int,
60+
help='Threads (default 0 = number of cores)')
61+
parser.add_argument('-t',
62+
action='store',
63+
default=0,
64+
type=int,
65+
help='Total Chunks (default 0 = number of threads)')
66+
parser.add_argument('-o',
67+
action='store',
68+
default=0,
69+
type=int,
70+
help='Chunk offset (default 0)')
71+
args = parser.parse_args()
72+
73+
threads = args.j or os.cpu_count()
74+
total_chunks = args.t or threads
75+
offset = args.o
76+
assert offset + threads <= total_chunks
77+
with Pool(threads) as pool:
78+
pool.map(Processor(total_chunks), range(offset, offset + threads))

Diff for: applications/nlp/transformer/datasets/thepile.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_train_sample(index):
9191

9292
def get_val_sample(index):
9393
"""Token indices for a data sample from the validation set."""
94-
text = dataset_train[index]['text']
94+
text = dataset_val[index]['text']
9595
tokenized = tokenize(text)
9696

9797
# Trim long sequences, left-pad short sequences
@@ -120,3 +120,12 @@ def sample_dims():
120120

121121
def vocab_size():
122122
return tokenizer.get_vocab_size()
123+
124+
125+
if __name__ == '__main__':
126+
print('Training samples:', num_train_samples())
127+
print('Validation samples:', num_val_samples())
128+
print('Training sample 101:')
129+
print(tokenizer.decode(get_train_sample(101)))
130+
print('Validation sample 233:')
131+
print(tokenizer.decode(get_val_sample(233)))

Diff for: applications/nlp/transformer/datasets/thepile_pretokenized.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
The Pile dataset, stored as pre-tokenized binary files for optimized processing.
2+
The Pile dataset, stored as pre-tokenized, pre-packed binary files for optimized processing.
33
"""
44
import os
55
import os.path
@@ -10,7 +10,9 @@
1010
# Options
1111
# ----------------------------------------------
1212

13-
sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
13+
# Sequence length is hardcoded to 512 in the pre-packed binary dataset.
14+
# To use other sequence lengths, see ``thepile_pretokenized_varlen.py``
15+
sequence_length = 512
1416

1517
# ----------------------------------------------
1618
# Setup
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
The Pile dataset, stored as pre-tokenized binary files for optimized processing.
3+
"""
4+
import os
5+
import os.path
6+
7+
import numpy as np
8+
# ----------------------------------------------
9+
# Options
10+
# ----------------------------------------------
11+
12+
sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
13+
14+
# ----------------------------------------------
15+
# Setup
16+
# ----------------------------------------------
17+
18+
# Load the datasets
19+
data_dir = os.getenv('THE_PILE_DATA_DIR',
20+
'/p/vast1/data/datasets/the-pile-pretokenized')
21+
dataset_train = np.memmap(os.path.join(data_dir, 'train.bin'),
22+
dtype=np.uint16,
23+
mode='r')
24+
sample_lengths_train = np.fromfile(os.path.join(data_dir, 'train-seqlen.bin'),
25+
dtype=np.uint32).astype(np.uint64)
26+
sample_offsets_train = np.zeros_like(sample_lengths_train)
27+
sample_offsets_train[1:] = np.cumsum(sample_lengths_train)[:-1]
28+
dataset_val = np.memmap(os.path.join(data_dir, 'val.bin'),
29+
dtype=np.uint16,
30+
mode='r')
31+
sample_lengths_val = np.fromfile(os.path.join(data_dir, 'val-seqlen.bin'),
32+
dtype=np.uint32).astype(np.uint64)
33+
sample_offsets_val = np.zeros_like(sample_lengths_val)
34+
sample_offsets_val[1:] = np.cumsum(sample_lengths_val)[:-1]
35+
36+
# Uses the definition from the GPT-NeoX-20B tokenizer
37+
pad_index = 1 # '<|padding|>'
38+
_vocab_size = 50277
39+
40+
# ----------------------------------------------
41+
# Sample access functions
42+
# ----------------------------------------------
43+
44+
45+
def trim_and_pad(sample, random: bool):
46+
# Trim long sequences
47+
if len(sample) > sequence_length:
48+
if random:
49+
pos = np.random.rand()
50+
offset = (len(sample) - sequence_length + 1) * pos
51+
offset = int(np.floor(offset))
52+
sample = sample[offset:offset + sequence_length]
53+
else:
54+
sample = sample[0:sequence_length]
55+
56+
# Left-pad short sequences
57+
if len(sample) < sequence_length:
58+
sample_pad = np.full(sequence_length, pad_index, dtype=np.int32)
59+
if len(sample) > 0:
60+
sample_pad[-len(sample):] = sample
61+
return sample_pad
62+
63+
return sample
64+
65+
66+
def get_train_sample(index: int):
67+
sample = np.copy(
68+
dataset_train[sample_offsets_train[index]:sample_offsets_train[index] +
69+
sample_lengths_train[index]]).astype(np.int32)
70+
return trim_and_pad(sample, True)
71+
72+
73+
def get_val_sample(index):
74+
sample = np.copy(
75+
dataset_val[sample_offsets_val[index]:sample_offsets_val[index] +
76+
sample_lengths_val[index]]).astype(np.int32)
77+
return trim_and_pad(sample, False)
78+
79+
80+
def num_train_samples():
81+
return sample_lengths_train.shape[0]
82+
83+
84+
def num_val_samples():
85+
return sample_lengths_val.shape[0]
86+
87+
88+
def sample_dims():
89+
return (sequence_length, )
90+
91+
92+
def vocab_size():
93+
return _vocab_size
94+
95+
96+
if __name__ == '__main__':
97+
print('Training samples:', num_train_samples())
98+
print('Validation samples:', num_val_samples())
99+
from tokenizers import Tokenizer
100+
tokenizer = Tokenizer.from_file(
101+
os.path.join(data_dir, '20B_tokenizer.json'))
102+
print('Training sample 101:')
103+
print(tokenizer.decode(get_train_sample(101)))
104+
print('Validation sample 233:')
105+
print(tokenizer.decode(get_val_sample(233)))

Diff for: applications/nlp/transformer/modeling.py

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
8787
transformer, args)
8888
parallelism.apply_ffn_model_parallelism(transformer, args)
8989
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
90+
parallelism.apply_layer_parallelism(transformer, args)
9091

9192
# Run through transformer
9293
result = transformer(encoder_input, decoder_input, sequence_length - 1)
@@ -124,6 +125,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
124125
)
125126

126127
parallelism.apply_fsdp_allweights(result, args)
128+
parallelism.apply_layer_parallelism_postamble(result, args)
127129
return result
128130

129131

@@ -186,6 +188,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
186188
transformer, args)
187189
parallelism.apply_ffn_model_parallelism(transformer, args)
188190
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
191+
parallelism.apply_layer_parallelism(transformer, args)
189192

190193
# Run through transformer with the same sequence
191194
result = transformer(decoder_input, decoder_input, sequence_length)
@@ -227,6 +230,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
227230
)
228231

229232
parallelism.apply_fsdp_allweights(result, args)
233+
parallelism.apply_layer_parallelism_postamble(result, args)
230234
return result
231235

232236

0 commit comments

Comments
 (0)