@@ -53,32 +53,6 @@ def test_to_handle_unsupported_element(self):
5353
5454 self .assertIn ("Unsupported element 'multimodal_input'" ,
5555 str (context .exception ))
56- self .assertIn ("Supported elements: 'multimodal_data'" ,
57- str (context .exception ))
58-
59- def test_to_tensor_unsupported_element (self ):
60- """Test to_tensor raises ValueError for unsupported elements."""
61- params = MultimodalParams ()
62-
63- with self .assertRaises (ValueError ) as context :
64- params .to_tensor ("multimodal_input" )
65-
66- self .assertIn ("Unsupported element 'multimodal_input'" ,
67- str (context .exception ))
68- self .assertIn ("Supported elements: 'multimodal_data'" ,
69- str (context .exception ))
70-
71- def test_to_device_unsupported_element (self ):
72- """Test to_device raises ValueError for unsupported elements."""
73- params = MultimodalParams ()
74-
75- with self .assertRaises (ValueError ) as context :
76- params .to_device ("multimodal_input" , device = "cuda" , pin_memory = True )
77-
78- self .assertIn ("Unsupported element 'multimodal_input'" ,
79- str (context .exception ))
80- self .assertIn ("Supported elements: 'multimodal_data'" ,
81- str (context .exception ))
8256
8357 def test_to_tensor_basic_handle (self ):
8458 """Test converting a basic handle back to tensor."""
@@ -122,5 +96,56 @@ def test_to_tensor_all_handles(self):
12296 self .image ["image_width" ])
12397
12498
99+ class TestMultimodalParamsDeviceTransfer (unittest .TestCase ):
100+ """Test cases for to_device method in MultimodalParams."""
101+
102+ def setUp (self ):
103+ """Set up test fixtures."""
104+ self .mm_embedding = torch .randn (3 , 4 , 5 )
105+ self .mrope_config = {
106+ "mrope_rotary_cos_sin" : torch .randn (2 , 3 ),
107+ "mrope_position_deltas" : torch .randn (5 ),
108+ }
109+ self .image = {
110+ "pixel_values" : torch .randn (1 , 3 , 224 , 224 ),
111+ "image_height" : [224 ],
112+ "image_width" : [224 ],
113+ }
114+ self .sample_multimodal_data = {
115+ "multimodal_embedding" : self .mm_embedding ,
116+ "mrope_config" : self .mrope_config ,
117+ "image" : self .image ,
118+ }
119+
120+ def test_to_device_basic (self ):
121+ """Test converting a basic data to device."""
122+ params = MultimodalParams ()
123+ params .multimodal_data = {"multimodal_embedding" : self .mm_embedding }
124+
125+ params .to_device ("multimodal_data" , device = "cuda:0" , pin_memory = True )
126+
127+ result = params .multimodal_data ["multimodal_embedding" ]
128+ self .assertEqual (result .device , torch .device ("cuda:0" ))
129+
130+ def test_to_device_all_data (self ):
131+ """Test converting all data to device."""
132+ params = MultimodalParams ()
133+ params .multimodal_data = self .sample_multimodal_data .copy ()
134+
135+ params .to_device ("multimodal_data" , device = "cuda:0" , pin_memory = True )
136+
137+ result = params .multimodal_data ["multimodal_embedding" ]
138+ self .assertEqual (result .device , torch .device ("cuda:0" ))
139+
140+ result = params .multimodal_data ["mrope_config" ]["mrope_rotary_cos_sin" ]
141+ self .assertEqual (result .device , torch .device ("cuda:0" ))
142+
143+ result = params .multimodal_data ["mrope_config" ]["mrope_position_deltas" ]
144+ self .assertEqual (result .device , torch .device ("cuda:0" ))
145+
146+ result = params .multimodal_data ["image" ]["pixel_values" ]
147+ self .assertEqual (result .device , torch .device ("cuda:0" ))
148+
149+
125150if __name__ == "__main__" :
126151 unittest .main ()
0 commit comments