Skip to content

Commit 6c782eb

Browse files
committed
Support new compile API from autoparallel PR #77
1 parent 60f5f11 commit 6c782eb

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def input_fn():
6464
param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
6565
reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce]
6666
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
67-
with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop:
67+
with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy, compile=job_config.training.compile) as autop:
6868
autop.add_parameter_memory_constraint(low=None, high=None)
6969

7070
possible_input_shardings = {
@@ -87,8 +87,4 @@ def input_fn():
8787
logger.info(f"AutoParallel took {t1 - t0} seconds")
8888
parallel_mod = autop.apply_placement(sharding_placement)
8989

90-
if job_config.training.compile:
91-
torch._inductor.config.reorder_for_peak_memory = False
92-
parallel_mod.compile(fullgraph=True)
93-
9490
return parallel_mod

0 commit comments

Comments
 (0)