Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jun 18, 2021
1 parent 3201f17 commit f8bb543
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def build_data_pipeline(
else:
data_source = preprocess.data_source_of_name(data_source)

deserializer = deserializer or getattr(preprocess, "deserializer", None)
if type(deserializer) == Deserializer:
deserializer = getattr(preprocess, "deserializer", deserializer)

data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
Expand All @@ -400,7 +401,7 @@ def build_data_pipeline(
@torch.jit.unused
@property
def is_servable(self) -> bool:
return self.build_data_pipeline()._deserializer is not None
return type(self.build_data_pipeline()._deserializer) != Deserializer

@torch.jit.unused
@property
Expand Down Expand Up @@ -608,7 +609,7 @@ def run_serve_sanity_check(self):

from flash.core.serve.flash_components import build_flash_serve_model_component

print("Running sanity check")
print("Running serve sanity check")
comp = build_flash_serve_model_component(self)
composition = Composition(predict=comp, TESTING=True, DEBUG=True)
app = composition.serve(host="0.0.0.0", port=8000)
Expand Down
4 changes: 2 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler

from flash.core.classification import ClassificationTask
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Serializer
from flash.core.registry import FlashRegistry
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer,
serializer=serializer or Labels(),
)

self.save_hyperparameters()
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True):
print(ImageClassifier.available_backbones())

# 4. Build the model
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, serializer=Labels())
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 5. Create the trainer
trainer = flash.Trainer(max_epochs=3)
Expand Down

0 comments on commit f8bb543

Please sign in to comment.