Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incorrect batch_sizes when Dataloader returns a dict with multiple tensors. #3668

Closed
gerardsn opened this issue Sep 26, 2020 · 21 comments · Fixed by #3888
Closed

incorrect batch_sizes when Dataloader returns a dict with multiple tensors. #3668

gerardsn opened this issue Sep 26, 2020 · 21 comments · Fixed by #3888
Assignees
Labels
bug Something isn't working priority: 0 High priority task
Milestone

Comments

@gerardsn
Copy link

🐛 Bug

Tracked batch sizes in result object are incorrect when a Dataloader returns a dict with multiple tensors.

To Reproduce

Create data loader that returns a dict, e.g. batch = {'batchA': tensor_A, 'batchB': tensor_B}.
Both entires have batch size N with N != 2.
For this example a batch size of 2 will be logged since len(batch) == 2.

https://github.com/PyTorchLightning/pytorch-lightning/blob/05e5f03fd7c851b06ca5e34b39eb660857b8f00c/pytorch_lightning/trainer/evaluation_loop.py#L147-L150
https://github.com/PyTorchLightning/pytorch-lightning/blob/05e5f03fd7c851b06ca5e34b39eb660857b8f00c/pytorch_lightning/trainer/training_loop.py#L304-L306

Expected behavior

Log correct batch size.
I'm not sure what can be defined as the 'correct' batch size when there are multiple tensors, but I expect that each tensor in the dict has the same batch_size. So, maybe something like:

if is_result_obj:
    if isinstance(batch, dict):
        batch = batch[list(batch.keys())[0]]
    result_obj.track_batch_size(len(batch))
@gerardsn gerardsn added bug Something isn't working help wanted Open to be worked on labels Sep 26, 2020
@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 1, 2020

I think doing just len(batch) is still wrong since here if the batch is a tuple or some kind of custom batch datatype then len(batch) will be wrong. Considering the basic mnist example too it will give 2 only which is wrong.

@gerardsn
Copy link
Author

gerardsn commented Oct 2, 2020

This should probably catch most things. Might be a bit much though.
It returns 1 if it fails to determine the batch size to prevent issues with weighted averaging in reduce_on_epoch_end.

if is_result_obj:
    result_obj.track_batch_size(unpack_batchsize(batch))

# maybe add as staticmethod to ResultObj?
def unpack_batchsize(sample):
    """ 
    Recursively unpack sample to find a torch.Tensor.
    returns len(tensor) when found, or 1 when it hits an empty or non iterable.
    """
    if isinstance(sample, torch.Tensor):
        sample = len(sample)
    elif isinstance(sample, dict):
        sample = next(iter(sample.values()), 1)
    elif isinstance(sample, Iterable):
        sample = next(iter(sample), 1)
    else:
        sample = 1  

    if isinstance(sample, int):
        return sample
    return unpack_batchsize(sample)

@carmocca
Copy link
Contributor

carmocca commented Oct 2, 2020

I suggest adding a function to the LightningModule batch_len_fx which defaults to len if it is not overriden. Anything could be a batch and lightning shouldn't have the responsability of supporting any batch type.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

Exactly what I had in mind @carmocca. Or maybe simple ask to put batch_size in .log itself if on_epoch=True??

.log('some_metric', metric_value, on_epoch=True, batch_size=batch_size)
.log('some_metric', metric_value, on_epoch=False)

@gerardsn
Copy link
Author

gerardsn commented Oct 2, 2020

Lightning currently defaults to weighted_mean for reduction on epoch end by substituting the reduction method if it is torch.mean:

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L389-L390

If this is the desired behaviour, I think Lightning should at least attempt getting a reasonable estimate for the batch size. In most use cases the dataloader will return multiple tensors, resulting in an incorrect batch estimate if len is the default. (e.g. any supervised method has at least (X, y) in its batch, producing len(batch)=2 as mentioned by @rohitgr7)

This could still be done using batch_len_fx though. On first call, if the method is not overriden, replace the batch_len_fx with a reasonable estimate based on the type of batch. (e.g. len of [tensor, first value in Iterable])

@gerardsn
Copy link
Author

gerardsn commented Oct 2, 2020

Exactly what I had in mind @carmocca. Or maybe simple ask to put batch_size in .log itself if on_epoch=True??

.log('some_metric', metric_value, on_epoch=True, batch_size=batch_size)
.log('some_metric', metric_value, on_epoch=False)

this should work too. Probably default to 1 if not provided since len is likely to be wrong.

@fogside
Copy link

fogside commented Oct 2, 2020

@gerardsn I have a problem exactly with this weighted_mean function.
I'm working with the latest Lightning version from master.

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L369

It gets outputs = [{'checkpoint_on': tensor(28.3303, device='cuda:0'), 'val_loss': tensor(28.3303, device='cuda:0'), 'val_precision1': 0.12652068126520682}].
Because I have only one batch per epoch in validation.

Lightning tries to reduce on epoch end.
It feeds
result = tensor([27.8364], device='cuda:0'), weights=tensor([2]), into the weighted_mean function and I get an error here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L897
AttributeError: 'list' object has no attribute 'device'
I think it's related to this issue. It would be nice to not reduce anything if it's just one batch per epoch.

@edenlightning edenlightning added this to the 0.9.x milestone Oct 2, 2020
@edenlightning edenlightning added the priority: 0 High priority task label Oct 2, 2020
@edenlightning edenlightning changed the title Result object incorrect batch_sizes incorrect batch_sizes when Dataloader returns a dict with multiple tensors. Oct 2, 2020
@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

@fogside in your example result is a tensor so result.device should not though an error.

@fogside
Copy link

fogside commented Oct 2, 2020

@fogside in your example result is a tensor so result.device should not though an error.

But it's a list with a tensor inside.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

result = tensor([27.8364], device='cuda:0'), weights=tensor([2])

you refering to this right?

@fogside
Copy link

fogside commented Oct 2, 2020

result = tensor([27.8364], device='cuda:0'), weights=tensor([2])

you refering to this right?

Sorry, I just realized, that I was mistaken.
Actually it calls this method twice for some reason.
I added prints at the beginning of weighted_mean function and at the reduce_on_epoch_end (I also changed the number of batches in this example)

Result[k] tensor([23.6331, 26.0617, 24.0941, 25.3255], device='cuda:0')
result:  tensor([23.6331, 26.0617, 24.0941, 25.3255], device='cuda:0')
weights:  tensor([2, 2, 2, 2])
Result[k] [0.14285714285714285, 0.06451612903225806, 0.056179775280898875, 0.13793103448275862]
result:  [0.14285714285714285, 0.06451612903225806, 0.056179775280898875, 0.13793103448275862]
weights:  tensor([2, 2, 2, 2])

And on the second time it gives me the error.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

are you logging non-tensor values? maybe doing .item() somewhere in the logs? if not, can you put .log statements here??

@fogside
Copy link

fogside commented Oct 2, 2020

are you logging non-tensor values?

Yes, I was calculating precision in numpy.. Isn't it possible to log non-tensor values?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

no.. also to calculate precision or anyother metric you can try pl.metrics package which computes these metrics on the current device itself.

or you can just do torch.tensor(numpy_value) in .log

@fogside
Copy link

fogside commented Oct 2, 2020

no.. also to calculate precision or another metric you can try pl.metrics package which does all of these on the current device itself.

I see. Thank you!
Actually I was trying to work with pytorch-metric-learning and used some function for topk precision estimation from there. But it looks quite tough to merge these 2 frameworks. I see now that topk precision should be calculated in pytorch. Another thing is that I need to have as big batch as possible to get a good topk estimation (it's even better to have the whole val set), that's why I found it hard to make this estimations in the validation_step. Maybe I should look into some Callbacks?
But it's not related to this issue.

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 2, 2020

already working on topk accuracy. Maybe will add topk precision and recall in pl.metrics as well. Can you point me to the implementation of topk precision in pytorch-metric-learning package. It would be helpful. Thanks :)

@fogside
Copy link

fogside commented Oct 3, 2020

already working on topk accuracy. Maybe will add topk precision and recall in pl.metrics as well. Can you point me to the implementation of topk precision in pytorch-metric-learning package. It would be helpful. Thanks :)

It's great!
Sure. I used the class AccuracyCalculator like this

accuracy_calculator = AccuracyCalculator(include=("mean_average_precision_at_r"),  k=5)
accuracies = self.accuracy_calculator.get_accuracy(embeddings,
                                                           embeddings,
                                                           labels,
                                                           labels,
                                                           True)

Implementation:
https://github.com/KevinMusgrave/pytorch-metric-learning/blob/10bed5ee8719a543827aa32ea658603c2fcb0130/src/pytorch_metric_learning/utils/accuracy_calculator.py#L45

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 3, 2020

So I guess 2 things should be fixed:

  • Track correct batch_size
  • Allow non-tensor numeric values in .log(...)

@williamFalcon
Copy link
Contributor

@gerardsn I have a problem exactly with this weighted_mean function.
I'm working with the latest Lightning version from master.

https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L369

It gets outputs = [{'checkpoint_on': tensor(28.3303, device='cuda:0'), 'val_loss': tensor(28.3303, device='cuda:0'), 'val_precision1': 0.12652068126520682}].
Because I have only one batch per epoch in validation.

Lightning tries to reduce on epoch end.
It feeds
result = tensor([27.8364], device='cuda:0'), weights=tensor([2]), into the weighted_mean function and I get an error here:
https://github.com/PyTorchLightning/pytorch-lightning/blob/ebc1b23fa38d54e9805aa4356867369f064c7031/pytorch_lightning/core/step_result.py#L897

AttributeError: 'list' object has no attribute 'device'
I think it's related to this issue. It would be nice to not reduce anything if it's just one batch per epoch.

this is fixed on master

@williamFalcon
Copy link
Contributor

ok, making changes to this today.

What do we want as the default behavior? doesn't the custom reduce function solve the problem of custom batches etc?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 5, 2020

The batches are not tracked correctly.

@edenlightning edenlightning removed the help wanted Open to be worked on label Oct 5, 2020
williamFalcon added a commit that referenced this issue Oct 6, 2020
williamFalcon added a commit that referenced this issue Oct 6, 2020
williamFalcon added a commit that referenced this issue Oct 6, 2020
* Fixes #3668, #3887 as a bonus

* Fixes #3668, #3887 as a bonus
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants