We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 409d3af commit 0ea49e4Copy full SHA for 0ea49e4
torchtune/training/_activation_offloading.py
@@ -133,10 +133,9 @@ def get_num_bytes_tensor(x: torch.Tensor) -> int:
133
def pack_tensor(activation: torch.Tensor) -> int:
134
# activations are passed in during forward pass - from here we take over and return a unique id
135
if self.is_first_forward_call:
136
- if len(self.tracker) == 0:
137
- raise AssertionError(
138
- "backward pass should have cleared tracker of all tensors"
139
- )
+ assert (
+ len(self.tracker) == 0
+ ), "backward pass should have cleared tracker of all tensors"
140
141
# set training phase trackers
142
self.is_first_forward_call = False
0 commit comments