diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py index 61bfde817..2a65583c6 100644 --- a/tests/models/test_arch_mapde.py +++ b/tests/models/test_arch_mapde.py @@ -14,14 +14,14 @@ ON_GPU = toolbox_env.has_gpu() -def _load_mapde(name: str) -> torch.nn.Module: +def _load_mapde(name: str) -> MapDe: """Loads MapDe model with specified weights.""" model = MapDe() weights_path = fetch_pretrained_weights(name) map_location = select_device(on_gpu=ON_GPU) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - + model.to(map_location) return model @@ -45,7 +45,6 @@ def test_functionality(remote_sample: Callable) -> None: model = _load_mapde(name="mapde-conic") patch = model.preproc(patch) batch = torch.from_numpy(patch)[None] - model = model.to() output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU)) output = model.postproc(output[0]) assert np.all(output[0:2] == [[19, 171], [53, 89]]) diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py index 2729d2b3a..16c99cc49 100644 --- a/tests/models/test_arch_sccnn.py +++ b/tests/models/test_arch_sccnn.py @@ -12,14 +12,14 @@ from tiatoolbox.wsicore.wsireader import WSIReader -def _load_sccnn(name: str) -> torch.nn.Module: +def _load_sccnn(name: str) -> SCCNN: """Loads SCCNN model with specified weights.""" model = SCCNN() weights_path = fetch_pretrained_weights(name) map_location = select_device(on_gpu=env_detection.has_gpu()) pretrained = torch.load(weights_path, map_location=map_location) model.load_state_dict(pretrained) - + model.to(map_location) return model @@ -48,7 +48,6 @@ def test_functionality(remote_sample: Callable) -> None: ) output = model.postproc(output[0]) assert np.all(output == [[8, 7]]) - model = _load_sccnn(name="sccnn-conic") output = model.infer_batch( model,