Skip to content

Commit b22a3ae

Browse files
authored
actually patch tp plan (#2897)
1 parent 833df4f commit b22a3ae

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtune/modules/loss/cross_entropy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def set_model_output(self, model: nn.Module) -> None:
7272
def patch_tp_plan(self, tp_plan) -> dict:
7373
if "output" not in tp_plan and "decoder.output" not in tp_plan:
7474
raise KeyError("`tp_plan` requires `output` key")
75-
76-
tp_plan["output"] = ColwiseParallel(
75+
key = "output" if "output" in tp_plan else "decoder.output"
76+
tp_plan[key] = ColwiseParallel(
7777
input_layouts=Shard(1),
7878
output_layouts=Shard(-1),
7979
use_local_output=False,

0 commit comments

Comments
 (0)