-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Generalize IO API to support any number of data / labels #468
Conversation
You might want to merge in my PR here:#456 |
@piiswrong Yes, please merge that PR and I will rebase later. |
Do we need to differentiate between data and label? |
The only difference is that during training data and label is copied into the network. But during prediction only data is copied. During evaluating on the validation set, data and label are copied into different places. I also found it a bit redundant but I did not have a better way to handle this in more general case. |
I see. |
@piiswrong The main change is:
I will make a more detailed document soon. |
@tqchen Hmm, I did not notice that. I agree this is a potential problem with |
Yes, that will need some special treatment in python API, i.e. have a cached instance that will be returned in the call to the next, instead of calling MXIterNext, and support a function like |
Thanks for the summary! With regard to shape check, probably implement a peek method? |
I guess the peek method from python API might be a good idea, may need to implement peek, next on base-class, and ask the child class to implement a special _next method to get the data. |
I think the current code is ready to be merged after we get the peek |
k, I should be able to figure that out by the weekend |
Actually I would like to propose to simplify of the current data iter base class. I think the idea is to have as less methods as possible so that the user should be able to write a customized data iter very easily, say by just writing a for loop with Python generator api. Common tools like shuffling within a large cached in-memory pre-loaded mini batches could be provided so that the users do not need to implement the same thing over and over again. But this requires more efforts and discussion about the interface. |
As long as we have a fixed set of api, it is OK. For example, having a subclass Iter that can take a generator, while have default impl for all the other functions |
@tqchen Yes, but writing default impl, for example, for
|
I get your point, yes I agree, iter and next are two things we need. The rest interface that are helpful(but not needed) is the getpad(to get number of padded instance, for predictor)
We might need the provide_data and provide_label, either as a list of functions. In that sense, I agree that the peek should be implemented in the subclass, to support provide_data and label for MXDataIter |
Yes getpad is not being suggested to be removed, they now exists in the batch object, which is the object returned in each iteration that contains the data, label and other information (including the pad) for that minibatch. |
Sounds good to me |
Since we now support multiple data/label, the metrics interface probably should also be updated? |
Yes that is one aspect that we also need to consider a bit. Currently I just replaced all existing metrics with one that naively deal with pred and label list of the same length, by evaluating them one by one and then accumulate. This should be backward compatible with existing behavior, but for long term, we need to think about what would be best interface. |
Yes, I guess that should be addressed in a separate issue, let us first aim to get this merged |
Please review and merge it. I added a simple cache for the first batch to avoid calling reset. The discussed simplification of base iter interface could wait later, otherwise accumulating conflicts with base could get quite complicated. I hope I did not overwrite something during merging. |
|
||
# reset the training data if reach the end of train_data, we only | ||
# need to deal with the following two situations: | ||
# 1. epoch_size is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is recently added by mu for distributed training, so maybe needed here. The logic is to run epoch_size batches on each machine, to avoid different number of epochs on each machine
I have one comment on support epoch_size parameter which is added by @mli for distributed training. Maybe need to add that back. This is used when different machines have different batches on their local data, and we need to make sure each machine run the same number of batches in one epoch in distributed setting. Otherwise one machine will run additional batch, and hang because other machine did not send statistics over in BSP setting |
Other parts LGTM |
I tried to recover that part here. |
Generalize IO API to support any number of data / labels
cool, this is merged |
MXDataIter
interfaceModel
train to use extended APIModel
predict to use extended API