Skip to content

Commit 05537e9

Browse files
galrotemfacebook-github-bot
authored andcommitted
avoid compiling compute_loss (#511)
Summary: Pull Request resolved: #511 # Context As per discussion in D48361308, looks like we can remove this since it's only used for torch dynamo but it's not officially supported # This diff Separating the change from the above since we also need to clean up some existing UTs which don't make sense anymore. Going to rebase the above change on top of this Reviewed By: JKSenthil Differential Revision: D48581289 fbshipit-source-id: 94219fa45190467b6b87533500f66735a909d1cb
1 parent a690136 commit 05537e9

File tree

2 files changed

+2
-100
lines changed

2 files changed

+2
-100
lines changed

tests/framework/test_auto_unit.py

+2-95
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def test_compile_state_dict(self) -> None:
266266
"""
267267
device = init_from_env()
268268
my_module = torch.nn.Linear(2, 2, device=device)
269+
self.assertIsNone(my_module._compiled_call_impl)
269270
my_module_state_dict = my_module.state_dict()
270271
auto_unit = DummyAutoUnit(
271272
module=my_module,
@@ -283,94 +284,7 @@ def test_compile_state_dict(self) -> None:
283284
self.assertTrue(
284285
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
285286
)
286-
287-
@unittest.skipUnless(
288-
condition=COMPILE_AVAIL,
289-
reason="This test needs PyTorch 1.13 or greater to run.",
290-
)
291-
def test_compile_eager(self) -> None:
292-
"""
293-
e2e torch compile test
294-
"""
295-
296-
my_module = torch.nn.Linear(2, 2)
297-
298-
input_dim = 2
299-
dataset_len = 16
300-
batch_size = 2
301-
max_epochs = 1
302-
303-
auto_unit = DummyAutoUnit(
304-
module=my_module,
305-
torch_compile_params=TorchCompileParams(backend="eager"),
306-
)
307-
308-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
309-
self.assertFalse(auto_unit._compile_used)
310-
train(auto_unit, train_dl, max_epochs=max_epochs)
311-
self.assertTrue(auto_unit._compile_used)
312-
313-
@unittest.skipUnless(
314-
condition=COMPILE_AVAIL,
315-
reason="This test needs PyTorch 1.13 or greater to run.",
316-
)
317-
@unittest.skipUnless(
318-
condition=cuda_available, reason="This test needs a GPU host to run."
319-
)
320-
def test_compile_train(self) -> None:
321-
"""
322-
e2e torch compile on train
323-
"""
324-
325-
my_module = torch.nn.Linear(2, 2)
326-
327-
input_dim = 2
328-
dataset_len = 16
329-
batch_size = 2
330-
max_epochs = 1
331-
332-
auto_unit = DummyAutoUnit(
333-
module=my_module,
334-
torch_compile_params=TorchCompileParams(backend="inductor"),
335-
)
336-
337-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
338-
339-
self.assertFalse(auto_unit._compile_used)
340-
train(auto_unit, train_dl, max_epochs=max_epochs)
341-
self.assertTrue(auto_unit._compile_used)
342-
343-
@unittest.skipUnless(
344-
condition=COMPILE_AVAIL,
345-
reason="This test needs PyTorch 1.13 or greater to run.",
346-
)
347-
@unittest.skipUnless(
348-
condition=cuda_available, reason="This test needs a GPU host to run."
349-
)
350-
def test_compile_eval(self) -> None:
351-
"""
352-
e2e torch compile on eval
353-
"""
354-
355-
my_module = torch.nn.Linear(2, 2)
356-
357-
input_dim = 2
358-
dataset_len = 16
359-
batch_size = 2
360-
361-
auto_unit = DummyAutoUnit(
362-
module=my_module,
363-
torch_compile_params=TorchCompileParams(backend="inductor"),
364-
)
365-
366-
input_dim = 2
367-
dataset_len = 8
368-
batch_size = 2
369-
370-
eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
371-
self.assertFalse(auto_unit._compile_used)
372-
evaluate(auto_unit, eval_dl)
373-
self.assertTrue(auto_unit._compile_used)
287+
self.assertIsNotNone(auto_unit.module._compiled_call_impl)
374288

375289
@unittest.skipUnless(
376290
condition=COMPILE_AVAIL,
@@ -983,16 +897,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
983897

984898
# pyre-fixme[11]: Annotation `Batch` is not defined as a type.
985899
class DummyAutoUnit(AutoUnit[Batch]):
986-
# pyre-fixme[3]: Return type must be annotated.
987-
# pyre-fixme[2]: Parameter must be annotated.
988-
def __init__(self, *args, **kwargs):
989-
super().__init__(*args, **kwargs)
990-
self._compile_used = False
991900

992901
# pyre-fixme[3]: Return annotation cannot contain `Any`.
993902
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
994-
if COMPILE_AVAIL:
995-
self._compile_used = torch._dynamo.is_compiling()
996903
inputs, targets = data
997904
outputs = self.module(inputs)
998905
loss = torch.nn.functional.cross_entropy(outputs, targets)

torchtnt/framework/auto_unit.py

-5
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,6 @@ def __init__(
470470
)
471471

472472
if torch_compile_params:
473-
# pyre-ignore
474-
self.compute_loss = torch.compile(
475-
self.compute_loss,
476-
**asdict(torch_compile_params),
477-
)
478473
try:
479474
# use in-place compile to avoid altering the state_dict keys
480475
module.compile(**asdict(torch_compile_params))

0 commit comments

Comments
 (0)