You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The training loop should offload these tensors to the CPU right after their aggregation is finished. Especially because the logging prints will do that anyways under the hood
The text was updated successfully, but these errors were encountered:
I think the reason we keep it is because we don't want call .item() (which incurs synchronization between CPU and GPU) unless hitting a log step. I do agree that if logging is disabled / infrequent, this overhead is unnecessary. Although, may I ask what's the use case where you'd log too infrequently so that this overhead becomes unacceptable?
If logging is disabled (or very infrequent), the memory usage slowly grows because the max and average loss is kept in a list on-device: https://github.com/pytorch/torchtitan/blob/main/train.py#L353-L354
The training loop should offload these tensors to the CPU right after their aggregation is finished. Especially because the logging prints will do that anyways under the hood
The text was updated successfully, but these errors were encountered: