Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
66e61e7
add preprocessing of HF datasets directly
RaymondLi0 Nov 7, 2022
a79988a
modify max seq-length from 2048 to 8192
RaymondLi0 Nov 8, 2022
db3809b
add missing cases in fused kernels
RaymondLi0 Nov 14, 2022
acda627
add longer sequence lengths in fused kernels test
RaymondLi0 Nov 14, 2022
d59c85b
larger MAX_TOKENS_TO_OOM
RaymondLi0 Nov 14, 2022
7b0cee2
use custom barrier with device_ids
Nov 18, 2022
93cb6a0
add HF tokenizer
Nov 22, 2022
9f2c442
add special tokens in HF tokenizer
RaymondLi0 Nov 22, 2022
9fe3bcb
fix vocab_size in _HFTokenizer
RaymondLi0 Nov 22, 2022
6982c4e
fix: initialize tokenizer with TokenizerFromFile
Nov 22, 2022
0348b3a
Merge branch 'preprocess-hf' of github.com:bigcode-project/Megatron-L…
Nov 22, 2022
4f060a2
fix: add special_tokens dict for FIM
Nov 22, 2022
332e8db
load attention-head-type from checkpoint
Nov 23, 2022
0717dab
attention-head-type defaults to None instead
Nov 23, 2022
96daa55
use detokenize method un text_generation
Nov 24, 2022
2d36c14
add mqa conversion to huggingface
RaymondLi0 Dec 2, 2022
760eed9
remove config and tokenizer save
RaymondLi0 Dec 2, 2022
baa7b3b
add Readme
RaymondLi0 Dec 2, 2022
2ceaf70
add some documentation
RaymondLi0 Dec 2, 2022
66beabe
add push to hub logic
Dec 2, 2022
de83476
add docs
Dec 2, 2022
1b7c96f
convert_checkpoint as function, push starting from last pushed iteration
RaymondLi0 Dec 2, 2022
5cb878f
add iter_interval argument
RaymondLi0 Dec 2, 2022
ab1c4cc
use relative imports in modeling file
RaymondLi0 Dec 8, 2022
60fbd1d
update readme
RaymondLi0 Dec 15, 2022
732396a
remove debug prints
Jan 24, 2023
9d80f8a
more precise error for attention_type/head_type values
Jan 24, 2023
7457e32
attention-head-type defaults to multihead again to avoid breaking pre…
RaymondLi0 Jan 24, 2023
cdbcfc9
documentation on the --tokenizer-file argument
RaymondLi0 Jan 30, 2023
94306d1
add missing newlines
RaymondLi0 Jan 30, 2023
506fbd4
revert barrier() to torch.distributed.barrier()
RaymondLi0 Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,8 @@ def _add_data_args(parser):
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--tokenizer-file', type=str, default=None,
help='Path to the tokenizer.json file. Used for the TokenizerFromFile[...] tokenizers')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
Expand Down Expand Up @@ -899,7 +901,9 @@ def _add_data_args(parser):
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'GPT2BPETokenizerWithFIM'],
'GPT2BPETokenizerWithFIM',
'TokenizerFromFile',
'TokenizerFromFileWithFIM'],
help='What type of tokenizer to use.')
group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
Expand Down
3 changes: 2 additions & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _compare(arg_name, old_arg_name=None):
# with alibi we can change `max_position_embeddings`
if args.position_embedding_type != PositionEmbeddingType.alibi:
_compare('max_position_embeddings')
if args.vocab_file:
if args.vocab_file or args.tokenizer_file:
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
Expand Down Expand Up @@ -504,6 +504,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('max_position_embeddings')
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
_set_arg('attention_head_type')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
Expand Down
4 changes: 2 additions & 2 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __getitem__(self, idx):
assert (fim_rate <= 1 and fim_rate >= 0), "FIM rate must be a probability 0 <= rate <= 1"

eod = self.tokenizer.eod
pad = self.tokenizer.tokenizer.special_tokens[FIM_PAD]
pad = self.tokenizer.special_tokens[FIM_PAD]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local GPT2Tokenizer implementation has this special_tokens attribute, but not the PreTrainedTokenizerFast from the transformers library.
So the code here instead rely on the wrappers around these tokenizers :

self.special_tokens = {

self.special_tokens = self.tokenizer.special_tokens


segment_breaks = np.argwhere(sample == eod) # split sample by document

Expand Down Expand Up @@ -495,7 +495,7 @@ def permute(sample, np_rng, args, tokenizer, truncate_or_pad=True):
"""
fim_rate = args.fim_rate

suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (tokenizer.tokenizer.special_tokens[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD])
suffix_tok_id, prefix_tok_id, middle_tok_id, pad_tok_id = (tokenizer.special_tokens[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD])

if np_rng.binomial(1, fim_rate): # sample bernoulli dist

Expand Down
18 changes: 15 additions & 3 deletions megatron/fused_kernels/scaled_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ void dispatch_scaled_softmax_forward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -523,6 +523,10 @@ void dispatch_scaled_softmax_forward(
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
case 13: // 8192
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Did you double-check that it works as intended? Things can get tricky when kernels grow too big (registers, shared memory, etc. Don't know if relevant here.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit that I did not double-check this. I relied on the code from: https://github.com/NVIDIA/Megatron-LM/pull/243/files
I tried to train a model with seq-length 4092 and it seemed to work fine

scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
break;
default:
break;
}
Expand All @@ -541,7 +545,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -617,6 +621,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 13: // 8192
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
Expand All @@ -634,7 +642,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 );
if (key_seq_len == 0) {
return;
} else {
Expand Down Expand Up @@ -709,6 +717,10 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 13: // 8192
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;

default:
break;
Expand Down
2 changes: 1 addition & 1 deletion megatron/fused_kernels/scaled_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
Expand Down
2 changes: 1 addition & 1 deletion megatron/fused_kernels/scaled_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ torch::Tensor fwd_cuda(
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);

// Output
Expand Down
20 changes: 18 additions & 2 deletions megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -415,6 +415,14 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand All @@ -431,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride,
int attn_batches)
{
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 );
if (softmax_elements == 0) {
return;
} else {
Expand Down Expand Up @@ -506,6 +514,14 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
break;
default:
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ torch::Tensor fwd_cuda(
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
TORCH_INTERNAL_ASSERT(seq_len <= 8192);

// Output
auto act_options = input.options().requires_grad(false);
Expand Down
12 changes: 6 additions & 6 deletions megatron/fused_kernels/tests/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def test_load_fused_kernels():


def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
bert = BertModel.from_pretrained("bert-base-cased", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
"hi hi hi hi hi hi hi hi hi hi hi hi hi" * 256 # 32 * 256
)

tokens = tokenizer(
Expand Down Expand Up @@ -121,11 +121,11 @@ def test_fused_softmax():


def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
gpt = GPT2Model.from_pretrained("gpt2", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
"hi hi hi hi hi hi hi" * 256 # 24 * 256
)

tokens = tokenizer(
Expand Down Expand Up @@ -221,11 +221,11 @@ def test_fused_upper_triangle_mask_softmax():


def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
bert = BertModel.from_pretrained("bert-base-cased", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
"hi hi hi hi hi hi hi hi hi hi hi hi hi" * 256 # 32
)

tokens = tokenizer(
Expand Down
2 changes: 1 addition & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def set_global_variables(args):
set_args(args)

_build_num_microbatches_calculator(args)
if args.vocab_file:
if args.vocab_file or args.tokenizer_file:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
Expand Down
2 changes: 1 addition & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _compile_dependencies():
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
custom_kernel_constraint = seq_len > 16 and seq_len <=8192 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and 16 < sk <= 8192 # sk must be 16 ~ 8192
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
if 0 <= sk <= 8192:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
Expand Down
4 changes: 3 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,10 @@ def __init__(self, init_method,
2 * projection_size,
gather_output=False,
init_method=init_method)
else:
elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multiquery':
raise NotImplementedError("Multiquery attention not implemented for cross-attention.")
else:
raise ValueError(f"Invalid attention arguments: {attention_type}, {self.attention_head_type}")

if self.attention_head_type == 'multihead':
self.core_attention = CoreAttention(self.layer_number,
Expand Down
2 changes: 1 addition & 1 deletion megatron/text_generation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .sampling import sample
from .beam_utils import BeamHypotheses

MAX_TOKENS_TO_OOM = 12000 # (rprenger) Perfect value depends on hardware and network
MAX_TOKENS_TO_OOM = 128000 # (rprenger) Perfect value depends on hardware and network

def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Expand Down
9 changes: 5 additions & 4 deletions megatron/text_generation/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def detokenize_generations(tokens_gpu_tensor,
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
# word = tokenizer.tokenizer.decoder[token]
# word = bytearray(
# [tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
# 'utf-8', errors='replace')
word = tokenizer.detokenize([token])
words.append(word)
prompts_plus_generations_segments.append(words)

Expand Down
Loading