From 8349d40441146311b170b47248e596810824e85e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:46:32 +0100 Subject: [PATCH] update --- tests/vision/segmentation/test_serialization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py index a971c91fbf..872fcc2420 100644 --- a/tests/vision/segmentation/test_serialization.py +++ b/tests/vision/segmentation/test_serialization.py @@ -1,6 +1,7 @@ import pytest import torch +from flash.data.data_source import DefaultDataKeys from flash.vision.segmentation.serialization import SegmentationLabels @@ -30,7 +31,7 @@ def test_serialize(self): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.serialize(sample) + classes = serial.serialize({DefaultDataKeys.PREDS: sample}) assert classes[1, 2] == 1 assert classes[0, 1] == 3