Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph]Add group sharded api #40129

Merged
merged 1 commit into from
Mar 9, 2022

Conversation

Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Mar 3, 2022

PR types

New features

PR changes

APIs

Describe

Add group sharded api

  1. group_sharded_parallel
  2. save_group_sharded_model
import paddle
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model

fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)

clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)

# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)

img, label = data
label.stop_gradient = True
img.stop_gradient = True

out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)

loss.backward()
optimizer.step()
optimizer.clear_grad()

# save model and optimizer state_dict
save_group_sharded_model(model, output=output_dir, optimizer=optimizer)

@paddle-bot-old
Copy link

paddle-bot-old bot commented Mar 3, 2022

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -136,7 +136,7 @@ def __init__(self,
# Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status()

@paddle.no_grad()
@fluid.dygraph.no_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里引入fluid的原因是什么?fluid下的api会被废弃。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改为paddle.autograd.no_grad()

logger_ = get_logger(logging.INFO)


class ShardedLevel(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议去掉这个对象,不需要为参数定义单独增加一个对象

  1. 直接在group_sharded_parallel函数里使用level='os'或者直接使用level=1,参考amp的level定义,一般理解level对应一个整数,类似verbose之类的
  2. os, os_g, p_g_os是什么的缩写?可读性较差,是否有更好的表示方式?

Copy link
Contributor Author

@Baibaifan Baibaifan Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经过讨论,去掉ShardedLevel,采用字符串名字"os", "os_g", "p_g_os"作为level,level名字和论文对齐。


def group_sharded_parallel(model,
optimizer,
shard_level,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_level -> level
因为api名称已经包含sharded了,这里的参数默认都是针对shard的参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

p_g_os = 3


def group_sharded_parallel(model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了group_sharded以外,是否还有其他的sharded方式?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前采用group_sharded的意思是分组参数切片,是和数据并行并列的一种分布式方式,所以定义为group_sharded。目前还未有其他sharded方式。

from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler

__all__ = ['ShardedLevel', 'group_sharded_parallel', 'save_for_group_sharded']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要用__all__公开api,通过__init__.py公开就行
paddle.distributed.sharding.group_sharded_parallel
而不是
paddle.distributed.sharding.group_sharded.group_sharded_parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

return model, optimizer, scaler


def save_for_group_sharded(model, output, optimizer=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了group外,是否还有其他的参数形式?
save_for_group_sharded -> save_sharded_model ? 或者save_group_sharded_model呢?
类似save_inference_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经讨论修改为save_group_sharded_model

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)

@Baibaifan Baibaifan closed this Mar 9, 2022
@Baibaifan Baibaifan reopened this Mar 9, 2022
@Baibaifan Baibaifan merged commit f40ed5f into PaddlePaddle:develop Mar 9, 2022
@@ -55,6 +55,7 @@
from . import cloud_utils # noqa: F401
from . import utils # noqa: F401

from .sharding import * # noqa: F401
Copy link
Contributor

@gongweibao gongweibao Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import *?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants