Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions megatron/core/transformer/pipeline_parallel_layer_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
class PipelineParallelLayerLayout:
"""Configuration of custom pipeline parallel layer partitioning."""

def __repr__(self):
return self.input_data
def __repr__(self) -> str:
if isinstance(self.input_data, str):
return self.input_data
else:
return str(self.input_data)

def __init__(self, layout: str | list, pipeline_model_parallel_size: int):
"""Initialize PipelineParallelLayerLayout from a list or a str.
Expand Down
35 changes: 35 additions & 0 deletions tests/unit_tests/transformer/test_transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,38 @@ def test_parsing_layout_from_str(self, pp_size, input_layout_str, input_layout_l
parsed_layout_from_str.virtual_pipeline_model_parallel_size
== parsed_layout_baseline.virtual_pipeline_model_parallel_size
)

@pytest.mark.parametrize(
"pp_size, input_layout",
[
(2, "Et|t*4|t|tL"),
(2, [["embedding", "decoder"], ["decoder"] * 4, ["decoder"], ["decoder", "loss"]]),
(8, [["embedding"] + ["decoder"] * 3] + [["decoder"] * 2] * 29 + [["mtp"], ["loss"]]),
],
)
def test_repr_returns_string(self, pp_size, input_layout):
"""Test that __repr__ always returns a string for both str and list inputs."""
layout = PipelineParallelLayerLayout(input_layout, pp_size)
repr_result = repr(layout)

# Assert that repr returns a string
assert isinstance(
repr_result, str
), f"__repr__ must return a string, but got {type(repr_result).__name__}"

# Assert that the returned string matches the expected value
if isinstance(input_layout, str):
# For string input, repr should return the exact same string
assert repr_result == input_layout, (
f"For string input, repr should return the original string.\n"
f"Expected: {input_layout!r}\n"
f"Got: {repr_result!r}"
)
else:
# For list input, repr should return str(input_layout)
expected_repr = str(input_layout)
assert repr_result == expected_repr, (
f"For list input, repr should return str(input_layout).\n"
f"Expected: {expected_repr!r}\n"
f"Got: {repr_result!r}"
)
Loading