Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

add predict_kwargs in ObjectDetectionModel in order to filter the pre… #990

2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990))

- Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))
Expand Down
10 changes: 7 additions & 3 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ class IceVisionAdapter(Adapter):

required_extras: str = "image"

def __init__(self, model_type, model, icevision_adapter, backbone):
def __init__(self, model_type, model, icevision_adapter, backbone, predict_kwargs):
super().__init__()

self.model_type = model_type
self.model = model
self.icevision_adapter = icevision_adapter
self.backbone = backbone
self.predict_kwargs = predict_kwargs

@classmethod
@catch_url_error
Expand All @@ -62,6 +63,7 @@ def from_task(
num_classes: int,
backbone: str,
head: str,
predict_kwargs: Dict,
pretrained: bool = True,
metrics: Optional["IceVisionMetric"] = None,
image_size: Optional = None,
Expand All @@ -77,7 +79,7 @@ def from_task(
**kwargs,
)
icevision_adapter = icevision_adapter(model=model, metrics=metrics)
return cls(model_type, model, icevision_adapter, backbone)
return cls(model_type, model, icevision_adapter, backbone, predict_kwargs)

@staticmethod
def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None):
Expand Down Expand Up @@ -198,7 +200,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return batch

def forward(self, batch: Any) -> Any:
return from_icevision_predictions(self.model_type.predict_from_dl(self.model, [batch], show_pbar=False))
return from_icevision_predictions(
self.model_type.predict_from_dl(self.model, [batch], show_pbar=False, **self.predict_kwargs)
)

def training_epoch_end(self, outputs) -> None:
return self.icevision_adapter.training_epoch_end(outputs)
Expand Down
13 changes: 13 additions & 0 deletions flash/image/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ObjectDetector(AdapterTask):
lr_scheduler: The LR scheduler to use during training.
learning_rate: The learning rate to use for training.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
predict_kwargs: dictionary containing parameters that will be used during the prediction phase.
kwargs: additional kwargs nessesary for initializing the backbone task
"""

Expand All @@ -50,17 +51,20 @@ def __init__(
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 5e-3,
output: OUTPUT_TYPE = None,
predict_kwargs: Dict = None,
**kwargs: Any,
):
self.save_hyperparameters()

predict_kwargs = predict_kwargs if predict_kwargs else {}
metadata = self.heads.get(head, with_metadata=True)
adapter = metadata["metadata"]["adapter"].from_task(
self,
num_classes=num_classes,
backbone=backbone,
head=head,
pretrained=pretrained,
predict_kwargs=predict_kwargs,
**kwargs,
)

Expand All @@ -75,3 +79,12 @@ def __init__(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
"""This function is used only for debugging usage with CI."""
# todo

@property
def predict_kwargs(self) -> Dict[str, Any]:
"""The kwargs used for the prediction step."""
return self.adapter.predict_kwargs

@predict_kwargs.setter
def predict_kwargs(self, predict_kwargs: Dict[str, Any]):
self.adapter.predict_kwargs = predict_kwargs
13 changes: 13 additions & 0 deletions flash/image/instance_segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class InstanceSegmentation(AdapterTask):
lr_scheduler: The LR scheduler to use during training.
learning_rate: The learning rate to use for training.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
predict_kwargs: dictionary containing parameters that will be used during the prediction phase.
**kwargs: additional kwargs used for initializing the task
"""

Expand All @@ -57,17 +58,20 @@ def __init__(
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 5e-4,
output: OUTPUT_TYPE = None,
predict_kwargs: Dict = None,
**kwargs: Any,
):
self.save_hyperparameters()

predict_kwargs = predict_kwargs if predict_kwargs else {}
metadata = self.heads.get(head, with_metadata=True)
adapter = metadata["metadata"]["adapter"].from_task(
self,
num_classes=num_classes,
backbone=backbone,
head=head,
pretrained=pretrained,
predict_kwargs=predict_kwargs,
**kwargs,
)

Expand Down Expand Up @@ -96,3 +100,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
input_transform=InstanceSegmentationInputTransform(),
output_transform=InstanceSegmentationOutputTransform(),
)

@property
def predict_kwargs(self) -> Dict[str, Any]:
"""The kwargs used for the prediction step."""
return self.adapter.predict_kwargs

@predict_kwargs.setter
def predict_kwargs(self, predict_kwargs: Dict[str, Any]):
self.adapter.predict_kwargs = predict_kwargs
13 changes: 13 additions & 0 deletions flash/image/keypoint_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class KeypointDetector(AdapterTask):
lr_scheduler: The LR scheduler to use during training.
learning_rate: The learning rate to use for training.
output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs.
predict_kwargs: dictionary containing parameters that will be used during the prediction phase.
**kwargs: additional kwargs used for initializing the task
"""

Expand All @@ -52,10 +53,12 @@ def __init__(
lr_scheduler: LR_SCHEDULER_TYPE = None,
learning_rate: float = 5e-4,
output: OUTPUT_TYPE = None,
predict_kwargs: Dict = None,
**kwargs: Any,
):
self.save_hyperparameters()

predict_kwargs = predict_kwargs if predict_kwargs else {}
metadata = self.heads.get(head, with_metadata=True)
adapter = metadata["metadata"]["adapter"].from_task(
self,
Expand All @@ -64,6 +67,7 @@ def __init__(
backbone=backbone,
head=head,
pretrained=pretrained,
predict_kwargs=predict_kwargs,
**kwargs,
)

Expand All @@ -78,3 +82,12 @@ def __init__(
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None:
"""This function is used only for debugging usage with CI."""
# todo

@property
def predict_kwargs(self) -> Dict[str, Any]:
"""The kwargs used for the prediction step."""
return self.adapter.predict_kwargs

@predict_kwargs.setter
def predict_kwargs(self, predict_kwargs: Dict[str, Any]):
self.adapter.predict_kwargs = predict_kwargs
17 changes: 17 additions & 0 deletions tests/image/detection/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,20 @@ def test_cli():
main()
except SystemExit:
pass


@pytest.mark.parametrize("head", ["retinanet"])
@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _ICEVISION_AVAILABLE, reason="IceVision is not installed for testing")
def test_predict(tmpdir, head):
model = ObjectDetector(num_classes=2, head=head, pretrained=False)
ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
dl = model.process_train_dataset(ds, trainer, 2, 0, False, None)
trainer.fit(model, dl)
dl = model.process_predict_dataset(ds, batch_size=2)
predictions = trainer.predict(model, dl)
assert len(predictions[0][0]["bboxes"]) > 0
model.predict_kwargs = {"detection_threshold": 2}
predictions = trainer.predict(model, dl)
assert len(predictions[0][0]["bboxes"]) == 0