Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Support tensors and arrays for class_weight #1413

Merged
merged 1 commit into from
Jul 7, 2023
Merged

fix: Support tensors and arrays for class_weight #1413

merged 1 commit into from
Jul 7, 2023

Conversation

ntw-au
Copy link
Contributor

@ntw-au ntw-au commented Jun 13, 2023

Avoids ambiguous truth value ValueError when the class_weight input parameter is either a PyTorch tensor or a NumPy array.

Repro follows—also fails with np.array instead torch.tensor, succeeds with Python sequence.

import torch
from torchgeo.trainers.segmentation import SemanticSegmentationTask

SemanticSegmentationTask(
    model='unet',
    backbone='resnet101',
    loss='ce',
    learning_rate=0.01,
    weights=None,
    class_weights=torch.tensor([0.25, 0.5, 0.25]),
    in_channels=1,
    num_classes=3,
    ignore_index=None
)

Expected:

<prints PyTorch model>

Actual:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path>/torchgeo/torchgeo/trainers/segmentation.py", line 171, in __init__
    self.config_task()
  File "<path>/torchgeo/torchgeo/trainers/segmentation.py", line 69, in config_task
    torch.FloatTensor(self.class_weights) if self.class_weights else None
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@github-actions github-actions bot added trainers PyTorch Lightning trainers testing Continuous integration testing labels Jun 13, 2023
@ntw-au
Copy link
Contributor Author

ntw-au commented Jun 14, 2023

@microsoft-github-policy-service agree company="Geomatic.AI Pty Ltd"

Avoids ambiguous truth value ValueError when the class_weight input
parameter is either a PyTorch tensor or a NumPy array.

Includes new tests for SemanticSegmentationTask's class_weight
parameter.
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point before the 0.5 release (expected in August) I'm planning on making all arguments explicit (as opposed to using **kwargs) so we can add type hints. So remind me to make sure all sequences are permitted, not just Tensors.

torchgeo/trainers/segmentation.py Show resolved Hide resolved
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't yet see a simpler way to handle this other than documenting that only lists are allowed. Maybe @nsutezo can review too.

@adamjstewart adamjstewart added this to the 0.5.0 milestone Jun 16, 2023
@adamjstewart adamjstewart merged commit b0ae5be into microsoft:main Jul 7, 2023
@adamjstewart
Copy link
Collaborator

Working on type hints in #1541 and it's actually difficult to support lists, arrays, and tensors. Might just support tensors since that's what nn.CrossEntropyLoss supports.

@ntw-au
Copy link
Contributor Author

ntw-au commented Sep 4, 2023

I think that's a good idea, as a user I'd expect tensors as first preference, and it's easy for callers to change arrays and lists into a tensor themselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants