Today init_attention_mask is called before PP does microbatch split
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L421
I haven't tested, but this will likely cause wrong block mask being applied to non-first microbatches.
E.g. consider a local batch
batch = [
b0,
b1,
b2,
b3,
]
Block mask mask is created for batch, but after (say, size-1) microbatching mask will be used for 4 different smaller batches [b0], [b1], [b2], [b3]. For [b0] it might be fine, but for the others the mask is wrong.
The solution could be either of
- do
init_attention_mask after microbatching.
- when PP does microbatching, modify the block mask as well.
cc @fegin @H-Huang @drisspg