Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 39 additions & 34 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,73 +52,78 @@ def test_job_config_file_cmd_overrides(self):
)
assert config.job.dump_folder == "/tmp/test_tt/"

def test_parse_pp_split_points(self):
toml_splits = ["layers.2", "layers.4", "layers.6"]
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
# no split points specified
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
"./torchtitan/models/llama3/train_configs/debug_model.toml",
]
)
assert config.parallelism.pipeline_parallel_split_points == []
def test_parse_module_fqns_per_model_part(self):
toml_chunks = [
["tok_embeddings", "layers.0"],
["layers.1", "layers.2"],
["layers.3", "norm", "output"],
]
cmdline_chunks = [
["tok_embeddings", "layers.0", "layers.1"],
["layers.2", "layers.3", "norm", "output"],
]

# toml has no split points, but cmdline splits are specified
# no module names specified
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
"./torchtitan/models/llama3/train_configs/debug_model.toml",
"--parallelism.pipeline_parallel_split_points",
",".join(cmdline_splits),
]
)
assert (
config.parallelism.pipeline_parallel_split_points == cmdline_splits
), config.parallelism.pipeline_parallel_split_points
assert config.parallelism.module_fqns_per_model_part is None

# toml has split points, cmdline does not
# toml has module names, cmdline does not
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"pipeline_parallel_split_points": toml_splits,
"module_fqns_per_model_part": toml_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.pipeline_parallel_split_points == toml_splits
), config.parallelism.pipeline_parallel_split_points
config.parallelism.module_fqns_per_model_part == toml_chunks
), config.parallelism.module_fqns_per_model_part

# toml has split points, cmdline overrides them
# test that the field accepts list of lists structure
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"pipeline_parallel_split_points": toml_splits,
"module_fqns_per_model_part": cmdline_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--job.config_file",
fp.name,
"--parallelism.pipeline_parallel_split_points",
",".join(cmdline_splits),
]
)
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.module_fqns_per_model_part == cmdline_chunks
), config.parallelism.module_fqns_per_model_part

# test empty chunks are handled correctly
empty_chunks = [[], ["tok_embeddings"], []]
with tempfile.NamedTemporaryFile() as fp:
with open(fp.name, "wb") as f:
tomli_w.dump(
{
"parallelism": {
"module_fqns_per_model_part": empty_chunks,
}
},
f,
)
config_manager = ConfigManager()
config = config_manager.parse_args(["--job.config_file", fp.name])
assert (
config.parallelism.pipeline_parallel_split_points == cmdline_splits
), config.parallelism.pipeline_parallel_split_points
config.parallelism.module_fqns_per_model_part == empty_chunks
), config.parallelism.module_fqns_per_model_part

def test_parse_exclude_from_loading(self):
toml_splits = ["optimizer", "dataloader"]
Expand Down
25 changes: 24 additions & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class Parallelism:

pipeline_parallel_split_points: list[str] = field(default_factory=list)
"""
DEPRECATED: Use module_fqns_per_model_part instead.
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
Expand All @@ -299,9 +300,31 @@ class Parallelism:
but currently the split points must be specified manually.
"""

module_fqns_per_model_part: list[list[str]] | None = None
"""
Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk.
Each inner list represents one model chunk and contains the module names that belong to that chunk.
e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']]
will create 3 chunks: the first containing tok_embeddings and layers.0,
the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4.
This provides more explicit control over which modules belong to each chunk compared to split points.
"""

pipeline_parallel_first_stage_less_layers: int = 1
"""
The number of layers to reduce in the first stage of pipeline parallelism. This is because
the first stage has the extra overhead of the embedding layer, which is not present in the other stages.
"""

pipeline_parallel_last_stage_less_layers: int = 1
"""
The number of layers to reduce in the last stage of pipeline parallelism. This is because
the last stage has the extra overhead of the output layer, which is not present in the other stages.
"""

pipeline_parallel_layers_per_stage: int | None = None
"""
The number of layers per (virtual) pipeline stage. If specified, the split points will be
The number of layers per (virtual) pipeline stage. If specified, the module_fqns_per_model_part will be
calculated from the number of layers and pipeline_parallel_degree. If not specified, the
layers per stage will be inferred from the model, schedule, and pipeline_parallel_degree.
"""
Expand Down
Loading
Loading