Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Graph Support to ASP #38517

Closed
wants to merge 3 commits into from

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Dec 28, 2021

PR types

New features

PR changes

APIs

Describe

Dynamic graph support to Automatic SParsity.

  1. Added dynamic support to ASP module (paddle.fluid.contrib.sparsity).
  2. Changed alignment from paddle.static.sparsity to paddle.sparsity.
  3. Added ASP related unit-tests regards to above changes.

Usage Examples.

import paddle
from paddle import sparsity

# Dynamic Graph ----------------------------------------------------------
class MyLayer(paddle.nn.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self.conv1 = paddle.nn.Conv2D(
            in_channels=3, out_channels=4, kernel_size=3, padding=2)
        self.linear1 = paddle.nn.Linear(4624, 32)
        self.linear2 = paddle.nn.Linear(32, 32)
        self.linear3 = paddle.nn.Linear(32, 10)

    def forward(self, img):
        hidden = self.conv1(img)
        hidden = paddle.flatten(hidden, start_axis=1)
        hidden = self.linear1(hidden)
        hidden = self.linear2(hidden)
        prediction = self.linear3(hidden)
        return prediction

my_layer = MyLayer()
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=my_layer.parameters())

# Need to set excluded layers before calling decorate
sparsity.set_excluded_layers(["linear_0"])

# `decorate` would create necessary ASP mask variables when dynamic graph mode is enable.
optimizer = sparsity.decorate(optimizer)

loss_fn = paddle.nn.MSELoss(reduction='mean')

output = layer(input_x)
loss = loss_fn(output, labels)
loss.backward()
# `step` will call masking operations for ASP workflow.
optimizer.step()
optimizer.clear_grad()

# Static Graph ----------------------------------------------------------
paddle.enable_static()

main_program = paddle.static.Program()
startup_program = paddle.static.Program()

with paddle.static.program_guard(main_program, startup_program):
    input_data = paddle.static.data(name='data', shape=[None, 128])
    label = paddle.static.data(name='label', shape=[None, 10])
    hidden = paddle.static.nn.fc(x=input_data, num_flatten_dims=-1, size=32, activation=None, name="need_sparse_fc")
    hidden = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=32, activation=None, name="need_dense_fc")
    prob = paddle.static.nn.fc(x=hidden, num_flatten_dims=-1, size=10, activation=None)
    loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label))

    # Setup exluded layers out from ASP workflow.
    # Please note, excluded_layers must be set before calling `optimizer.minimize()`.
    sparsity.set_excluded_layers(["need_dense_fc"], main_program)

    optimizer = paddle.optimizer.SGD(learning_rate=0.1)
    optimizer = paddle.static.amp.decorate(optimizer )
    # Calling sparsity.decorate() to wrap minimize() in optimizer, which 
    # will insert necessary masking operations for ASP workflow.
    optimizer = sparsity.decorate(optimizer)
    optimizer.minimize(loss, startup_program)

1. Added functions step and clear_grad to OptimizerWithSparsityGuarantee.
2. Added step function to ASPHelper.
3.  Added prune_model_by_layer and renamed original prune_mode to prune_model_by_program.
4. Move paddle.static.sparsity to paddle.sparsity
@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 28, 2021

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jan 5, 2022

Sorry to inform you that 1511702's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@@ -309,20 +441,24 @@ def decorate(optimizer):
r"""
This is the implementation of `sparsity.decorate`, for details please see explanation in `sparsity.decorate`.
"""
if in_dygraph_mode():
main_prog = paddle.static.default_main_program()
startup_prog = paddle.static.default_startup_program()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是动态图,就不再存在program

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

線下討論結果:
此處的main_program, startup_program僅作為ASP內部存放mask variables位置的key, 並在step階段透過key將mask取出來進行weight sparse masking.
為避免動靜態圖 (且靜態圖以default_main_program為main_program)混合使用的情境, 考慮使用None作為動靜態圖的區分,待確定評審意見後進行對應修改。

@@ -281,12 +413,12 @@ class ASPHelper(object):
"""

MASK_APPENDDED_NAME = '_asp_mask'
SUPPORTED_LAYERS = {'fc': 'w_0', 'linear': 'w_0', 'conv2d': 'w_0'}
PADDLE_WEIGHT_SUFFIX = "w_"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也有可能会存在fc、linear、conv以外的算子,参数以w_方式命名,这种方式不太保险啊。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

實際上的支援判斷規則有以下三點

  1. If the name of parameter in supported_list, then return True.
  2. If the name of parameter without type suffix in supported_list, then return True.
  3. If the class name of parameter without type suffix and counter in supported_list, then return True.

詳細規則可以參考ASPHelper._is_supported_layer

'The "mask_algo" should be one of ["mask_1d", "mask_2d_greedy", "mask_2d_best"]'

prune_func = None
if isinstance(model, paddle.nn.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里使用 prune_layer或者prune_program 和在 prune_model_by_layer 中的动静态图if-else 是否可以合在一起呢 ?统一的在这里控制 ?


add_supported_layer('fc')
add_supported_layer('linear')
add_supported_layer('conv2d')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

小疑问:conv2d 是把核函数 weight 4:2 吗 ?

if with_mask:
weight_mask_param = global_scope().find_var(
ASPHelper._get_mask_name(param.name))
assert weight_mask_param is not None, \
'Cannot find {} variable, please call ASPHelper.minimize' \
'Cannot find {} variable, please call optimizer.minimize (' \
'paddle.sparsity.decorate(optimizer).minimize(loss)' \
' and initialization (exe.run(startup_program)) first!'.format(ASPHelper._get_mask_name(param.name))
weight_mask_tensor = weight_mask_param.get_tensor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分计算有可能通过 Op 在 GPU 上计算吗 ?

@mingxu1067
Copy link
Collaborator Author

Split this PR to two parts, 40253 (Add Support Layer List to ASP) and 41177 (Dynamic graph support to Automatic SParsity).
Will close this one, after above two are merged.

@mingxu1067
Copy link
Collaborator Author

Close this, due function are merged in #41177

@paddle-bot-old
Copy link

很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。
Sorry to inform you that through our discussion, your PR fails to meet the merging standard (Reference: Paddle Custom Operator Design Doc). You can also submit an new one. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants