Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-563] Refactor R optimizers to fix memory leak #11374

Merged
merged 11 commits into from
Jul 25, 2018
Merged

[MXNET-563] Refactor R optimizers to fix memory leak #11374

merged 11 commits into from
Jul 25, 2018

Conversation

jeremiedb
Copy link
Contributor

Fix R memory leakage through refactor of the optimizers.
Given the mutatable NDArray isn't supported in R, optimizers have been ported as symbolic update, an executor being created for each weight.

! Only optimizers update symbols have been used for now, so only SGD, rmsprop and Adam are now supported. Will need to reimplement the manual update for Adagrad and Adadelta.

Memory is now kept at low level even for very large networks and embeddings.
Tested on CPU and single GPU, not multiple GPUs.

@jeremiedb
Copy link
Contributor Author

Provides fix for #10721 #10928

Copy link
Member

@anirudhacharya anirudhacharya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremiedb can you add tests for the optimizers - #7196 None of the optimizers have tests.

Let me know if you want me to pitch in, we can collaborate on this. I can begin work on fixing some of the broken optimizers, starting next month.

@thirdwing
Copy link
Contributor

@hetong007 Please take a look at this.

@jeremiedb
Copy link
Contributor Author

@anirudhacharya Sure, I'll add tests.
Would be great if you could jump in as well. I was expecting to add the missing Adagrad and Adadelta optimizers within a week in order to match existing functionnalities as soon as possible. Would you be disposed looking at the non-mutatble NDArrays which was actually the root cause leading to refactor optimizers into symbolic execution? Thanks!

@anirudhacharya
Copy link
Member

anirudhacharya commented Jun 26, 2018

@jeremiedb sure!

@jeremiedb
Copy link
Contributor Author

@hetong007 With Adadelta and Adagrad, the same functionnalities as now supported (and non-centered rmsprop has been added within rmsprop). Tests to be added.


count <- 0
num_update <- 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please eliminate trailing whitespaces.

#' Step size.
#' @param gamma1 float, default=0.95
#'
#' @param learning.rate float, default=1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a strong reason to change the default values? It may break other people's code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to align with the Python's package default. I'll revert to existing default if you see if you see more harms from this change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to keep the default value as in R package. It's not necessary to set default values the same across interfaces, while it may break user's script if we change it, especially for parameters in optimizer.

mx.opt.get.updater <- function(optimizer, weights, ctx) {

exec_list <- lapply(seq_along(weights), function(i) {
if (is.null(weights[[i]])) return(NULL) else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please format here and below as

if (condition) {
   xxx
} else {
  yyy
}

@hetong007
Copy link
Contributor

You may use the help from your editor to eliminate trailing spaces. Also, please add tests for new optimizers.

@jeremiedb
Copy link
Contributor Author

@hetong007 Tests added and trailing space fixed.

#' @param learning.rate float, default=0.002
#' Step size.
#' The initial learning rate.
#' @param gamma1 float, default=0.95
#' decay factor of moving average for gradient, gradient^2.
#' @param gamm2 float, default=0.9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gamm2 -> gamma2

epsilon = 1e-8,
wd = 0,
rescale.grad = 1,
clip_gradient = -1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between setting it to 1 and -1 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When clip_gradient is < 0, no clipping is applied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add that to the docstring, otherwise people may feel confused without looking at the code.

@aplikaplik
Copy link

aplikaplik commented Jul 18, 2018

Hi, i have same problem with GPU memory (package mxnet, fuction mx.mlp, just started with that). But as i'm traffic engineer, i have a little problem with orientation in provided solution (i'm use r for long time but not in this detail). Can you pleas be so kind and provide some "Refactor R optimizers for dumies" version of solution? :).

@jeremiedb
Copy link
Contributor Author

@hetong007 Is there any blocking element remaining?

@jeremiedb
Copy link
Contributor Author

@aplikaplik Is your issue specificly about mxnet MLP function or also related to gpuR? As for mx.mlp, this PR should fix memory encountered with both mx.mlp and mx.model.Feedforward.create since they are all relying on the same optimizers update routine.
Roughly speaking, the idea of the fix was to create an symbolic graph for each parameter to be trained. At each update, the state weights are update through a forward pass on a graph in which the calculation are those associated with each of the optimisation routine (SGD, Adam, ...). At the end, it's the same approach, but using the symbolic representation (mx.symbol) rather than the imperative interface (mx.nd) since the latter had apparent memory leak.

@hetong007 hetong007 merged commit be47870 into apache:master Jul 25, 2018
@aplikaplik
Copy link

@jeremiedb Hi, i referenced gpuR because it could be usefull for @cdeterman. My problem is much simpler, because i don't know how implement this code (optimizer.R) to my R instalation or mxnet package? Im sorry for that elemtal question.

XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* refactor R optimizers to fix memory leak

* add Adadelta and Adagrad

* fix comments

* fix comments

* fix comments

* add tests

* fix whitespaces

* fix whitespaces

* fix typo

* fix typo

* add doc on clipping
@jeremiedb jeremiedb deleted the optim-R branch September 6, 2018 03:46
@onomatet
Copy link

@hetong007 @jeremiedb
Hi, could you please check if #17207 is related to the changes in this pull request?
If yes, are there any temporary solutions to alter the learning rate of the optimizers in symbolic representation?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants