diff --git a/flash/audio/speech_recognition/output_transform.py b/flash/audio/speech_recognition/output_transform.py index d5f65a2df0..56f282bed5 100644 --- a/flash/audio/speech_recognition/output_transform.py +++ b/flash/audio/speech_recognition/output_transform.py @@ -62,7 +62,7 @@ def per_batch_transform(self, batch: Any) -> Any: def __getstate__(self): # TODO: Find out why this is being pickled state = self.__dict__.copy() - state.pop("_tokenizer") + state.pop("_tokenizer", None) return state def __setstate__(self, state): diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 84d3ed566f..c27b589958 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -34,9 +34,9 @@ OUTPUT_TRANSFORM_TYPE, OUTPUT_TYPE, ) -from flash.image.data import ImageDeserializer from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS +from flash.image.segmentation.input import SemanticSegmentationDeserializer from flash.image.segmentation.output import SegmentationLabelsOutput from flash.image.segmentation.transforms import SemanticSegmentationInputTransform @@ -184,7 +184,7 @@ def serve( host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True, - input_cls: Optional[Type[ServeInput]] = ImageDeserializer, + input_cls: Optional[Type[ServeInput]] = SemanticSegmentationDeserializer, transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, ) -> Composition: