Skip to content

Commit cfd8d67

Browse files
galrotemfacebook-github-bot
authored andcommitted
avoid compiling compute_loss
Summary: # Context [Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.] # This diff [List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.] # What’s next [If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.] Differential Revision: D48581289 fbshipit-source-id: 4abd5932a1a18f2983605cacc886abe6433a0711
1 parent f11f72a commit cfd8d67

File tree

2 files changed

+0
-100
lines changed

2 files changed

+0
-100
lines changed

tests/framework/test_auto_unit.py

-95
Original file line numberDiff line numberDiff line change
@@ -284,94 +284,6 @@ def test_compile_state_dict(self) -> None:
284284
torch.allclose(my_module_state_dict[k], compiled_state_dict[k])
285285
)
286286

287-
@unittest.skipUnless(
288-
condition=COMPILE_AVAIL,
289-
reason="This test needs PyTorch 1.13 or greater to run.",
290-
)
291-
def test_compile_eager(self) -> None:
292-
"""
293-
e2e torch compile test
294-
"""
295-
296-
my_module = torch.nn.Linear(2, 2)
297-
298-
input_dim = 2
299-
dataset_len = 16
300-
batch_size = 2
301-
max_epochs = 1
302-
303-
auto_unit = DummyAutoUnit(
304-
module=my_module,
305-
torch_compile_params=TorchCompileParams(backend="eager"),
306-
)
307-
308-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
309-
self.assertFalse(auto_unit._compile_used)
310-
train(auto_unit, train_dl, max_epochs=max_epochs)
311-
self.assertTrue(auto_unit._compile_used)
312-
313-
@unittest.skipUnless(
314-
condition=COMPILE_AVAIL,
315-
reason="This test needs PyTorch 1.13 or greater to run.",
316-
)
317-
@unittest.skipUnless(
318-
condition=cuda_available, reason="This test needs a GPU host to run."
319-
)
320-
def test_compile_train(self) -> None:
321-
"""
322-
e2e torch compile on train
323-
"""
324-
325-
my_module = torch.nn.Linear(2, 2)
326-
327-
input_dim = 2
328-
dataset_len = 16
329-
batch_size = 2
330-
max_epochs = 1
331-
332-
auto_unit = DummyAutoUnit(
333-
module=my_module,
334-
torch_compile_params=TorchCompileParams(backend="inductor"),
335-
)
336-
337-
train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
338-
339-
self.assertFalse(auto_unit._compile_used)
340-
train(auto_unit, train_dl, max_epochs=max_epochs)
341-
self.assertTrue(auto_unit._compile_used)
342-
343-
@unittest.skipUnless(
344-
condition=COMPILE_AVAIL,
345-
reason="This test needs PyTorch 1.13 or greater to run.",
346-
)
347-
@unittest.skipUnless(
348-
condition=cuda_available, reason="This test needs a GPU host to run."
349-
)
350-
def test_compile_eval(self) -> None:
351-
"""
352-
e2e torch compile on eval
353-
"""
354-
355-
my_module = torch.nn.Linear(2, 2)
356-
357-
input_dim = 2
358-
dataset_len = 16
359-
batch_size = 2
360-
361-
auto_unit = DummyAutoUnit(
362-
module=my_module,
363-
torch_compile_params=TorchCompileParams(backend="inductor"),
364-
)
365-
366-
input_dim = 2
367-
dataset_len = 8
368-
batch_size = 2
369-
370-
eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
371-
self.assertFalse(auto_unit._compile_used)
372-
evaluate(auto_unit, eval_dl)
373-
self.assertTrue(auto_unit._compile_used)
374-
375287
@unittest.skipUnless(
376288
condition=COMPILE_AVAIL,
377289
reason="This test needs PyTorch 1.13 or greater to run.",
@@ -983,16 +895,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None:
983895

984896
# pyre-fixme[11]: Annotation `Batch` is not defined as a type.
985897
class DummyAutoUnit(AutoUnit[Batch]):
986-
# pyre-fixme[3]: Return type must be annotated.
987-
# pyre-fixme[2]: Parameter must be annotated.
988-
def __init__(self, *args, **kwargs):
989-
super().__init__(*args, **kwargs)
990-
self._compile_used = False
991898

992899
# pyre-fixme[3]: Return annotation cannot contain `Any`.
993900
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
994-
if COMPILE_AVAIL:
995-
self._compile_used = torch._dynamo.is_compiling()
996901
inputs, targets = data
997902
outputs = self.module(inputs)
998903
loss = torch.nn.functional.cross_entropy(outputs, targets)

torchtnt/framework/auto_unit.py

-5
Original file line numberDiff line numberDiff line change
@@ -470,11 +470,6 @@ def __init__(
470470
)
471471

472472
if torch_compile_params:
473-
# pyre-ignore
474-
self.compute_loss = torch.compile(
475-
self.compute_loss,
476-
**asdict(torch_compile_params),
477-
)
478473
try:
479474
# use in-place compile to avoid altering the state_dict keys
480475
module.compile(**asdict(torch_compile_params))

0 commit comments

Comments
 (0)