Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't able to train the num_classes=1 #521

Closed
kethan52 opened this issue May 2, 2022 · 5 comments
Closed

Can't able to train the num_classes=1 #521

kethan52 opened this issue May 2, 2022 · 5 comments
Labels
trainers PyTorch Lightning trainers

Comments

@kethan52
Copy link

kethan52 commented May 2, 2022

model =SemanticSegmentationTask(
    segmentation_model="fcn",
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=9,
    num_classes=1,
    num_filters=128,
    loss="jaccard",
    ignore_zeros=True,
    learning_rate=0.1,
    learning_rate_schedule_patience=0.05,
    
)

trainer = pl.Trainer(

                  gpus=0,
                  callbacks=[checkpoint_callback, early_stopping_callback],
                  logger=[csv_logger],
                  min_epochs=10,
                  max_epochs=100,
                  precision=32,
                  log_every_n_steps=0.01,
                  max_steps=8,
                  
               )

trainer.fit(model,dl,dl)

Hi, as I want to do the binary sematic segmentation I need num_clasess=1 can I have any scope of doing with single class

@isaaccorley
Copy link
Collaborator

This seems to be a duplicate of your other issues #513 and #514. Can you add some context to why this is a separate issue?

@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label May 3, 2022
@calebrob6
Copy link
Member

calebrob6 commented May 6, 2022

num_classes=1 will need to be trained with a different loss than PyTorch's cross entropy, which isn't currently implemented in SemanticSegmentationTask (you would need to use this https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss). Maybe we should add an assert that num_classes >= 2?

@adamjstewart
Copy link
Collaborator

Could we add a conditional that automatically switches to BCE for num_classes=1?

@calebrob6
Copy link
Member

calebrob6 commented May 6, 2022 via email

@adamjstewart
Copy link
Collaborator

I think this is also related to #245. So it's not just loss, it's also IoU.

@microsoft microsoft locked and limited conversation to collaborators Jun 18, 2022
@adamjstewart adamjstewart converted this issue into discussion #599 Jun 18, 2022

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

No branches or pull requests

4 participants