Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Logical Partition & Dist Op #35117

Merged
merged 28 commits into from
Sep 2, 2021

Conversation

JZ-LIANG
Copy link
Contributor

@JZ-LIANG JZ-LIANG commented Aug 24, 2021

PR types

New features

PR changes

Others

Describe

Add Partitioner and Dist Op implement of Auto Parallel.
Partitioner: convert a serial network to distributed networks where the op and var are partitioned into different ranks.
Dist Op: implement the computation and communication logic of dist op.

Tensor-Parallel & Data-Parallel are supported in Auto Parallel now~

Functions added by this PR are not supposed to be called by user directly, how to use Auto Parallel please refer to PR

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JZ-LIANG JZ-LIANG changed the title [Auto Parallel] Logical Partition & Update Dist Op [Auto Parallel] Logical Partition & Dist Op Aug 24, 2021
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)
# print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed ~

# NOTE Theoretically, the MP param init broadcast should be handled by
# each dist op itself. but if we insert the broadcast op at that moment, the broadcast
# will before the initializer, which lead to a undertermined case.
if self._enable_tensor_parallel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little curious why mp doesn't split parameters? Since the purpose of mp is used to split parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we expose the nn.Linear to user which is consist of two ops: weight-matmul & bias-add.
in MP all weight(matmul)will be spilted, but in row parallel, the bias in nn.Linear is not spilted.

this is the NOTE for the special case for nn.Linear bias in row parallel.

return no_grad_set_name


def _get_no_grad_set(loss, no_grad_set=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to take care of no grad set ourselves?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is used for finetuning, in finetuning we should allow user to set which parameter will not be update (no grad)

for idx, op in reversed(list(enumerate(main_global_block.ops))):
if is_loss_grad_op(op):
loss_grad_var = main_global_block.vars[op.output_arg_names[0]]
main_global_block._insert_op_without_sync(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When should we use without_sync and when should we use sync_with? What't the purpose of sync_with function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sync means the synchronize between python-end program and C++-end program. every time you modify one end and want the modification to be effected in another end, you should sync them.

# NOTE naive gradient sync without overlapping
# so there is not need to sync between calc and comm
# collecting grad var
grad_to_sync = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following statements to build grad_to_sync are very hard for me to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the actual meaning for that is: the allreduce (of that var )to be sync. we should sync the allreduce to ensure the optimizer update is conduct after grad allreduce.

by now the fleet.distributed_strategy that need transpile forward program are following:
1. AMP
2. Recompute
4. sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, it will be better remove some statements without corresponding implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done !

self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this dp/mp strategy in the next step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got, this will be deleted in next major in Sep. Auto Parallel will NOT hold a global view of MP/DP/PP

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@aoyulong
Copy link
Contributor

LGTM

@JZ-LIANG JZ-LIANG merged commit a622b70 into PaddlePaddle:develop Sep 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants