Skip to content

Conversation

@yzhangcs
Copy link
Contributor

@yzhangcs yzhangcs commented Mar 11, 2025

What does this PR do?

Fix some minor issues in PR #938

  1. Fix the decay_ratio in debug_model.toml, ensuing that warmup_stable_steps > warmup_steps
  2. Make sure warmup_stable_steps is rounded to an integer
  3. Move lr check into JobConfig

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 11, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left some comments.

f"The warmup steps should be less than or equal to the warmup-stable steps ({warmup_stable_steps}). "
f"Consider reducing either the decay ratio ({lr_decay_ratio}) or the warmup steps ({warmup_steps})."
)
warmup_stable_steps = round(training_steps * (1 - lr_decay_ratio))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can avoid this recomputation by feeding in warmup_stable_steps as an arg (and remove lr_decay_ratio)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lr_decay_ratio would be a better choice as it is typically a fixed ratio of training_steps. Specifying both warmup_stable_steps and training_steps might introduce unnecessary cognitive load.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but does it make sense to do this computation and rounding every single training step? btw currently we are not explicitly passing in training_steps to this function anyway.

An alternative proposal is:

  • We can make warmup_steps, stable_steps, and decay_steps as input to this function, matching the function name properly (so cognitive load is lower?).
  • We calculate these things from outside and only calculate once.
  • If warmup_steps (directly specified) + decay_steps (calculated based on ratio) > training_steps, we can set stable_steps = 0.

The benefit is that linear_warmup_stable_decay is only responsible for the actual piecewise function for LR rate, and the logistics of computing the WSD phase lengths is from outside and only done once. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Hi, just updated the code based on your suggestions :)

@yzhangcs
Copy link
Contributor Author

bash run_train.sh --lr_scheduler.warmup_steps=4  --lr_scheduler.decay_ratio=0.95  --lr_scheduler.decay_type=linear  --training.steps=40

here is the outputs

[rank0]:[titan] 2025-03-12 11:10:11,808 - root - WARNING - The warmup steps should be less than or equal to the warmup-stable steps (2). Consider reducing either the decay ratio (0.95) or the warmup steps (4).
[rank0]:[titan] 2025-03-12 11:10:11,808 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-03-12 11:10:11,809 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-03-12 11:10:12,612 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-03-12 11:10:12,615 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-03-12 11:10:18,277 - root - INFO - TikTokenizer built: #words 2256, BOS ID 2000, EOS ID 2001
[rank0]:[titan] 2025-03-12 11:10:18,277 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-03-12 11:10:18,424 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=8, n_heads=16, n_kv_heads=None, vocab_size=2256, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, norm_type='rmsnorm')
[rank0]:[titan] 2025-03-12 11:10:18,491 - root - INFO - CUDA capacity: NVIDIA H100 80GB HBM3 with 79.19GiB memory
[rank0]:[titan] 2025-03-12 11:10:18,494 - root - WARNING - Error running lspci: [Errno 2] No such file or directory: 'lspci', fallback to use device_name
[rank0]:[titan] 2025-03-12 11:10:18,495 - root - INFO - Model llama3 debugmodel size: 7,975,168 total parameters
[rank0]:[titan] 2025-03-12 11:10:18,495 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-03-12 11:10:18,528 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - WARNING - Error running lspci: [Errno 2] No such file or directory: 'lspci', fallback to use device_name
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-03-12 11:10:18,689 - root - INFO - CUDA memory usage for model: 0.01GiB(0.01%)
[rank0]:[titan] 2025-03-12 11:10:18,690 - root - INFO - Training starts at step 1, with local batch size 8, global batch size 64, sequence length 2048, total steps 40 (warmup 4)
[rank0]:[titan] 2025-03-12 11:10:19,620 - root - INFO - step:  1  loss:  8.2023  memory:  1.39GiB(1.75%)  tps: 14,563  tflops: 1.38  mfu: 0.14%
[rank0]:[titan] 2025-03-12 11:10:19,620 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-03-12 11:10:19,706 - root - INFO - step:  2  loss:  8.1718  memory:  1.41GiB(1.79%)  tps: 190,528  tflops: 18.05  mfu: 1.82%
[rank0]:[titan] 2025-03-12 11:10:19,793 - root - INFO - step:  3  loss:  8.1348  memory:  1.41GiB(1.79%)  tps: 188,922  tflops: 17.89  mfu: 1.81%
[rank0]:[titan] 2025-03-12 11:10:19,879 - root - INFO - step:  4  loss:  8.0639  memory:  1.41GiB(1.79%)  tps: 192,283  tflops: 18.21  mfu: 1.84%
[rank0]:[titan] 2025-03-12 11:10:19,957 - root - INFO - step:  5  loss:  7.9554  memory:  1.41GiB(1.79%)  tps: 211,260  tflops: 20.01  mfu: 2.02%
[rank0]:[titan] 2025-03-12 11:10:20,037 - root - INFO - step:  6  loss:  7.7464  memory:  1.41GiB(1.79%)  tps: 204,153  tflops: 19.34  mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:20,133 - root - INFO - step:  7  loss:  7.4730  memory:  1.41GiB(1.79%)  tps: 172,727  tflops: 16.36  mfu: 1.65%
[rank0]:[titan] 2025-03-12 11:10:20,221 - root - INFO - step:  8  loss:  7.1972  memory:  1.41GiB(1.79%)  tps: 187,061  tflops: 17.72  mfu: 1.79%
[rank0]:[titan] 2025-03-12 11:10:20,308 - root - INFO - step:  9  loss:  6.9878  memory:  1.41GiB(1.79%)  tps: 187,481  tflops: 17.76  mfu: 1.80%
[rank0]:[titan] 2025-03-12 11:10:20,396 - root - INFO - step: 10  loss:  6.8936  memory:  1.41GiB(1.79%)  tps: 187,707  tflops: 17.78  mfu: 1.80%
[rank0]:[titan] 2025-03-12 11:10:20,497 - root - INFO - step: 11  loss:  6.7785  memory:  1.41GiB(1.79%)  tps: 163,028  tflops: 15.44  mfu: 1.56%
[rank0]:[titan] 2025-03-12 11:10:20,511 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:20,597 - root - INFO - step: 12  loss:  6.6870  memory:  1.41GiB(1.79%)  tps: 163,552  tflops: 15.49  mfu: 1.57%
[rank0]:[titan] 2025-03-12 11:10:20,702 - root - INFO - step: 13  loss:  6.6189  memory:  1.41GiB(1.79%)  tps: 157,297  tflops: 14.90  mfu: 1.51%
[rank0]:[titan] 2025-03-12 11:10:20,806 - root - INFO - step: 14  loss:  6.5221  memory:  1.41GiB(1.79%)  tps: 158,442  tflops: 15.01  mfu: 1.52%
[rank0]:[titan] 2025-03-12 11:10:20,908 - root - INFO - step: 15  loss:  6.4979  memory:  1.41GiB(1.79%)  tps: 160,132  tflops: 15.17  mfu: 1.53%
[rank0]:[titan] 2025-03-12 11:10:21,006 - root - INFO - step: 16  loss:  6.4628  memory:  1.41GiB(1.79%)  tps: 167,695  tflops: 15.88  mfu: 1.61%
[rank0]:[titan] 2025-03-12 11:10:21,095 - root - INFO - step: 17  loss:  6.3759  memory:  1.41GiB(1.79%)  tps: 184,477  tflops: 17.47  mfu: 1.77%
[rank0]:[titan] 2025-03-12 11:10:21,178 - root - INFO - step: 18  loss:  6.3192  memory:  1.41GiB(1.79%)  tps: 198,526  tflops: 18.80  mfu: 1.90%
[rank0]:[titan] 2025-03-12 11:10:21,263 - root - INFO - step: 19  loss:  6.2907  memory:  1.41GiB(1.79%)  tps: 193,687  tflops: 18.35  mfu: 1.85%
[rank0]:[titan] 2025-03-12 11:10:21,346 - root - INFO - step: 20  loss:  6.2619  memory:  1.41GiB(1.79%)  tps: 198,550  tflops: 18.81  mfu: 1.90%
[rank0]:[titan] 2025-03-12 11:10:21,426 - root - INFO - step: 21  loss:  6.2519  memory:  1.41GiB(1.79%)  tps: 204,828  tflops: 19.40  mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:21,511 - root - INFO - step: 22  loss:  6.2250  memory:  1.41GiB(1.79%)  tps: 193,831  tflops: 18.36  mfu: 1.86%
[rank0]:[titan] 2025-03-12 11:10:21,589 - root - INFO - step: 23  loss:  6.2236  memory:  1.41GiB(1.79%)  tps: 212,462  tflops: 20.12  mfu: 2.03%
[rank0]:[titan] 2025-03-12 11:10:21,603 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:21,667 - root - INFO - step: 24  loss:  6.1861  memory:  1.41GiB(1.79%)  tps: 209,645  tflops: 19.86  mfu: 2.01%
[rank0]:[titan] 2025-03-12 11:10:21,751 - root - INFO - step: 25  loss:  6.1657  memory:  1.41GiB(1.79%)  tps: 197,367  tflops: 18.69  mfu: 1.89%
[rank0]:[titan] 2025-03-12 11:10:21,848 - root - INFO - step: 26  loss:  6.1338  memory:  1.41GiB(1.79%)  tps: 169,197  tflops: 16.03  mfu: 1.62%
[rank0]:[titan] 2025-03-12 11:10:21,929 - root - INFO - step: 27  loss:  6.1061  memory:  1.41GiB(1.79%)  tps: 203,673  tflops: 19.29  mfu: 1.95%
[rank0]:[titan] 2025-03-12 11:10:22,009 - root - INFO - step: 28  loss:  6.0646  memory:  1.41GiB(1.79%)  tps: 205,287  tflops: 19.44  mfu: 1.97%
[rank0]:[titan] 2025-03-12 11:10:22,092 - root - INFO - step: 29  loss:  6.1441  memory:  1.41GiB(1.79%)  tps: 196,687  tflops: 18.63  mfu: 1.88%
[rank0]:[titan] 2025-03-12 11:10:22,174 - root - INFO - step: 30  loss:  6.0883  memory:  1.41GiB(1.79%)  tps: 202,703  tflops: 19.20  mfu: 1.94%
[rank0]:[titan] 2025-03-12 11:10:22,255 - root - INFO - step: 31  loss:  6.0294  memory:  1.41GiB(1.79%)  tps: 202,888  tflops: 19.22  mfu: 1.94%
[rank0]:[titan] 2025-03-12 11:10:22,363 - root - INFO - step: 32  loss:  6.0592  memory:  1.41GiB(1.79%)  tps: 152,631  tflops: 14.46  mfu: 1.46%
[rank0]:[titan] 2025-03-12 11:10:22,446 - root - INFO - step: 33  loss:  6.0040  memory:  1.41GiB(1.79%)  tps: 197,792  tflops: 18.73  mfu: 1.89%
[rank0]:[titan] 2025-03-12 11:10:22,525 - root - INFO - step: 34  loss:  6.0207  memory:  1.41GiB(1.79%)  tps: 208,519  tflops: 19.75  mfu: 2.00%
[rank0]:[titan] 2025-03-12 11:10:22,605 - root - INFO - step: 35  loss:  6.0129  memory:  1.41GiB(1.79%)  tps: 203,645  tflops: 19.29  mfu: 1.95%
[rank0]:[titan] 2025-03-12 11:10:22,618 - root - WARNING - Dataset c4_test is being re-looped
[rank0]:[titan] 2025-03-12 11:10:22,684 - root - INFO - step: 36  loss:  6.0163  memory:  1.41GiB(1.79%)  tps: 208,557  tflops: 19.75  mfu: 2.00%
[rank0]:[titan] 2025-03-12 11:10:22,772 - root - INFO - step: 37  loss:  6.0309  memory:  1.41GiB(1.79%)  tps: 188,585  tflops: 17.86  mfu: 1.81%
[rank0]:[titan] 2025-03-12 11:10:22,852 - root - INFO - step: 38  loss:  5.9621  memory:  1.41GiB(1.79%)  tps: 204,681  tflops: 19.39  mfu: 1.96%
[rank0]:[titan] 2025-03-12 11:10:22,935 - root - INFO - step: 39  loss:  5.9948  memory:  1.41GiB(1.79%)  tps: 199,496  tflops: 18.90  mfu: 1.91%
[rank0]:[titan] 2025-03-12 11:10:23,012 - root - INFO - step: 40  loss:  5.9468  memory:  1.41GiB(1.79%)  tps: 212,193  tflops: 20.10  mfu: 2.03%
[rank0]:[titan] 2025-03-12 11:10:23,012 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-03-12 11:10:25,012 - root - INFO - Training completed

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks beautiful, thanks!

@tianyu-l tianyu-l merged commit 7e614f1 into pytorch:main Mar 14, 2025
6 checks passed
MaxiBoether pushed a commit to eth-easl/torchtitan-mixtera that referenced this pull request Apr 17, 2025
### What does this PR do?

Fix some minor issues in PR pytorch#938 
1. Fix the `decay_ratio` in
[debug_model.toml](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/train_configs/debug_model.toml),
ensuing that `warmup_stable_steps` > `warmup_steps`
2. Make sure `warmup_stable_steps` is rounded to an integer
3. Move lr check into `JobConfig`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants