Skip to content

Commit ede15a7

Browse files
committed
Review updates.
1 parent a251fd4 commit ede15a7

File tree

11 files changed

+85
-85
lines changed

11 files changed

+85
-85
lines changed

docs/debugging.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ Use `--debug.deterministic_warn_only` to only warn about (not stop running) kern
9999

100100
The following debug configs are available for AC.
101101

102-
`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower.
102+
`preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower.
103103

104-
`ac_determinism_check` - A string specifying the determinism function
104+
`determinism_check` - A string specifying the determinism function
105105

106-
`ac_debug` - capture ac debug information. Will be slower.
106+
`debug` - capture ac debug information. Will be slower.
107107

108108
See https://docs.pytorch.org/docs/stable/checkpoint.html for details.
109109

tests/unit_tests/test_activation_checkpoint.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch.utils.flop_counter import FlopCounterMode
1212

1313
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
14-
from torchtitan.config.job_config import JobConfig
1514
from torchtitan.distributed.activation_checkpoint import apply_ac
1615

1716

@@ -75,16 +74,15 @@ def get_bw_flops(model_fn):
7574
# 2. SAC
7675
# Per-op SAC's policy is to save every other mm
7776
model_selective_ac = ToyModule()
78-
job_config = JobConfig()
79-
job_config.activation_checkpoint = ACConfig(
77+
ac_config_no_force = ACConfig(
8078
mode="selective",
8179
selective_ac_option="op",
8280
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
8381
early_stop=False,
8482
)
8583
apply_ac(
8684
model_selective_ac,
87-
job_config,
85+
ac_config_no_force,
8886
model_compile_enabled=False,
8987
use_flex_attn=False,
9088
op_sac_save_list=_op_sac_save_list,
@@ -94,15 +92,15 @@ def get_bw_flops(model_fn):
9492
# 3. Per-op SAC with force recompute "moe.router.gate"
9593
# This leads to two mms being recomputed since they share the same shape!
9694
model_with_force_first = ToyModule()
97-
job_config.activation_checkpoint = ACConfig(
95+
ac_config_with_force_first = ACConfig(
9896
mode="selective",
9997
selective_ac_option="op",
10098
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
10199
early_stop=False,
102100
)
103101
apply_ac(
104102
model_with_force_first,
105-
job_config,
103+
ac_config_with_force_first,
106104
model_compile_enabled=False,
107105
use_flex_attn=False,
108106
op_sac_save_list=_op_sac_save_list,
@@ -111,15 +109,15 @@ def get_bw_flops(model_fn):
111109

112110
# 4. Per-op SAC with force recompute "output"
113111
model_with_force_last = ToyModule()
114-
job_config.activation_checkpoint = ACConfig(
112+
ac_config_with_force_last = ACConfig(
115113
mode="selective",
116114
selective_ac_option="op",
117115
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
118116
early_stop=False,
119117
)
120118
apply_ac(
121119
model_with_force_last,
122-
job_config,
120+
ac_config_with_force_last,
123121
model_compile_enabled=False,
124122
use_flex_attn=False,
125123
op_sac_save_list=_op_sac_save_list,
@@ -128,13 +126,13 @@ def get_bw_flops(model_fn):
128126

129127
# 5. Full AC
130128
model_with_full_ac = ToyModule()
131-
job_config.activation_checkpoint = ACConfig(
129+
ac_config_full_ac = ACConfig(
132130
mode="full",
133131
early_stop=False,
134132
)
135133
apply_ac(
136134
model_with_full_ac,
137-
job_config,
135+
ac_config_full_ac,
138136
model_compile_enabled=False,
139137
use_flex_attn=False,
140138
op_sac_save_list=_op_sac_save_list,
@@ -170,14 +168,14 @@ def get_act_mem(model_fn):
170168
# 2. SAC
171169
# Per-op SAC's policy is to save every other mm
172170
model_selective_ac = ToyModule().cuda()
173-
job_config.activation_checkpoint = ACConfig(
171+
ac_config_no_force = ACConfig(
174172
mode="selective",
175173
selective_ac_option="op",
176174
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
177175
)
178176
apply_ac(
179177
model_selective_ac,
180-
job_config,
178+
ac_config_no_force,
181179
model_compile_enabled=False,
182180
use_flex_attn=False,
183181
op_sac_save_list=_op_sac_save_list,
@@ -187,14 +185,14 @@ def get_act_mem(model_fn):
187185
# 3. Per-op SAC with force recompute "moe.router.gate"
188186
# This leads to two mms being recomputed since they share the same shape!
189187
model_with_force_first = ToyModule().cuda()
190-
job_config.activation_checkpoint = ACConfig(
188+
ac_config_with_force_first = ACConfig(
191189
mode="selective",
192190
selective_ac_option="op",
193191
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
194192
)
195193
apply_ac(
196194
model_with_force_first,
197-
job_config,
195+
ac_config_with_force_first,
198196
model_compile_enabled=False,
199197
use_flex_attn=False,
200198
op_sac_save_list=_op_sac_save_list,
@@ -203,14 +201,14 @@ def get_act_mem(model_fn):
203201

204202
# 4. Per-op SAC with force recompute "output"
205203
model_with_force_last = ToyModule().cuda()
206-
job_config.activation_checkpoint = ACConfig(
204+
ac_config_with_force_last = ACConfig(
207205
mode="selective",
208206
selective_ac_option="op",
209207
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
210208
)
211209
apply_ac(
212210
model_with_force_last,
213-
job_config,
211+
ac_config_with_force_last,
214212
model_compile_enabled=False,
215213
use_flex_attn=False,
216214
op_sac_save_list=_op_sac_save_list,
@@ -219,12 +217,12 @@ def get_act_mem(model_fn):
219217

220218
# 5. Full AC
221219
model_with_full_ac = ToyModule().cuda()
222-
job_config.activation_checkpoint = ACConfig(
220+
ac_config_full_ac = ACConfig(
223221
mode="full",
224222
)
225223
apply_ac(
226224
model_with_full_ac,
227-
job_config,
225+
ac_config_full_ac,
228226
model_compile_enabled=False,
229227
use_flex_attn=False,
230228
op_sac_save_list=_op_sac_save_list,
@@ -245,44 +243,40 @@ def test_correctness(self):
245243

246244
model_selective_ac = ToyModule()
247245
model_selective_ac.load_state_dict(model_no_ac.state_dict())
248-
job_config = JobConfig()
249-
job_config.activation_checkpoint = ACConfig(
246+
apply_ac(
247+
model_selective_ac,
248+
ACConfig(
250249
mode="selective",
251250
selective_ac_option="op",
252251
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
253-
)
254-
apply_ac(
255-
model_selective_ac,
256-
job_config,
252+
),
257253
model_compile_enabled=False,
258254
use_flex_attn=False,
259255
op_sac_save_list=_op_sac_save_list,
260256
)
261257
model_force_first = ToyModule()
262258
model_force_first.load_state_dict(model_no_ac.state_dict())
263-
job_config.activation_checkpoint = ACConfig(
259+
apply_ac(
260+
model_force_first,
261+
ACConfig(
264262
mode="selective",
265263
selective_ac_option="op",
266264
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
267-
)
268-
apply_ac(
269-
model_force_first,
270-
job_config,
265+
),
271266
model_compile_enabled=False,
272267
use_flex_attn=False,
273268
op_sac_save_list=_op_sac_save_list,
274269
)
275270

276271
model_force_last = ToyModule()
277272
model_force_last.load_state_dict(model_no_ac.state_dict())
278-
job_config.activation_checkpoint = ACConfig(
273+
apply_ac(
274+
model_force_last,
275+
ACConfig(
279276
mode="selective",
280277
selective_ac_option="op",
281278
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
282-
)
283-
apply_ac(
284-
model_force_last,
285-
job_config,
279+
),
286280
model_compile_enabled=False,
287281
use_flex_attn=False,
288282
op_sac_save_list=_op_sac_save_list,

torchtitan/config/job_config.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,26 @@ class ActivationCheckpoint:
623623
https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
624624
"""
625625

626+
preserve_rng_state: bool = False
627+
"""
628+
If deterministic output compared to non-checkpointed passes is required, set
629+
to true. Results in stashing and restoring the RNG state during each checkpoint,
630+
may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html
631+
for details.
632+
"""
633+
634+
determinism_check: str = "default"
635+
"""
636+
A string specifying the determinism function. See
637+
https://docs.pytorch.org/docs/stable/checkpoint.html for details.
638+
"""
639+
640+
debug: bool = False
641+
"""
642+
Capture ac debug information. Will be slower. See
643+
https://docs.pytorch.org/docs/stable/checkpoint.html for details.
644+
"""
645+
626646

627647
@dataclass
628648
class Compile:
@@ -882,15 +902,6 @@ class Debug:
882902
deterministic_warn_only: bool = False
883903
"""Only warns about ops without deterministic implementations rather than erroring out """
884904

885-
ac_preserve_rng_state: bool = False
886-
"""If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""
887-
888-
ac_determinism_check: str = "default"
889-
"""A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""
890-
891-
ac_debug: bool = False
892-
""" Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""
893-
894905
moe_force_load_balance: bool = False
895906
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""
896907

0 commit comments

Comments
 (0)