Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 20, 2021
1 parent 3ac9f8e commit 64b90e7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 20 deletions.
6 changes: 3 additions & 3 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
Expand All @@ -26,7 +26,7 @@ def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
super().__init__(save_path=save_path)
self.multi_label = multi_label

def per_sample_transform(self, samples: Any) -> Any:
def per_sample_transform(self, samples: Any) -> List[Any]:
if self.multi_label:
return F.sigmoid(samples).tolist()
else:
Expand All @@ -42,5 +42,5 @@ def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs):

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x).int()
return F.sigmoid(x)
return F.softmax(x, -1)
25 changes: 13 additions & 12 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Task(LightningModule):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`.
preprocess: :class:`.Preprocess` to use as the default for this task.
postprocess: :class:`.Postprocess` to use as the default for this task.
preprocess: :class:`~flash.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

def __init__(
Expand Down Expand Up @@ -99,7 +99,8 @@ def step(self, batch: Any, batch_idx: int) -> Any:
y_hat = self.to_metrics_format(y_hat)
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
logs[name] = metric(y_hat, y) # log the metric itself if it is of type Metric
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
logs.update(losses)
Expand Down Expand Up @@ -180,17 +181,17 @@ def _resolve(
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise.
"""Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) and ``old_*`` otherwise.
Args:
old_preprocess: :class:`.Preprocess` to be overridden.
old_postprocess: :class:`.Postprocess` to be overridden.
new_preprocess: :class:`.Preprocess` to override with.
new_postprocess: :class:`.Postprocess` to override with.
old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden.
old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden.
new_preprocess: :class:`~flash.data.process.Preprocess` to override with.
new_postprocess: :class:`~flash.data.process.Postprocess` to override with.
Returns:
The resolved :class:`.Preprocess` and :class:`.Postprocess`.
The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
"""
preprocess = old_preprocess
if new_preprocess is not None and type(new_preprocess) != Preprocess:
Expand All @@ -203,7 +204,7 @@ def _resolve(
return preprocess, postprocess

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess`
"""Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
- Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
Expand All @@ -212,7 +213,7 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
- :class:`.DataPipeline` passed to this method.
Args:
data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`.
data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
Returns:
The fully resolved :class:`.DataPipeline`.
Expand Down
5 changes: 3 additions & 2 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES


def binary_cross_entropy_with_logits(x, y):
def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
nn.Linear(num_features, num_classes),
)

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
if self.hparams.multi_label:
return self.head(x)
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def apply_pool(self, x):
x = self.pooling_fn(x, dim=-1)
return x

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)

# bolts ssl models return lists
Expand Down
7 changes: 5 additions & 2 deletions tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def test_multilabel(tmpdir):
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
image, _ = ds[0]
image, label = ds[0]
predictions = model.predict(image.unsqueeze(0))
assert len(predictions[0]) == num_classes
assert (torch.tensor(predictions) > 1).sum() == 0
assert (torch.tensor(predictions) < 0).sum() == 0
assert len(predictions[0]) == num_classes == len(label)
assert len(torch.unique(label)) <= 2

0 comments on commit 64b90e7

Please sign in to comment.