-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fine-tune with R API #4817
Comments
I think the last line may not be possible. Others should be OK. |
Hello, I think you should actually be able to do it by using the get.output:
Tricky part is that get.output seems to only accept a numeric index and not a named one. After, it should be relatively straightforward through the use of the initializers:
And finally reassigning the original weights to the new initialiser for all but the new FC layer:
|
Thank you for the response. Is it possible to replace the line |
I'm not an expert, but i am pretty sure if you set all the weights to zero then a network is not able to learn anything (If that what you meant to do by setting the params to NULL) |
I hope that setting subset of parameters to NULL will force some kind of default initialization during training process for these parameters (similar to |
Well, I used jeremiedb's answer for finetuning a network and had to add some lines because after calling the FeedForward function I got an error that the bias for new_fc_24 was NULL, so in that case at least there was no such default initialization. |
Thanks OwlECoyote for pointing out the bias argument. Lines 416-417 in /model.R should make the behavior of the model clearer:
To make the picture clearer, the Feedforward wrapper calls the model.train function which initiates the arg.params to 0 before assigning the Pre-trained arg.params. Therefore, if there are missing arguments like statist-bhfz mentionned, their weights will be "initialised" to 0. In this case, the model will run, but won't learn, so it's not a good idea! Bottom line: arg.params, if provided, should simply be list containing the ndarrays for each of the model arguments (symbol$arguments, excluding data and label). And such list must have its names match the model arguments names. |
Hi, I recently attempted fine-tuning ResNet101 based on your comments. However, when I initiate training (using GPU), R outputs the following msg:
Then it crashes. I updated mxnet to 1.0.0. Fine-tuning using Inception-BN at 126 does not crash. I am wondering if there is an internal bug in mxnet R to cause this crash. |
Hi,
Many thanks for this library which provides almost only way to do deep learning in R. Current R documentation is not very comprehensive, but examples and discussions in issues help a lot.
Now I got stuck when looking for R equivalent of
get_internals()
method in Python (http://mxnet.io/how_to/finetune.html).Unfortunately,
model$symbol$get.internals()
doesn't return any adoptable R object, so modifying symbol in pre-trained model seems to be impossible in R API without direct editing of symbol.json file.Is it possible to add in the R package capabilities for writing code like this?
The text was updated successfully, but these errors were encountered: