File tree Expand file tree Collapse file tree 2 files changed +24
-0
lines changed
src/lightning/pytorch/core
tests/tests_pytorch/models Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -1472,6 +1472,10 @@ def forward(self, x):
14721472 )
14731473 example_inputs = self .example_input_array
14741474
1475+ if kwargs .get ("check_inputs" ) is not None :
1476+ kwargs ["check_inputs" ] = self ._on_before_batch_transfer (kwargs ["check_inputs" ])
1477+ kwargs ["check_inputs" ] = self ._apply_batch_transfer_handler (kwargs ["check_inputs" ])
1478+
14751479 # automatically send example inputs to the right device and use trace
14761480 example_inputs = self ._on_before_batch_transfer (example_inputs )
14771481 example_inputs = self ._apply_batch_transfer_handler (example_inputs )
Original file line number Diff line number Diff line change @@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105105 assert script_output .device == device
106106
107107
108+ @pytest .mark .parametrize (
109+ "device_str" ,
110+ [
111+ "cpu" ,
112+ pytest .param ("cuda:0" , marks = RunIf (min_cuda_gpus = 1 )),
113+ pytest .param ("mps:0" , marks = RunIf (mps = True )),
114+ ],
115+ )
116+ def test_torchscript_device_with_check_inputs (device_str ):
117+ """Test that scripted module is on the correct device."""
118+ device = torch .device (device_str )
119+ model = BoringModel ().to (device )
120+ model .example_input_array = torch .randn (5 , 32 )
121+
122+ check_inputs = torch .rand (5 , 32 )
123+
124+ script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
125+ assert isinstance (script , torch .jit .ScriptModule )
126+
127+
108128def test_torchscript_retain_training_state ():
109129 """Test that torchscript export does not alter the training mode of original model."""
110130 model = BoringModel ()
You can’t perform that action at this time.
0 commit comments