Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impl adam_w #2957

Merged
merged 17 commits into from
Jun 26, 2024
Merged

Impl adam_w #2957

merged 17 commits into from
Jun 26, 2024

Conversation

et16kr
Copy link
Collaborator

@et16kr et16kr commented May 10, 2024

  • Add AdamW to Adam optimizer. (ref)
  • Added driver test and gtest for adamW.
  • The performance measurement was done with PyTorch Fused Adam.
adam float16
dims rocm pytorch miopen imporvement model
96 36544 6416 5.70 gpt
288 25408 6608 3.85 gpt
384 17200 6592 2.61 gpt
512 17296 6928 2.50 t5
768 25776 7152 3.60 gpt
4x8 21712 6464 3.36 t5
64x512 67344 8977 7.50 t5
64x1024 122544 9392 13.05 t5
96x768 118240 9488 12.46 gpt
96x3072 118400 12304 9.62 gpt
128x512 117648 9280 12.68 t5
128x768 116464 9872 11.80 gpt
192x512 117408 9808 11.97 t5
288x768 119424 11345 10.53 gpt
384x768 118176 12273 9.63 gpt
6278x512 129632 62259 2.08 t5
6278x768 136704 92516 1.48 gpt
6283x512 128288 64419 1.99 t5
6283x768 135824 91732 1.48 gpt
adam float32
dims rocm pytorch miopen imporvement model
96 34816 6464 5.39 gpt
288 23312 6848 3.40 gpt
384 15728 7121 2.21 gpt
512 16336 7344 2.22 t5
768 28112 8096 3.47 gpt
4x8 22320 6416 3.48 t5
64x512 64960 9200 7.06 t5
64x1024 116896 9792 11.94 t5
96x768 115488 9937 11.62 gpt
96x3072 117248 14081 8.33 gpt
128x512 114720 9713 11.81 t5
128x768 113664 10273 11.06 gpt
192x512 114240 10353 11.03 t5
288x768 114496 11921 9.60 gpt
384x768 115680 14241 8.12 gpt
6278x512 158352 69412 2.28 t5
6278x768 192768 101557 1.90 gpt
6283x512 141088 69076 2.04 t5
6283x768 209808 101973 2.06 gpt
amp adam (parameter float32, gradient float16)
dims rocm pytorch miopen imporvement model
96 39408 9216 4.28 gpt
288 45856 9856 4.65 gpt
384 43552 9808 4.44 gpt
512 30544 10496 2.91 t5
768 23728 10496 2.26 gpt
4x8 24416 9632 2.53 t5
64x512 80032 12336 6.49 t5
64x1024 130272 12896 10.10 t5
96x768 126800 12928 9.81 gpt
96x3072 130000 17664 7.36 gpt
128x512 132096 12624 10.46 t5
128x768 129696 13744 9.44 gpt
192x512 126368 13936 9.07 t5
288x768 128784 15616 8.25 gpt
384x768 128672 17760 7.25 gpt
6278x512 167648 73088 2.29 t5
6278x768 204128 105552 1.93 gpt
6283x512 151792 73168 2.07 t5
6283x768 224144 105616 2.12 gpt
type avrage
adam float16 6.27
adam float32 6.73
amp adam 5.27

@et16kr et16kr marked this pull request as ready for review May 21, 2024 06:02
@JehandadKhan
Copy link
Collaborator

restarted the CI, @et16kr Please resolve the merge conflict.

src/tensor.cpp Show resolved Hide resolved
Comment on lines 238 to 239
template <typename Tgpu, typename Tref, bool adamw, bool is_amp, typename Tgrad>
int AdamDriver<Tgpu, Tref, adamw, is_amp, Tgrad>::ParseCmdLineArgs(int argc, char* argv[])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed a previous PR, but since a variable has been added here, I would ask is it necessary having all those variables as a template parameters?
At least is_amp and adamw should be runtime parameters to avoid code bloating and long compilation time.

@et16kr et16kr mentioned this pull request Jun 13, 2024
@junliume junliume merged commit 19988a7 into develop Jun 26, 2024
141 checks passed
@junliume junliume deleted the impl_AdamW branch June 26, 2024 04:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants