Skip to content

Commit 204c30a

Browse files
modify test
Signed-off-by: yechank <[email protected]>
1 parent 8a91c21 commit 204c30a

File tree

1 file changed

+51
-26
lines changed

1 file changed

+51
-26
lines changed

tests/unittest/_torch/multimodal/test_share_multiparams.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,6 @@ def test_to_handle_unsupported_element(self):
5353

5454
self.assertIn("Unsupported element 'multimodal_input'",
5555
str(context.exception))
56-
self.assertIn("Supported elements: 'multimodal_data'",
57-
str(context.exception))
58-
59-
def test_to_tensor_unsupported_element(self):
60-
"""Test to_tensor raises ValueError for unsupported elements."""
61-
params = MultimodalParams()
62-
63-
with self.assertRaises(ValueError) as context:
64-
params.to_tensor("multimodal_input")
65-
66-
self.assertIn("Unsupported element 'multimodal_input'",
67-
str(context.exception))
68-
self.assertIn("Supported elements: 'multimodal_data'",
69-
str(context.exception))
70-
71-
def test_to_device_unsupported_element(self):
72-
"""Test to_device raises ValueError for unsupported elements."""
73-
params = MultimodalParams()
74-
75-
with self.assertRaises(ValueError) as context:
76-
params.to_device("multimodal_input", device="cuda", pin_memory=True)
77-
78-
self.assertIn("Unsupported element 'multimodal_input'",
79-
str(context.exception))
80-
self.assertIn("Supported elements: 'multimodal_data'",
81-
str(context.exception))
8256

8357
def test_to_tensor_basic_handle(self):
8458
"""Test converting a basic handle back to tensor."""
@@ -122,5 +96,56 @@ def test_to_tensor_all_handles(self):
12296
self.image["image_width"])
12397

12498

99+
class TestMultimodalParamsDeviceTransfer(unittest.TestCase):
100+
"""Test cases for to_device method in MultimodalParams."""
101+
102+
def setUp(self):
103+
"""Set up test fixtures."""
104+
self.mm_embedding = torch.randn(3, 4, 5)
105+
self.mrope_config = {
106+
"mrope_rotary_cos_sin": torch.randn(2, 3),
107+
"mrope_position_deltas": torch.randn(5),
108+
}
109+
self.image = {
110+
"pixel_values": torch.randn(1, 3, 224, 224),
111+
"image_height": [224],
112+
"image_width": [224],
113+
}
114+
self.sample_multimodal_data = {
115+
"multimodal_embedding": self.mm_embedding,
116+
"mrope_config": self.mrope_config,
117+
"image": self.image,
118+
}
119+
120+
def test_to_device_basic(self):
121+
"""Test converting a basic data to device."""
122+
params = MultimodalParams()
123+
params.multimodal_data = {"multimodal_embedding": self.mm_embedding}
124+
125+
params.to_device("multimodal_data", device="cuda:0", pin_memory=True)
126+
127+
result = params.multimodal_data["multimodal_embedding"]
128+
self.assertEqual(result.device, torch.device("cuda:0"))
129+
130+
def test_to_device_all_data(self):
131+
"""Test converting all data to device."""
132+
params = MultimodalParams()
133+
params.multimodal_data = self.sample_multimodal_data.copy()
134+
135+
params.to_device("multimodal_data", device="cuda:0", pin_memory=True)
136+
137+
result = params.multimodal_data["multimodal_embedding"]
138+
self.assertEqual(result.device, torch.device("cuda:0"))
139+
140+
result = params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"]
141+
self.assertEqual(result.device, torch.device("cuda:0"))
142+
143+
result = params.multimodal_data["mrope_config"]["mrope_position_deltas"]
144+
self.assertEqual(result.device, torch.device("cuda:0"))
145+
146+
result = params.multimodal_data["image"]["pixel_values"]
147+
self.assertEqual(result.device, torch.device("cuda:0"))
148+
149+
125150
if __name__ == "__main__":
126151
unittest.main()

0 commit comments

Comments
 (0)