Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[RFC / WIP] refactor _train_multi_device to make more modular #909

Merged
merged 1 commit into from
Dec 18, 2015

Conversation

lukemetz
Copy link
Contributor

In an effort to make my own training loop for multi gpu training I found my self needing many bits and pieces out of the monolith _train_multi_device. This PR is an attempt to pull out the separate pieces so they can be reused.

@lukemetz lukemetz force-pushed the lm/model_refactor branch 3 times, most recently from caebf49 to d4ce662 Compare December 12, 2015 03:12
_load_data(data_batch, self.data_arrays)
_load_label(data_batch, self.label_arrays)

def forward(self, is_train=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_train default to True does not comply with exec.forward which defaults to false. Consider change to False.

@piiswrong
Copy link
Contributor

Since this is a major refactor of a core module, I would be really helpful if you can do some extra testing to make sure it doesn't break anything.

Could you test it on some examples (single card and multi card) in the example folder?

@lukemetz
Copy link
Contributor Author

Thanks for the comments! Will do. As of now, there is no front end change. I did do a few tests on train_mnist and train_cifar with and without multi GPU from the examples.?
I could not figure out a way to get into the update_on_kvstore == False path. Do any of the tests on travis do this or is there a way to use this?

@piiswrong
Copy link
Contributor

The examples in image-classification now all use kvstore (which I think should be fxied). You can get into the update_on_kvstore=False path by calling Feedforward.fit(kvstore=None)

@piiswrong
Copy link
Contributor

Overall I like this design. I have been thinking about refactoring this for a while but haven't got the time to do it. Thank you so much for your contribution!

@tqchen @antinucleon what do you think?

@lukemetz
Copy link
Contributor Author

Updated with moves and comments. Tested mnist single and multi to completion and cifar single and multi are able to start up.
My lint locally is funky, so this might fail travis but I will fix it.

@lukemetz lukemetz force-pushed the lm/model_refactor branch 5 times, most recently from 25cda10 to e4e4edd Compare December 12, 2015 05:50
@piiswrong
Copy link
Contributor

@lukemetz Thanks! Looks good to me but since this can potentially affect a lot of things I'll wait for someone else to have a look. They are all at NIPS right now so it might take some time before we can merge

@piiswrong
Copy link
Contributor

@tqchen could you take a look at this?

@tqchen
Copy link
Member

tqchen commented Dec 15, 2015

Sorry for the delay! Just get back from NIPS. I like the idea of executor manager as it wraps data parallel pattern we need.

The UpdateManager, however, seems was a thin wrapping of the steps that was in the _train_multi_device. I would suggest we remove this by either putting things back, or replace them by two functions in model.py.

Some comments

  • The general principle of adding modules(another layer of abstraction) is to make pieces re-usable among multiple places, in this case, there is not an example where the two things are used.
  • Function is usually preferred over classes when there are few ops as normally this makes less cause side-effect between calls

@tqchen
Copy link
Member

tqchen commented Dec 15, 2015

also loop in @pluskid @antinucleon @mli

@piiswrong
Copy link
Contributor

It's probably better to make Updater a method in KVStore. Since it's only needed if kvstore is used. otherwise a simple update call will be enough

@pluskid
Copy link
Contributor

pluskid commented Dec 15, 2015

I like the idea of crashing the giant _train_multi_device method into pieces. This looks good to me, though I do think using methods instead of extra classes is preferred. For example, ExecutorManager sounds like a very general thing, but actually here it only does data parallelization.

@lukemetz
Copy link
Contributor Author

@tqchen now that I look at it again I 100% agree on your comments about UpdateManager. Will move some code around into something along the lines of initialization and an update.

@pluskid How does a rename of ExecutorManager --> DataParallelExecutors sound?

In general I would agree with you all on functions vs classes. My concern in this case is having to manually manage multiple variables and having duplication passing them into functions. I had an hard time re-factoring this code as there was no separation / grouping of data like one gets with classes.

@tqchen
Copy link
Member

tqchen commented Dec 15, 2015

The DataParallelExecutors sound good to me, or something like DataParallelExecutorManager.
To make progress on this PR, let us go with the concrete proposals as we all agree on the spirits.
I think I am fine with either using class or methods as long as they remain private in model.py

If @pluskid have any proposal on possible list of methods, we can also compare and see which way should work. The refactoring by using methods can be adopted by R and Julia easier.

@pluskid
Copy link
Contributor

pluskid commented Dec 15, 2015

I'm actually fine with the current refactoring. As @lukemetz mentioned, refactoring with functions makes it a bit messier because a lot of local variables needs to be passed around. That is also the primary reason that I did not attempt to refactor this big method while implementing the Julia side.

That being said, I think I will not adapt the refactoring to the Julia side soon (but I'm not objecting if someone else want to do it), because I agree with @tqchen that refactoring is typically triggered by code sharing. At least before we added other models (RNNs, for example), I did not see how those code pieces are going to be used elsewhere. I think after this PR gets merged, @lukemetz maybe could show some examples of how he needs to write his own training loop and how the refactored code architecture makes it easier to do it. Then things will be more clear.

@lukemetz
Copy link
Contributor Author

Updated with comments. Let me know if there is anything else to be changed. The build appeared to have error-ed in cpp somewhere only on linux???

As to an example training loop: I started writing a wrapper around FeedForward to support a few features the FeedForward didn't have, specifically:

  • Multiple output cost evaluation on multiple data iterators.
  • Some log that I could write out that was separate from stdout that would work with my tools.

The actual code (modelled heavily off of Blocks) is currently a mess for a large number of reasons. It is basically a trimmed down version of _train_multi_device with a few extra callbacks for extensions and with more state exposed to extensions. (https://github.com/lukemetz/mxoid/blob/master/mxoid/loop.py).

I would be happy to go into more detail if anybody wants.

@mli
Copy link
Member

mli commented Dec 15, 2015

lgtm, thanks @lukemetz

btw, have you tested the codes? i would like to test at least example/image-classification/train_mnist.py with network = lenet, for cpu only, single gpu, and multiple gpus.

@lukemetz
Copy link
Contributor Author

@mli yep on all counts. Single and multi finished fine. cpu takes forever but started fine and is able to spit out a few samples/sec.

@piiswrong piiswrong mentioned this pull request Dec 15, 2015
Closed
@tqchen
Copy link
Member

tqchen commented Dec 18, 2015

The current code looks to be in good shape, @mli @pluskid @piiswrong please take a final look. going to merge this in 12 hours.

@pluskid
Copy link
Contributor

pluskid commented Dec 18, 2015

LGTM

1 similar comment
@mli
Copy link
Member

mli commented Dec 18, 2015

LGTM

tqchen added a commit that referenced this pull request Dec 18, 2015
[RFC / WIP] refactor _train_multi_device to make more modular
@tqchen tqchen merged commit e316055 into apache:master Dec 18, 2015
@tqchen
Copy link
Member

tqchen commented Dec 18, 2015

This is merge !

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants