Skip to content

Commit b674c37

Browse files
galrotemfacebook-github-bot
authored andcommitted
avoid compiling compute_loss (pytorch#511)
Summary: Pull Request resolved: pytorch#511 # Context As per discussion in D48361308, looks like we can remove this since it's only used for torch dynamo but it's not officially supported # This diff Separating the change from the above since we also need to clean up some existing UTs which don't make sense anymore. Going to rebase the above change on top of this Differential Revision: D48581289 fbshipit-source-id: 2808db907b77af373234a12d70c1314217635945
1 parent f11f72a commit b674c37

File tree

2 files changed

+0
-100
lines changed

2 files changed

+0
-100
lines changed

tests/framework/test_auto_unit.py

-95
Original file line numberDiff line numberDiff line change
@@ -284,94 +284,6 @@ def test_compile_state_dict(self) -> None:
284284
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
285285
)
286286

287-
@unittest.skipUnless(
288-
condition=COMPILE_AVAIL,
289-
reason="This test needs PyTorch 1.13 or greater to run.",
290-
)
291-
def test_compile_eager(self) -> None:
292-
"""
293-
e2e torch compile test
294-
"""
295-
296-
my_module = torch.nn.Linear(2, 2)
297-
298-
input_dim = 2
299-
dataset_len = 16
300-
batch_size = 2
301-
max_epochs = 1
302-
303-
auto_unit = DummyAutoUnit(
304-
module=my_module,
305-
torch_compile_params=TorchCompileParams(backend="eager"),
306-
)
307-
308-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
309-
self.assertFalse(auto_unit._compile_used)
310-
train(auto_unit, train_dl, max_epochs=max_epochs)
311-
self.assertTrue(auto_unit._compile_used)
312-
313-
@unittest.skipUnless(
314-
condition=COMPILE_AVAIL,
315-
reason="This test needs PyTorch 1.13 or greater to run.",
316-
)
317-
@unittest.skipUnless(
318-
condition=cuda_available, reason="This test needs a GPU host to run."
319-
)
320-
def test_compile_train(self) -> None:
321-
"""
322-
e2e torch compile on train
323-
"""
324-
325-
my_module = torch.nn.Linear(2, 2)
326-
327-
input_dim = 2
328-
dataset_len = 16
329-
batch_size = 2
330-
max_epochs = 1
331-
332-
auto_unit = DummyAutoUnit(
333-
module=my_module,
334-
torch_compile_params=TorchCompileParams(backend="inductor"),
335-
)
336-
337-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
338-
339-
self.assertFalse(auto_unit._compile_used)
340-
train(auto_unit, train_dl, max_epochs=max_epochs)
341-
self.assertTrue(auto_unit._compile_used)
342-
343-
@unittest.skipUnless(
344-
condition=COMPILE_AVAIL,
345-
reason="This test needs PyTorch 1.13 or greater to run.",
346-
)
347-
@unittest.skipUnless(
348-
condition=cuda_available, reason="This test needs a GPU host to run."
349-
)
350-
def test_compile_eval(self) -> None:
351-
"""
352-
e2e torch compile on eval
353-
"""
354-
355-
my_module = torch.nn.Linear(2, 2)
356-
357-
input_dim = 2
358-
dataset_len = 16
359-
batch_size = 2
360-
361-
auto_unit = DummyAutoUnit(
362-
module=my_module,
363-
torch_compile_params=TorchCompileParams(backend="inductor"),
364-
)
365-
366-
input_dim = 2
367-
dataset_len = 8
368-
batch_size = 2
369-
370-
eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
371-
self.assertFalse(auto_unit._compile_used)
372-
evaluate(auto_unit, eval_dl)
373-
self.assertTrue(auto_unit._compile_used)
374-
375287
@unittest.skipUnless(
376288
condition=COMPILE_AVAIL,
377289
reason="This test needs PyTorch 1.13 or greater to run.",
@@ -983,16 +895,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
983895

984896
# pyre-fixme[11]: Annotation `Batch` is not defined as a type.
985897
class DummyAutoUnit(AutoUnit[Batch]):
986-
# pyre-fixme[3]: Return type must be annotated.
987-
# pyre-fixme[2]: Parameter must be annotated.
988-
def __init__(self, *args, **kwargs):
989-
super().__init__(*args, **kwargs)
990-
self._compile_used = False
991898

992899
# pyre-fixme[3]: Return annotation cannot contain `Any`.
993900
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
994-
if COMPILE_AVAIL:
995-
self._compile_used = torch._dynamo.is_compiling()
996901
inputs, targets = data
997902
outputs = self.module(inputs)
998903
loss = torch.nn.functional.cross_entropy(outputs, targets)

torchtnt/framework/auto_unit.py

-5
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,6 @@ def __init__(
470470
)
471471

472472
if torch_compile_params:
473-
# pyre-ignore
474-
self.compute_loss = torch.compile(
475-
self.compute_loss,
476-
**asdict(torch_compile_params),
477-
)
478473
try:
479474
# use in-place compile to avoid altering the state_dict keys
480475
module.compile(**asdict(torch_compile_params))

0 commit comments

Comments
 (0)