You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When enabling rope position_embedding_type and sequence_parallel in NeMo GPT pretraining, the training will give runtime error:
File "/workspace/NeMo/nemo/collections/nlp/modules/common/megatron/rotary_pos_embedding.py", line 59, in apply_rotary_pos_emb
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
RuntimeError: The size of tensor a (2048) must match the size of tensor b (1024) at non-singleton dimension 0
Steps/Code to reproduce bug
Follow the tutorial to train GPT model with model.position_embedding_type=rope and model.sequence_parallel=True
Expected behavior
Expect there will be no runtime error when both options are enabled.
Environment overview (please complete the following information)
barry-jin
changed the title
[Bug] RoPE positional embedding doesn't work when sequence_parallel is enabled
[Bug] RoPE positional embedding doesn't work when sequence_parallel is enabled for GPT models
Mar 8, 2023
Describe the bug
When enabling
rope
position_embedding_type andsequence_parallel
in NeMo GPT pretraining, the training will give runtime error:Steps/Code to reproduce bug
Follow the tutorial to train GPT model with
model.position_embedding_type=rope
andmodel.sequence_parallel=True
Expected behavior
Expect there will be no runtime error when both options are enabled.
Environment overview (please complete the following information)
docker pull
&docker run
commands usedEnvironment details
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
Additional context
Add any other context about the problem here.
Example: GPU model
The text was updated successfully, but these errors were encountered: