-
Notifications
You must be signed in to change notification settings - Fork 679
Conduct lr check only once #950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, left some comments.
torchtitan/components/optimizer.py
Outdated
| f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). " | ||
| f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})." | ||
| ) | ||
| warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can avoid this recomputation by feeding in warmup_stable_steps as an arg (and remove lr_decay_ratio)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think lr_decay_ratio would be a better choice as it is typically a fixed ratio of training_steps. Specifying both warmup_stable_steps and training_steps might introduce unnecessary cognitive load.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but does it make sense to do this computation and rounding every single training step? btw currently we are not explicitly passing in training_steps to this function anyway.
An alternative proposal is:
- We can make
warmup_steps,stable_steps, anddecay_stepsas input to this function, matching the function name properly (so cognitive load is lower?). - We calculate these things from outside and only calculate once.
- If
warmup_steps(directly specified) +decay_steps(calculated based on ratio) >training_steps, we can setstable_steps = 0.
The benefit is that linear_warmup_stable_decay is only responsible for the actual piecewise function for LR rate, and the logistics of computing the WSD phase lengths is from outside and only done once. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l Hi, just updated the code based on your suggestions :)
bash run_train.sh --lr_scheduler.warmup_steps=4 --lr_scheduler.decay_ratio=0.95 --lr_scheduler.decay_type=linear --training.steps=40here is the outputs [rank0]:[titan] 2025-03-12 11:10:11,808 - root - WARNING - The warmup steps should be less than or equal to the warmup-stable steps (2). Consider reducing either the decay ratio (0.95) or the warmup steps (4).
[rank0]:[titan] 2025-03-12 11:10:11,808 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-03-12 11:10:11,809 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-03-12 11:10:12,612 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-03-12 11:10:12,615 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-03-12 11:10:18,277 - root - INFO - TikTokenizer built: #words 2256, BOS ID 2000, EOS ID 2001
[rank0]:[titan] 2025-03-12 11:10:18,277 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-03-12 11:10:18,424 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=8, n_heads=16, n_kv_heads=None, vocab_size=2256, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, norm_type='rmsnorm')
[rank0]:[titan] 2025-03-12 11:10:18,491 - root - INFO - CUDA capacity: NVIDIA H100 80GB HBM3 with 79.19GiB memory
[rank0]:[titan] 2025-03-12 11:10:18,494 - root - WARNING - Error running lspci: [Errno 2] No such file or directory: 'lspci', fallback to use device_name
[rank0]:[titan] 2025-03-12 11:10:18,495 - root - INFO - Model llama3 debugmodel size: 7,975,168 total parameters
[rank0]:[titan] 2025-03-12 11:10:18,495 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-03-12 11:10:18,528 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - WARNING - Error running lspci: [Errno 2] No such file or directory: 'lspci', fallback to use device_name
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - INFO - CUDA memory usage for model: 0.01GiB(0.01%)
[rank0]:[titan] 2025-03-12 11:10:18,690 - root - INFO - Training starts at step 1, with local batch size 8, global batch size 64, sequence length 2048, total steps 40 (warmup 4)
[rank0]:[titan] 2025-03-12 11:10:19,620 - root - INFO - step: 1 loss: 8.2023 memory: 1.39GiB(1.75%) tps: 14,563 tflops: 1.38 mfu: 0.14%
[rank0]:[titan] 2025-03-12 11:10:19,620 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-03-12 11:10:19,706 - root - INFO - step: 2 loss: 8.1718 memory: 1.41GiB(1.79%) tps: 190,528 tflops: 18.05 mfu: 1.82%
[rank0]:[titan] 2025-03-12 11:10:19,793 - root - INFO - step: 3 loss: 8.1348 memory: 1.41GiB(1.79%) tps: 188,922 tflops: 17.89 mfu: 1.81%
[rank0]:[titan] 2025-03-12 11:10:19,879 - root - INFO - step: 4 loss: 8.0639 memory: 1.41GiB(1.79%) tps: 192,283 tflops: 18.21 mfu: 1.84%
[rank0]:[titan] 2025-03-12 11:10:19,957 - root - INFO - step: 5 loss: 7.9554 memory: 1.41GiB(1.79%) tps: 211,260 tflops: 20.01 mfu: 2.02%
[rank0]:[titan] 2025-03-12 11:10:20,037 - root - INFO - step: 6 loss: 7.7464 memory: 1.41GiB(1.79%) tps: 204,153 tflops: 19.34 mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:20,133 - root - INFO - step: 7 loss: 7.4730 memory: 1.41GiB(1.79%) tps: 172,727 tflops: 16.36 mfu: 1.65%
[rank0]:[titan] 2025-03-12 11:10:20,221 - root - INFO - step: 8 loss: 7.1972 memory: 1.41GiB(1.79%) tps: 187,061 tflops: 17.72 mfu: 1.79%
[rank0]:[titan] 2025-03-12 11:10:20,308 - root - INFO - step: 9 loss: 6.9878 memory: 1.41GiB(1.79%) tps: 187,481 tflops: 17.76 mfu: 1.80%
[rank0]:[titan] 2025-03-12 11:10:20,396 - root - INFO - step: 10 loss: 6.8936 memory: 1.41GiB(1.79%) tps: 187,707 tflops: 17.78 mfu: 1.80%
[rank0]:[titan] 2025-03-12 11:10:20,497 - root - INFO - step: 11 loss: 6.7785 memory: 1.41GiB(1.79%) tps: 163,028 tflops: 15.44 mfu: 1.56%
[rank0]:[titan] 2025-03-12 11:10:20,511 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:20,597 - root - INFO - step: 12 loss: 6.6870 memory: 1.41GiB(1.79%) tps: 163,552 tflops: 15.49 mfu: 1.57%
[rank0]:[titan] 2025-03-12 11:10:20,702 - root - INFO - step: 13 loss: 6.6189 memory: 1.41GiB(1.79%) tps: 157,297 tflops: 14.90 mfu: 1.51%
[rank0]:[titan] 2025-03-12 11:10:20,806 - root - INFO - step: 14 loss: 6.5221 memory: 1.41GiB(1.79%) tps: 158,442 tflops: 15.01 mfu: 1.52%
[rank0]:[titan] 2025-03-12 11:10:20,908 - root - INFO - step: 15 loss: 6.4979 memory: 1.41GiB(1.79%) tps: 160,132 tflops: 15.17 mfu: 1.53%
[rank0]:[titan] 2025-03-12 11:10:21,006 - root - INFO - step: 16 loss: 6.4628 memory: 1.41GiB(1.79%) tps: 167,695 tflops: 15.88 mfu: 1.61%
[rank0]:[titan] 2025-03-12 11:10:21,095 - root - INFO - step: 17 loss: 6.3759 memory: 1.41GiB(1.79%) tps: 184,477 tflops: 17.47 mfu: 1.77%
[rank0]:[titan] 2025-03-12 11:10:21,178 - root - INFO - step: 18 loss: 6.3192 memory: 1.41GiB(1.79%) tps: 198,526 tflops: 18.80 mfu: 1.90%
[rank0]:[titan] 2025-03-12 11:10:21,263 - root - INFO - step: 19 loss: 6.2907 memory: 1.41GiB(1.79%) tps: 193,687 tflops: 18.35 mfu: 1.85%
[rank0]:[titan] 2025-03-12 11:10:21,346 - root - INFO - step: 20 loss: 6.2619 memory: 1.41GiB(1.79%) tps: 198,550 tflops: 18.81 mfu: 1.90%
[rank0]:[titan] 2025-03-12 11:10:21,426 - root - INFO - step: 21 loss: 6.2519 memory: 1.41GiB(1.79%) tps: 204,828 tflops: 19.40 mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:21,511 - root - INFO - step: 22 loss: 6.2250 memory: 1.41GiB(1.79%) tps: 193,831 tflops: 18.36 mfu: 1.86%
[rank0]:[titan] 2025-03-12 11:10:21,589 - root - INFO - step: 23 loss: 6.2236 memory: 1.41GiB(1.79%) tps: 212,462 tflops: 20.12 mfu: 2.03%
[rank0]:[titan] 2025-03-12 11:10:21,603 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:21,667 - root - INFO - step: 24 loss: 6.1861 memory: 1.41GiB(1.79%) tps: 209,645 tflops: 19.86 mfu: 2.01%
[rank0]:[titan] 2025-03-12 11:10:21,751 - root - INFO - step: 25 loss: 6.1657 memory: 1.41GiB(1.79%) tps: 197,367 tflops: 18.69 mfu: 1.89%
[rank0]:[titan] 2025-03-12 11:10:21,848 - root - INFO - step: 26 loss: 6.1338 memory: 1.41GiB(1.79%) tps: 169,197 tflops: 16.03 mfu: 1.62%
[rank0]:[titan] 2025-03-12 11:10:21,929 - root - INFO - step: 27 loss: 6.1061 memory: 1.41GiB(1.79%) tps: 203,673 tflops: 19.29 mfu: 1.95%
[rank0]:[titan] 2025-03-12 11:10:22,009 - root - INFO - step: 28 loss: 6.0646 memory: 1.41GiB(1.79%) tps: 205,287 tflops: 19.44 mfu: 1.97%
[rank0]:[titan] 2025-03-12 11:10:22,092 - root - INFO - step: 29 loss: 6.1441 memory: 1.41GiB(1.79%) tps: 196,687 tflops: 18.63 mfu: 1.88%
[rank0]:[titan] 2025-03-12 11:10:22,174 - root - INFO - step: 30 loss: 6.0883 memory: 1.41GiB(1.79%) tps: 202,703 tflops: 19.20 mfu: 1.94%
[rank0]:[titan] 2025-03-12 11:10:22,255 - root - INFO - step: 31 loss: 6.0294 memory: 1.41GiB(1.79%) tps: 202,888 tflops: 19.22 mfu: 1.94%
[rank0]:[titan] 2025-03-12 11:10:22,363 - root - INFO - step: 32 loss: 6.0592 memory: 1.41GiB(1.79%) tps: 152,631 tflops: 14.46 mfu: 1.46%
[rank0]:[titan] 2025-03-12 11:10:22,446 - root - INFO - step: 33 loss: 6.0040 memory: 1.41GiB(1.79%) tps: 197,792 tflops: 18.73 mfu: 1.89%
[rank0]:[titan] 2025-03-12 11:10:22,525 - root - INFO - step: 34 loss: 6.0207 memory: 1.41GiB(1.79%) tps: 208,519 tflops: 19.75 mfu: 2.00%
[rank0]:[titan] 2025-03-12 11:10:22,605 - root - INFO - step: 35 loss: 6.0129 memory: 1.41GiB(1.79%) tps: 203,645 tflops: 19.29 mfu: 1.95%
[rank0]:[titan] 2025-03-12 11:10:22,618 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:22,684 - root - INFO - step: 36 loss: 6.0163 memory: 1.41GiB(1.79%) tps: 208,557 tflops: 19.75 mfu: 2.00%
[rank0]:[titan] 2025-03-12 11:10:22,772 - root - INFO - step: 37 loss: 6.0309 memory: 1.41GiB(1.79%) tps: 188,585 tflops: 17.86 mfu: 1.81%
[rank0]:[titan] 2025-03-12 11:10:22,852 - root - INFO - step: 38 loss: 5.9621 memory: 1.41GiB(1.79%) tps: 204,681 tflops: 19.39 mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:22,935 - root - INFO - step: 39 loss: 5.9948 memory: 1.41GiB(1.79%) tps: 199,496 tflops: 18.90 mfu: 1.91%
[rank0]:[titan] 2025-03-12 11:10:23,012 - root - INFO - step: 40 loss: 5.9468 memory: 1.41GiB(1.79%) tps: 212,193 tflops: 20.10 mfu: 2.03%
[rank0]:[titan] 2025-03-12 11:10:23,012 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-03-12 11:10:25,012 - root - INFO - Training completed |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks beautiful, thanks!
### What does this PR do? Fix some minor issues in PR pytorch#938 1. Fix the `decay_ratio` in [debug_model.toml](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/train_configs/debug_model.toml), ensuing that `warmup_stable_steps` > `warmup_steps` 2. Make sure `warmup_stable_steps` is rounded to an integer 3. Move lr check into `JobConfig`
What does this PR do?
Fix some minor issues in PR #938
decay_ratioin debug_model.toml, ensuing thatwarmup_stable_steps>warmup_stepswarmup_stable_stepsis rounded to an integerJobConfig