Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] add get() and set() for Grad_scaler #33835

Merged
merged 4 commits into from
Jul 1, 2021

Conversation

zhangbo9674
Copy link
Contributor

PR types

New features

PR changes

APIs

Describe

Add get() and set() function for amp properties:

scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
print("get_enable=",scaler.get_enable())
print("get_decr_every_n_nan_or_inf=",scaler.get_decr_every_n_nan_or_inf())
print("get_decr_ratio=",scaler.get_decr_ratio())
print("get_incr_every_n_steps=",scaler.get_incr_every_n_steps())
print("get_incr_ratio=",scaler.get_incr_ratio())
print("get_init_loss_scaling=",scaler.get_init_loss_scaling())
print("get_use_dynamic_loss_scaling=",scaler.get_use_dynamic_loss_scaling())

scaler.set_decr_every_n_nan_or_inf(2)
print("get_decr_every_n_nan_or_inf=",scaler.get_decr_every_n_nan_or_inf())
scaler.set_decr_ratio(0.1)
print("get_decr_ratio=",scaler.get_decr_ratio())
scaler.set_incr_every_n_steps(200)
print("get_incr_every_n_steps=",scaler.get_incr_every_n_steps())
scaler.set_incr_ratio(3)
print("get_incr_ratio=",scaler.get_incr_ratio())
scaler.set_init_loss_scaling(100)
print("get_init_loss_scaling=",scaler.get_init_loss_scaling())

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

def is_use_loss_scaling(self):
"""
Enable loss scaling or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要加一个 Returns: 然后说明返回类型,以及什么情况下会返回什么样的值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!

def is_use_dynamic_loss_scaling(self):
"""
Whether to use dynamic loss scaling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要补充 Returns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!


Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. code & import 这两行中间需要加一个空行 否则预览会有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!


Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!


Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!

def get_init_loss_scaling(self):
"""
Return the initial loss scaling factor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加Returns 同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done!

@@ -145,3 +145,289 @@ def minimize(self, optimizer, *args, **kwargs):
optimizer.clear_grad()
"""
return super(GradScaler, self).minimize(optimizer, *args, **kwargs)

def is_enable(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_enable -> is_enabled

"""
return super(GradScaler, self).is_enable()

def is_use_dynamic_loss_scaling(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_use_dynamic_loss_scaling -> is_dynamic_loss_scaling_used?

zhiqiu
zhiqiu previously approved these changes Jun 30, 2021
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
use_dynamic_loss_scaling=True)
enable = scaler.get_enable()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaler.get_enable -> scaler.is_enable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done.

incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
use_dynamic_loss_scaling=True)
use_dynamic_loss_scaling = scaler.get_use_dynamic_loss_scaling()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaler.get_use_dynamic_loss_scaling -> scaler.is_use_dynamic_loss_scaling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks,done.


Examples:
.. code-block:: python
import paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加空行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tks. done.

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit 8568734 into PaddlePaddle:develop Jul 1, 2021
@zhangbo9674 zhangbo9674 deleted the dev/amp_get_set branch September 14, 2022 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants