@@ -284,94 +284,6 @@ def test_compile_state_dict(self) -> None:
284
284
torch .allclose (my_module_state_dict [k ], compiled_state_dict [k ])
285
285
)
286
286
287
- @unittest .skipUnless (
288
- condition = COMPILE_AVAIL ,
289
- reason = "This test needs PyTorch 1.13 or greater to run." ,
290
- )
291
- def test_compile_eager (self ) -> None :
292
- """
293
- e2e torch compile test
294
- """
295
-
296
- my_module = torch .nn .Linear (2 , 2 )
297
-
298
- input_dim = 2
299
- dataset_len = 16
300
- batch_size = 2
301
- max_epochs = 1
302
-
303
- auto_unit = DummyAutoUnit (
304
- module = my_module ,
305
- torch_compile_params = TorchCompileParams (backend = "eager" ),
306
- )
307
-
308
- train_dl = generate_random_dataloader (dataset_len , input_dim , batch_size )
309
- self .assertFalse (auto_unit ._compile_used )
310
- train (auto_unit , train_dl , max_epochs = max_epochs )
311
- self .assertTrue (auto_unit ._compile_used )
312
-
313
- @unittest .skipUnless (
314
- condition = COMPILE_AVAIL ,
315
- reason = "This test needs PyTorch 1.13 or greater to run." ,
316
- )
317
- @unittest .skipUnless (
318
- condition = cuda_available , reason = "This test needs a GPU host to run."
319
- )
320
- def test_compile_train (self ) -> None :
321
- """
322
- e2e torch compile on train
323
- """
324
-
325
- my_module = torch .nn .Linear (2 , 2 )
326
-
327
- input_dim = 2
328
- dataset_len = 16
329
- batch_size = 2
330
- max_epochs = 1
331
-
332
- auto_unit = DummyAutoUnit (
333
- module = my_module ,
334
- torch_compile_params = TorchCompileParams (backend = "inductor" ),
335
- )
336
-
337
- train_dl = generate_random_dataloader (dataset_len , input_dim , batch_size )
338
-
339
- self .assertFalse (auto_unit ._compile_used )
340
- train (auto_unit , train_dl , max_epochs = max_epochs )
341
- self .assertTrue (auto_unit ._compile_used )
342
-
343
- @unittest .skipUnless (
344
- condition = COMPILE_AVAIL ,
345
- reason = "This test needs PyTorch 1.13 or greater to run." ,
346
- )
347
- @unittest .skipUnless (
348
- condition = cuda_available , reason = "This test needs a GPU host to run."
349
- )
350
- def test_compile_eval (self ) -> None :
351
- """
352
- e2e torch compile on eval
353
- """
354
-
355
- my_module = torch .nn .Linear (2 , 2 )
356
-
357
- input_dim = 2
358
- dataset_len = 16
359
- batch_size = 2
360
-
361
- auto_unit = DummyAutoUnit (
362
- module = my_module ,
363
- torch_compile_params = TorchCompileParams (backend = "inductor" ),
364
- )
365
-
366
- input_dim = 2
367
- dataset_len = 8
368
- batch_size = 2
369
-
370
- eval_dl = generate_random_dataloader (dataset_len , input_dim , batch_size )
371
- self .assertFalse (auto_unit ._compile_used )
372
- evaluate (auto_unit , eval_dl )
373
- self .assertTrue (auto_unit ._compile_used )
374
-
375
287
@unittest .skipUnless (
376
288
condition = COMPILE_AVAIL ,
377
289
reason = "This test needs PyTorch 1.13 or greater to run." ,
@@ -983,16 +895,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
983
895
984
896
# pyre-fixme[11]: Annotation `Batch` is not defined as a type.
985
897
class DummyAutoUnit (AutoUnit [Batch ]):
986
- # pyre-fixme[3]: Return type must be annotated.
987
- # pyre-fixme[2]: Parameter must be annotated.
988
- def __init__ (self , * args , ** kwargs ):
989
- super ().__init__ (* args , ** kwargs )
990
- self ._compile_used = False
991
898
992
899
# pyre-fixme[3]: Return annotation cannot contain `Any`.
993
900
def compute_loss (self , state : State , data : Batch ) -> Tuple [torch .Tensor , Any ]:
994
- if COMPILE_AVAIL :
995
- self ._compile_used = torch ._dynamo .is_compiling ()
996
901
inputs , targets = data
997
902
outputs = self .module (inputs )
998
903
loss = torch .nn .functional .cross_entropy (outputs , targets )
0 commit comments