diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 4e6e745cf..065f21213 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -59,15 +59,16 @@ def model_provider(pre_process=True, post_process=True): attention_mask = torch.tril(torch.ones( (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view( 1, 1, args.seq_length, args.seq_length) - + # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) if args.fp16: attention_mask = attention_mask.half() elif args.bf16: attention_mask = attention_mask.bfloat16() - - args.attn_mask = attention_mask + + # must be bool or the training crashes expecting bool, but getting Half + args.attn_mask = attention_mask.to(torch.bool) else: model = GPTModel(