forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update multi-task learning example (apache#12964)
* Update multi task learning example * Updating README.md
- Loading branch information
1 parent
03ef2b7
commit cb5f011
Showing
3 changed files
with
462 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
# Mulit-task learning example | ||
|
||
This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example and mocks up the multi-label task. | ||
This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example, trying to predict jointly the digit and whether this digit is odd or even. | ||
|
||
## Usage | ||
First, you need to write a multi-task iterator on your own. The iterator needs to generate multiple labels according to your applications, and the label names should be specified in the `provide_label` function, which needs to be consist with the names of output layers. | ||
For example: | ||
|
||
Then, if you want to show metrics of different tasks separately, you need to write your own metric class and specify the `num` parameter. In the `update` function of metric, calculate the metrics separately for different tasks. | ||
![](https://camo.githubusercontent.com/ed3cf256f47713335dc288f32f9b0b60bf1028b7/68747470733a2f2f7777772e636c61737365732e63732e756368696361676f2e6564752f617263686976652f323031332f737072696e672f31323330302d312f70612f7061312f64696769742e706e67) | ||
|
||
The example script uses gpu as device by default, if gpu is not available for your environment, you can change `device` to be `mx.cpu()`. | ||
Should be jointly classified as 4, and Even. | ||
|
||
In this example we don't expect the tasks to contribute to each other much, but for example multi-task learning has been successfully applied to the domain of image captioning. In [A Multi-task Learning Approach for Image Captioning](https://www.ijcai.org/proceedings/2018/0168.pdf) by Wei Zhao, Benyou Wang, Jianbo Ye, Min Yang, Zhou Zhao, Ruotian Luo, Yu Qiao, they train a network to jointly classify images and generate text captions | ||
|
||
Please refer to the notebook for a fully worked example. |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.