-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
@thomelane @ThomasDelteil @safrooze or @indhub please take a look |
|
||
Loss function is used to calculate how the output of the network different from the ground truth. In case of the logistic regression the ground truth are class labels, which can be either 0 or 1. Because of that, we are using `SigmoidBinaryCrossEntropyLoss`, which suites well for that scenario. | ||
|
||
Trainer object allows to specify the method of training to be used. There are various methods available, and for our tutorial we use a widely accepted method Stochastic Gradient Descent. We also need to parametrize it with learning rate value, which defines how fast training happens, and weight decay which is used for regularization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"We also need to parametrize it with learning rate value, which defines how fast training happens" The learning rate of SGD defines the weight updates, not necessarily how fast training happens. I propose to reword it a bit to avoid confusion that large LR will lead to fast training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, fixed.
@@ -0,0 +1,215 @@ | |||
|
|||
# Logistic regression using Gluon API explained |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tutorial is very similar to this Straight Dope tutorial. Did you consider focusing purely on the confusion around Accuracy
and pointing user to Straight Dope for further reading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of similarity with the tutorial you have mentioned. My personal pet-peeve with the tutorial in The Straight Dope is that it doesn't use library functions for sigmoid, loss and accuracy. I believe it creates an impression that you need to write a lot of boiler plate code instead of using what MXNet offers.
Also, there were confusions around number of neurons in forum and stackoverflow which is directly relates to the usage of Accuracy
, but needs a bit more context to cover.
mx.random.seed(12345) # Added for reproducibility | ||
``` | ||
|
||
In this tutorial we will use fake dataset, which contains 10 features drawn from a normal distribution with mean equals to 0 and standard deviation equals to 1, and a class label, which can be either 0 or 1. The length of the dataset is an arbitrary value. The function below helps us to generate a dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logistic regression tutorial in straight dope uses the Adult dataset. Did you consider using similar dataset (or perhaps simply pointing the user to that tutorial)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to add extra code for data loading and data processing, because this is not the point of this tutorial. The optimal way would be if there is a binary classification dataset in the mxnet itself, so it can be loaded in one line and no pre-processing would be required.
|
||
## Working with data | ||
|
||
To work with data, Apache MXNet provides Dataset and DataLoader classes. The former is used to provide an indexed access to the data, the latter is used to shuffle and batchify the data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Point user to links to API doc and tutorials for further reading and perhaps skip explanation of dataset/dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, added links and remove my explanation. Added a link to Datasets and Dataloaders tutorial instead.
|
||
This separation is done because a source of Dataset can vary from a simple array of numbers to complex data structures like text and images. DataLoader doesn't need to be aware of the source of data as long as Dataset provides a way to get the number of records and to load a record by index. As an outcome, Dataset doesn't need to hold in memory all data at once. Needless to say, that one can implement its own versions of Dataset and DataLoader, but we are going to use existing implementation. | ||
|
||
Below we define 2 datasets: training dataset and validation dataset. It is a good practice to measure performance of a trained model on a data that the network hasn't seen before. That is why we are going to use training set for training the model and validation set to calculate model's accuracy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend to keep it brief and mention "we define training and validation dataset". The difference between train/validation/test datasets are outside of this tutorial's scope.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Removed this one as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"training set" -> "training dataset"
"validation set" -> "validation dataset"
|
||
## Defining and training the model | ||
|
||
In real application, model can be arbitrary complex. The only requirement for the logistic regression is that the last layer of the network must be a single neuron. Apache MXNet allows us to do so by using `Dense` layer and specifying the number of units to 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make references to functions a link to the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
## Tip 4: Convert probabilities to classes before calculating Accuracy | ||
|
||
`Accuracy` metric requires 2 arguments: 1) a vector of ground-truth classes and 2) A tensor of predictions. When tensor of predictions is of the same shape as the vector of ground-truth classes, `Accuracy` class assumes that it should contain predicted classes. So, it converts the vector to `Int32` and compare each item of ground-truth classes to prediction vector. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is one called a vector and the other called a tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced it with "vector of matrix"
|
||
`Accuracy` metric requires 2 arguments: 1) a vector of ground-truth classes and 2) A tensor of predictions. When tensor of predictions is of the same shape as the vector of ground-truth classes, `Accuracy` class assumes that it should contain predicted classes. So, it converts the vector to `Int32` and compare each item of ground-truth classes to prediction vector. | ||
|
||
Because of the behaviour above, you will get an unexpected result if you just pass the output of `Sigmoid` function as is. `Sigmoid` function produces output in range [0; 1], and all numbers in that range are going to be casted to 0, even if it is as high as 0.99. To avoid this we write a custom bit of code, that: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having a tip here, I recommend to write the rounding as a separate function in notebook and include a paragraph on top explaining what it's doing. This paragraph is really the only motivation for writing this tutorial.
|
||
The same is not true, if the output shape of your function is different from the shape of ground-truth classes vector. For example, when doing multiclass regression with `Softmax` as an output, the shape of the output is going to be *number_of_examples* x *number_of_classes*. In that case we don't need to do the transformation above, because `Accuracy` metric would understand that shapes are different and will assume that the prediction contains probabilities of an example to belong to these classes - exactly what we want it to be. | ||
|
||
This makes things a little bit easier, and that's why I have seen examples where `Softmax` is used as an output of prediction. If you want to do that, make sure to change the output layer size to 2 neurons, where each neuron will provide a value of an example to belong to class 0 and 1 respectively. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't call it easier. Softmax allows the output layer to have more flexibility in learning because it contains twice the number of parameters to learn.
|
||
## Conclusion | ||
|
||
In this tutorial I explained some potential pitfalls to be aware about when doing logistic regression in Apache MXNet Gluon API. There might be some other challenging scenarios, which are not covered in this tutorial, like dealing with imbalanced classes, but I hope this tutorial will serve as a guidance and all other potential pitfalls would be covered in future tutorials. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"be aware of"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
## Conclusion | ||
|
||
In this tutorial I explained some potential pitfalls to be aware about when doing logistic regression in Apache MXNet Gluon API. There might be some other challenging scenarios, which are not covered in this tutorial, like dealing with imbalanced classes, but I hope this tutorial will serve as a guidance and all other potential pitfalls would be covered in future tutorials. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to keep the conclusion factual and not include reference to one of many challenges or "hoping" the tutorial is useful. Also this is not a blog (no posting time associated with it). Future tutorials has no meaning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrote the conclusion by reiterating over main points of the tutorial.
The tutorial looks awesome! A few comments:
|
@@ -59,9 +55,9 @@ val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) | |||
|
|||
## Defining and training the model | |||
|
|||
In real application, model can be arbitrary complex. The only requirement for the logistic regression is that the last layer of the network must be a single neuron. Apache MXNet allows us to do so by using `Dense` layer and specifying the number of units to 1. | |||
The only requirement for the logistic regression is that the last layer of the network must be a single neuron. Apache MXNet allows us to do so by using [Dense](https://mxnet.incubator.apache.org/api/python/gluon/nn.html#mxnet.gluon.nn.Dense) layer and specifying the number of units to 1. The rest of the network can be arbitrary complex. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"arbitrarily complex"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Usually, it is not enough to pass the training data through a network only once to achieve high Accuracy. It helps when the network sees each example multiple times. The number of displaying every example to the network is called `epoch`. How big this number should be is unknown in advance, and usually it is estimated using trial and error approach. | ||
|
||
Below we are defining the main training loop, which go through each example in batches specified number of times (epochs). After each epoch we display training loss, validation loss and calculate accuracy of the model using validation set. For now, let's take a look into the code, and I am going to explain the details later. | ||
The next step is to define the training function in which we iterate over all batches of training data, execute the forward pass on each batch and calculate training loss. On line 19, we sum losses of every batch per an epoch into a single variable, because we calculate loss per single batch, but want to display it per epoch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"batch per epoch"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
1. Subtracts a threshold from the original prediction. Usually, the threshold is equal to 0.5, but it can be higher, if you want to increase certainty of an item to belong to class 1. | ||
In case when there are 3 or more classes, one cannot use a single Logistic regression, but should do multiclass regression. The solution would be to increase the number of output neurons to the number of classes and use `SoftmaxCrossEntropyLoss`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- "In case there are".
- I think the correct term would be multi-class classification. I suggest to simply remove this sentence to not create any confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
@yifeim, Thanks for comments, I am answering them in order:
I haven't seen I added a showcase how to use F1 metric.
|
Awesome. Done and done! Thanks for addressing the questions. |
* Add logistic regression tutorial * Code review fix * Add F1 metric, fix code review comments * Add Download buttons script
Description
Add new tutorial on how to do logistic regression with Gluon API (writing minimal amount of custom code) + provided some explanation why it should be done in that way
Checklist
Essentials
Changes