Skip to content

Commit 2eaaefb

Browse files
committed
Simplify tests
1 parent 36b639b commit 2eaaefb

File tree

1 file changed

+56
-85
lines changed

1 file changed

+56
-85
lines changed

tests/unit/test_callbacks_writer_file.py

+56-85
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
import pytest
77
import torch
88
from PIL import Image
9+
import tempfile
910

1011
from lighter.callbacks.writer.file import LighterFileWriter
1112

1213

13-
def test_file_writer_initialization():
14+
def test_file_writer_initialization(tmp_path):
1415
"""Test the initialization of LighterFileWriter class.
1516
1617
This test verifies that:
@@ -21,17 +22,12 @@ def test_file_writer_initialization():
2122
The test creates a temporary directory, initializes a writer, checks its attributes,
2223
and then cleans up the directory.
2324
"""
24-
path = Path("test_dir")
25-
path.mkdir(exist_ok=True) # Ensure the directory exists
26-
try:
27-
writer = LighterFileWriter(path=path, writer="tensor")
28-
assert writer.path == Path("test_dir")
29-
assert writer.writer.__name__ == "write_tensor" # Verify writer function
30-
finally:
31-
shutil.rmtree(path) # Clean up after test
25+
writer = LighterFileWriter(path=tmp_path, writer="tensor")
26+
assert writer.path == tmp_path
27+
assert writer.writer.__name__ == "write_tensor" # Verify writer function
3228

3329

34-
def test_file_writer_write_tensor():
30+
def test_file_writer_write_tensor(tmp_path):
3531
"""Test tensor writing functionality of LighterFileWriter.
3632
3733
This test verifies that:
@@ -42,25 +38,20 @@ def test_file_writer_write_tensor():
4238
The test creates a simple tensor, saves it, loads it back, and verifies
4339
the content matches the original.
4440
"""
45-
test_dir = Path("test_dir")
46-
test_dir.mkdir(exist_ok=True)
47-
try:
48-
writer = LighterFileWriter(path=test_dir, writer="tensor")
49-
tensor = torch.tensor([1, 2, 3])
50-
writer.write(tensor, id=1)
41+
writer = LighterFileWriter(path=tmp_path, writer="tensor")
42+
tensor = torch.tensor([1, 2, 3])
43+
writer.write(tensor, id=1)
5144

52-
# Verify file exists
53-
saved_path = writer.path / "1.pt"
54-
assert saved_path.exists()
45+
# Verify file exists
46+
saved_path = writer.path / "1.pt"
47+
assert saved_path.exists()
5548

56-
# Verify tensor contents
57-
loaded_tensor = torch.load(saved_path) # nosec B614
58-
assert torch.equal(loaded_tensor, tensor)
59-
finally:
60-
shutil.rmtree(test_dir)
49+
# Verify tensor contents
50+
loaded_tensor = torch.load(saved_path) # nosec B614
51+
assert torch.equal(loaded_tensor, tensor)
6152

6253

63-
def test_file_writer_write_image():
54+
def test_file_writer_write_image(tmp_path):
6455
"""Test image writing functionality of LighterFileWriter.
6556
6657
This test verifies that:
@@ -71,26 +62,21 @@ def test_file_writer_write_image():
7162
The test creates a random RGB image tensor, saves it, and verifies
7263
the saved image properties.
7364
"""
74-
test_dir = Path("test_dir")
75-
test_dir.mkdir(exist_ok=True)
76-
try:
77-
writer = LighterFileWriter(path=test_dir, writer="image")
78-
tensor = torch.randint(0, 256, (3, 64, 64), dtype=torch.uint8)
79-
writer.write(tensor, id="image_test")
80-
81-
# Verify file exists
82-
saved_path = writer.path / "image_test.png"
83-
assert saved_path.exists()
84-
85-
# Verify image contents
86-
image = Image.open(saved_path)
87-
image_array = np.array(image)
88-
assert image_array.shape == (64, 64, 3)
89-
finally:
90-
shutil.rmtree(test_dir)
65+
writer = LighterFileWriter(path=tmp_path, writer="image")
66+
tensor = torch.randint(0, 256, (3, 64, 64), dtype=torch.uint8)
67+
writer.write(tensor, id="image_test")
68+
69+
# Verify file exists
70+
saved_path = writer.path / "image_test.png"
71+
assert saved_path.exists()
72+
73+
# Verify image contents
74+
image = Image.open(saved_path)
75+
image_array = np.array(image)
76+
assert image_array.shape == (64, 64, 3)
9177

9278

93-
def test_file_writer_write_video():
79+
def test_file_writer_write_video(tmp_path):
9480
"""Test video writing functionality of LighterFileWriter.
9581
9682
This test verifies that:
@@ -100,21 +86,16 @@ def test_file_writer_write_video():
10086
The test creates a random RGB video tensor and verifies it can be saved
10187
to disk in the correct format.
10288
"""
103-
test_dir = Path("test_dir")
104-
test_dir.mkdir(exist_ok=True)
105-
try:
106-
writer = LighterFileWriter(path=test_dir, writer="video")
107-
tensor = torch.randint(0, 256, (3, 10, 64, 64), dtype=torch.uint8)
108-
writer.write(tensor, id="video_test")
89+
writer = LighterFileWriter(path=tmp_path, writer="video")
90+
tensor = torch.randint(0, 256, (3, 10, 64, 64), dtype=torch.uint8)
91+
writer.write(tensor, id="video_test")
10992

110-
# Verify file exists
111-
saved_path = writer.path / "video_test.mp4"
112-
assert saved_path.exists()
113-
finally:
114-
shutil.rmtree(test_dir)
93+
# Verify file exists
94+
saved_path = writer.path / "video_test.mp4"
95+
assert saved_path.exists()
11596

11697

117-
def test_file_writer_write_grayscale_video():
98+
def test_file_writer_write_grayscale_video(tmp_path):
11899
"""Test grayscale video writing functionality of LighterFileWriter.
119100
120101
This test verifies that:
@@ -125,22 +106,17 @@ def test_file_writer_write_grayscale_video():
125106
The test creates a grayscale video tensor and verifies it can be properly
126107
converted and saved as an MP4 file.
127108
"""
128-
test_dir = Path("test_dir")
129-
test_dir.mkdir(exist_ok=True)
130-
try:
131-
writer = LighterFileWriter(path=test_dir, writer="video")
132-
# Create a grayscale video tensor with 1 channel
133-
tensor = torch.randint(0, 256, (1, 10, 64, 64), dtype=torch.uint8)
134-
writer.write(tensor, id="grayscale_video_test")
135-
136-
# Verify file exists
137-
saved_path = writer.path / "grayscale_video_test.mp4"
138-
assert saved_path.exists()
139-
finally:
140-
shutil.rmtree(test_dir)
109+
writer = LighterFileWriter(path=tmp_path, writer="video")
110+
# Create a grayscale video tensor with 1 channel
111+
tensor = torch.randint(0, 256, (1, 10, 64, 64), dtype=torch.uint8)
112+
writer.write(tensor, id="grayscale_video_test")
141113

114+
# Verify file exists
115+
saved_path = writer.path / "grayscale_video_test.mp4"
116+
assert saved_path.exists()
142117

143-
def test_file_writer_write_itk_image():
118+
119+
def test_file_writer_write_itk_image(tmp_path):
144120
"""Test ITK image writing functionality of LighterFileWriter.
145121
146122
This test verifies that:
@@ -151,25 +127,20 @@ def test_file_writer_write_itk_image():
151127
The test attempts to write both regular tensors and MetaTensors,
152128
verifying proper error handling and successful writes.
153129
"""
154-
test_dir = Path("test_dir")
155-
test_dir.mkdir(exist_ok=True)
156-
try:
157-
writer = LighterFileWriter(path=test_dir, writer="itk_nrrd")
158-
tensor = torch.rand(1, 1, 64, 64, 64) # Example 3D tensor
130+
writer = LighterFileWriter(path=tmp_path, writer="itk_nrrd")
131+
tensor = torch.rand(1, 1, 64, 64, 64) # Example 3D tensor
159132

160-
# Test with regular tensor
161-
with pytest.raises(TypeError, match="Tensor must be in MONAI MetaTensor format"):
162-
writer.write(tensor, id="itk_image_test")
133+
# Test with regular tensor
134+
with pytest.raises(TypeError, match="Tensor must be in MONAI MetaTensor format"):
135+
writer.write(tensor, id="itk_image_test")
163136

164-
# Test with proper MetaTensor
165-
meta_tensor = monai.data.MetaTensor(tensor, affine=torch.eye(4), meta={"original_channel_dim": 1})
166-
writer.write(meta_tensor, id="itk_image_test")
137+
# Test with proper MetaTensor
138+
meta_tensor = monai.data.MetaTensor(tensor, affine=torch.eye(4), meta={"original_channel_dim": 1})
139+
writer.write(meta_tensor, id="itk_image_test")
167140

168-
# Verify file exists
169-
saved_path = writer.path / "itk_image_test.nrrd"
170-
assert saved_path.exists()
171-
finally:
172-
shutil.rmtree(test_dir)
141+
# Verify file exists
142+
saved_path = writer.path / "itk_image_test.nrrd"
143+
assert saved_path.exists()
173144

174145

175146
def test_file_writer_invalid_directory():

0 commit comments

Comments
 (0)