Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Aug 17, 2021
1 parent 68acdcf commit 438c0b2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 6 additions & 2 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
Expand All @@ -29,17 +29,21 @@ def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> InstanceSegmentationData:
"""Downloads and loads the pets data set from icedata."""
data_dir = icedata.pets.load_data()

if parser is None:
parser = partial(icedata.pets.parser, mask=True)

return InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
parser=partial(icedata.pets.parser, mask=True),
parser=parser,
**preprocess_kwargs,
)

Expand Down
9 changes: 7 additions & 2 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Callable, Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires_extras
Expand All @@ -28,17 +28,21 @@ def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> KeypointDetectionData:
"""Downloads and loads the BIWI data set from icedata."""
data_dir = icedata.biwi.load_data()

if parser is None:
parser = icedata.biwi.parser

return KeypointDetectionData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
parser=icedata.biwi.parser,
parser=parser,
**preprocess_kwargs,
)

Expand All @@ -50,6 +54,7 @@ def keypoint_detection():
KeypointDetectionData,
default_datamodule_builder=from_biwi,
default_arguments={
"model.num_keypoints": 1,
"trainer.max_epochs": 3,
},
)
Expand Down

0 comments on commit 438c0b2

Please sign in to comment.