Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Apr 20, 2021
1 parent 64b90e7 commit 7a807ba
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
12 changes: 8 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,10 @@ def _resolve(
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) and ``old_*`` otherwise.
"""Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use,
choosing ``new_*`` if it is not None or a base class
(:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`)
and ``old_*`` otherwise.
Args:
old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden.
Expand All @@ -204,7 +206,8 @@ def _resolve(
return preprocess, postprocess

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
- Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
Expand All @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
- :class:`.DataPipeline` passed to this method.
Args:
data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
data_pipeline: Optional highest priority source of
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
Returns:
The fully resolved :class:`.DataPipeline`.
Expand Down
5 changes: 1 addition & 4 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,4 @@ def __init__(

def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
if self.hparams.multi_label:
return self.head(x)
else:
return torch.softmax(self.head(x), -1)
return self.head(x)

0 comments on commit 7a807ba

Please sign in to comment.