diff --git a/megatron/core/transformer/pipeline_parallel_layer_layout.py b/megatron/core/transformer/pipeline_parallel_layer_layout.py index 835d8b5c33a..56467bf0e9d 100644 --- a/megatron/core/transformer/pipeline_parallel_layer_layout.py +++ b/megatron/core/transformer/pipeline_parallel_layer_layout.py @@ -15,8 +15,11 @@ class PipelineParallelLayerLayout: """Configuration of custom pipeline parallel layer partitioning.""" - def __repr__(self): - return self.input_data + def __repr__(self) -> str: + if isinstance(self.input_data, str): + return self.input_data + else: + return str(self.input_data) def __init__(self, layout: str | list, pipeline_model_parallel_size: int): """Initialize PipelineParallelLayerLayout from a list or a str. diff --git a/tests/unit_tests/transformer/test_transformer_block.py b/tests/unit_tests/transformer/test_transformer_block.py index 6a37ceac3c2..83b29613157 100644 --- a/tests/unit_tests/transformer/test_transformer_block.py +++ b/tests/unit_tests/transformer/test_transformer_block.py @@ -613,3 +613,38 @@ def test_parsing_layout_from_str(self, pp_size, input_layout_str, input_layout_l parsed_layout_from_str.virtual_pipeline_model_parallel_size == parsed_layout_baseline.virtual_pipeline_model_parallel_size ) + + @pytest.mark.parametrize( + "pp_size, input_layout", + [ + (2, "Et|t*4|t|tL"), + (2, [["embedding", "decoder"], ["decoder"] * 4, ["decoder"], ["decoder", "loss"]]), + (8, [["embedding"] + ["decoder"] * 3] + [["decoder"] * 2] * 29 + [["mtp"], ["loss"]]), + ], + ) + def test_repr_returns_string(self, pp_size, input_layout): + """Test that __repr__ always returns a string for both str and list inputs.""" + layout = PipelineParallelLayerLayout(input_layout, pp_size) + repr_result = repr(layout) + + # Assert that repr returns a string + assert isinstance( + repr_result, str + ), f"__repr__ must return a string, but got {type(repr_result).__name__}" + + # Assert that the returned string matches the expected value + if isinstance(input_layout, str): + # For string input, repr should return the exact same string + assert repr_result == input_layout, ( + f"For string input, repr should return the original string.\n" + f"Expected: {input_layout!r}\n" + f"Got: {repr_result!r}" + ) + else: + # For list input, repr should return str(input_layout) + expected_repr = str(input_layout) + assert repr_result == expected_repr, ( + f"For list input, repr should return str(input_layout).\n" + f"Expected: {expected_repr!r}\n" + f"Got: {repr_result!r}" + )