diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py index a971c91fbf..872fcc2420 100644 --- a/tests/vision/segmentation/test_serialization.py +++ b/tests/vision/segmentation/test_serialization.py @@ -1,6 +1,7 @@ import pytest import torch +from flash.data.data_source import DefaultDataKeys from flash.vision.segmentation.serialization import SegmentationLabels @@ -30,7 +31,7 @@ def test_serialize(self): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.serialize(sample) + classes = serial.serialize({DefaultDataKeys.PREDS: sample}) assert classes[1, 2] == 1 assert classes[0, 1] == 3