Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-331] Single machine All Reduce Topology-aware Communication (Updated) #11591

Merged
merged 50 commits into from
Jul 24, 2018

Conversation

ctcyang
Copy link
Contributor

@ctcyang ctcyang commented Jul 6, 2018

Description

Single machine All Reduce Topology-aware Communication

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Proposed communication method shows speed-up compared to both existing methods (parameter server and NCCL) on small batch sizes for ResNet-50, VGG-16, Inception-v3 and AlexNet.
  • Communication method queries the single-machine multi-GPU link topology, and determines a suitable communication pattern to use.
  • Use feature by MXNET_KVSTORE_USETREE=1, default is 0.
  • Add knobs for tuning this communication method.
  • In future, will add auto-tuner to automatically choose between single-machine communication protocols (parameter server, NCCL, method proposed here).

Comments

Carl Yang added 29 commits June 4, 2018 03:51
…se PCI-E as fallback for GPUs that are not linked by NVLink
@ctcyang
Copy link
Contributor Author

ctcyang commented Jul 20, 2018

Yeah, I'm blocked by the test. I can't replicate it on local machine, but I can replicate it on Docker image.

@ctcyang
Copy link
Contributor Author

ctcyang commented Jul 23, 2018

@haojin2 @eric-haibin-lin @rahul003 Trying to get this into 1.3 release

if (dest_id != topo_id) {
CopyFromTo(buf_from.merged[merged_row],
&(buf_dest.copy_buf[merged_row][is_dest-1]),
priority);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Align the lines like:

CopyFromTo(buf_...,
           &(buf_...,
           priority);


// ComputeTreesTest with backtracking
// TODO(carlyang): comment out test for now
/*TEST(GpuTopology, TestComputeTrees1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's wrong with these tests? Do they not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They used to segfault only on CI, but now they should be fine. I fixed an off-by-1 bug.

def __exit__(self, ptype, value, trace):
os.environ[self._key] = self._prev_val

def test_device_pushpull():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test in the file test_nccl?

CommDeviceTree() {
inited_ = false;
gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_GPUARRAY_BOUND", 10000000);
backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_BACKTRACK", 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added documentation.

std::vector<float> link_matrix(devs_.size()*devs_.size());
GetP2PWeight(devs_, &link_matrix);
if (backtrack_)
LOG(WARNING) << "Using Backtracking to generate trees";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be LOG(INFO)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many other places as well where this change ought to be done

}
}

// Performs partition on each existing partition in graph W if partition has
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great comments for the functions, but could you fix comment style to how it's standard in the codebase. Here's an example https://github.com/apache/incubator-mxnet/blob/b4156da26cfe741619227ae726872b1255194900/src/kvstore/kvstore_utils.h#L37

std::vector<Context> devs_;

/// \brief Highest numbered device
int max_dev_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to see where this variable is used to ensure that cases when gpus '1,5,3,7' are given work. But it looks like this variable is not used? Please remove this then

// dev_id: 4 2 3 1 7 5 0
// and generated an n_gpus x n_gpus link topology matrix:
//
// 1) The reduction trees are saved as indices on 0, 1, ..., n_gpus
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify how many are generated

@ctcyang
Copy link
Contributor Author

ctcyang commented Jul 23, 2018

@Roshrini For keeping track of PR

@ctcyang ctcyang requested a review from szha as a code owner July 24, 2018 00:51
- Values: 0(false) or 1(true) ```(default=0)```
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated.

* MXNET_KVSTORE_GPUARRAY_BOUND
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it says multiple trees, but could you call out that this is for tree kvstore? especially because we have a similar variable MXNET_KVSTORE_BIGARRAY_BOUND

- When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth.

* MXNET_KVSTORE_BACKTRACK
- Values: 0(false) or 1(true) ```(Default=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting issue

@eric-haibin-lin eric-haibin-lin merged commit fe07d50 into apache:master Jul 24, 2018
@ctcyang ctcyang deleted the feature_multirootv9merge branch July 24, 2018 21:40
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
…pdated) (apache#11591)

* add multiroot all-reduce communication pattern

* fix bug with UpdateWeight

* fix PCI-E links appearing in weight matrix bug

* optimization to skip CopyFromTo in ReduceInner gains a bit of throughput

* remove unnecessary if statement

* Add tests

* add more tests, 6 tests left to add

* get rid of some dead code

* Add comments

* Add randomized tests for backtrack and kernighan-lin

* Fix Postprocess

* Add switch for first valid tree when num_gpus > 8, and for maximum weight when num_gpus <= 8

* Kernighan-Lin seems to find better trees

* get rid of printfs

* change defaults

* inherit from CommDevice instead of Comm

* Fix lint errors

* Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation problem, add header guard

* fix lint errors

* better header guard that works for tests

* get rid of unused variable warning

* retrigger jenkins

* resolve 2 comments

* address comment using Class to do test, get rid of extraneous test, use PCI-E as fallback for GPUs that are not linked by NVLink

* address comments

* fix a few bugs

* get rid of printfs

* get rid of print

* Comment out test for now

* fix 2 more bugs

* fix segfault

* change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of stdout

* Fix code alignment

* get rid of todo

* Make changes to env variable names to indicate they are TREE-related

* Add note saying when ARRAY_BOUND env var takes effect
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants