Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Improve the APIs #45776

Merged
merged 60 commits into from
Sep 15, 2022
Merged

Conversation

aoyulong
Copy link
Contributor

@aoyulong aoyulong commented Sep 6, 2022

PR types

Others

PR changes

APIs

Describe

This pr systematically improve the APIs of auto parallel, and update the codes and unittests as well. These APIs can work in both the dynamic and static graphs, which are transparent for users. The main contributions:

  • Engine class (in Paddle/python/paddle/distributed/auto_parallel/engine.py) provides the high-level APIs for distributed training, evaluating and predicting. For example:

    import paddle
    import paddle.vision.transforms as T
    import paddle.distributed.auto_parallel as auto
    from paddle.vision.datasets import MNIST
    
    transform = T.Compose([
        T.Transpose(),
        T.Normalize([127.5], [127.5])
    ])
    train_dataset = MNIST(mode='train', transform=transform)
    valid_dataset = MNIST(mode='test', transform=transform)
    
    model = paddle.vision.models.LeNet()
    loss = paddle.nn.CrossEntropyLoss() 
    optimizer = paddle.optimizer.Adam(
        learning_rate=0.001, parameters=model.parameters())
    metrics = paddle.metric.Accuracy(topk=(1, 2))
    
    engine = auto.Engine(model, loss, optimizer, metrics) 
    # fit 
    engine.fit(train_dataset,
               epochs=2,
               batch_size=64)
    # evaluate 
    engine.evaluate(valid_dataset,
                    batch_size=64)
    # predict
    engine.predict(valid_dataset,
                   batch_size=64)
    # save
    engine.save("./my_model")
    # load 
    engine.load("./my_model")
  • shard_tensor and shard_op (in Paddle/python/paddle/distributed/auto_parallel/interface.py) provides the mid-level APIs for users to shard tensors or operators according to their own choices. For example:

    import paddle
    import paddle.distributed.auto_parallel as auto 
    
    mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
    a = paddle.ones([4, 6])
    b = paddle.zeros([4, 6])
    
    # shard_tensor
    auto.shard_tensor(a, mesh, shard_spec=["x", "y"])
    # shard_op, functional style
    auto.shard_op(paddle.add, mesh,
                  in_shard_specs=[["x", "y"], ["y", None]],
                  out_shard_specs=[[None, "x"]])(a, b)
  • Strategy (in Paddle/python/paddle/distributed/auto_parallel/strategy.py) is used to configure the paralleization and optimization behaviors.

  • ProcessMesh (in Paddle/python/paddle/distributed/auto_parallel/process_mesh.py) is used to describe the topology of the used processes in the distributed computation.

@paddle-bot
Copy link

paddle-bot bot commented Sep 6, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@aoyulong aoyulong changed the title [Auto Parallel] Improve the APIs [WIP: Auto Parallel] Improve the APIs Sep 7, 2022
@aoyulong aoyulong changed the title [WIP: Auto Parallel] Improve the APIs [Auto Parallel] WIP: Improve the APIs Sep 7, 2022
@aoyulong aoyulong changed the title [Auto Parallel] WIP: Improve the APIs [Auto Parallel] Improve the APIs Sep 14, 2022
valid_data=None,
valid_freq=1,
valid_batch_size=1,
valid_sample_split=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面evaluate接口中用的eval_data,这里命名最好保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@JiabinYang JiabinYang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
单测时间阈值修改
单测迁移

@aoyulong aoyulong merged commit b042a3b into PaddlePaddle:develop Sep 15, 2022
aoyulong added a commit to aoyulong/Paddle that referenced this pull request Sep 17, 2022
* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix

Co-authored-by: zhaoyingli <[email protected]>
Co-authored-by: caozhou <[email protected]>
Co-authored-by: caozhou <[email protected]>
fuyinno4 pushed a commit that referenced this pull request Sep 19, 2022
* [AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy

* [Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program

* [Auto Parallel] Improve the APIs (#45776)

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix

Co-authored-by: zhaoyingli <[email protected]>
Co-authored-by: caozhou <[email protected]>
Co-authored-by: caozhou <[email protected]>

* [Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed

* update strategy (#46138)

Co-authored-by: zhaoyingli <[email protected]>
Co-authored-by: JZ-LIANG <[email protected]>
Co-authored-by: zhaoyingli <[email protected]>
Co-authored-by: caozhou <[email protected]>
Co-authored-by: caozhou <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants