diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d83041774f..76dba6cca3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -395,6 +395,7 @@ jobs: - script: L2_Launch_models_qwen # - script: L2_Launch_models_qwen_quantization - script: L2_Launch_models_qwen_vl + - script: L2_Launch_recipes_gemma_vl - script: L2_Launch_recipes_gpt_oss - script: L2_Launch_recipes_llama_1b - script: L2_Launch_recipes_llama_3b diff --git a/src/megatron/bridge/data/vlm_datasets/mock_provider.py b/src/megatron/bridge/data/vlm_datasets/mock_provider.py index a21086539b..f24fb3fee4 100644 --- a/src/megatron/bridge/data/vlm_datasets/mock_provider.py +++ b/src/megatron/bridge/data/vlm_datasets/mock_provider.py @@ -67,27 +67,45 @@ class MockVLMConversationProvider(DatasetProvider): # HF AutoProcessor instance will be set during build _processor: Optional[Any] = None - def _make_base_examples(self) -> List[Dict[str, Any]]: - # Single minimal conversation example; dataset will repeat to target length + # Enable batch-level online sequence packing + pack_sequences_in_batch: bool = False + + def _make_single_example( + self, rng: numpy.random.Generator, prompt_text: str, response_text: str + ) -> Dict[str, Any]: + """Create a single mock conversation example with the given prompt and response text.""" num_images = max(0, int(getattr(self, "num_images", 1))) w, h = self.image_size - rng = numpy.random.default_rng(seed=self.random_seed) images = None if num_images > 0: - # Embed in-memory PIL images directly in the conversation so that - # qwen_vl_utils.process_vision_info can discover them. images = [ Image.fromarray(rng.integers(low=0, high=256, size=(h, w, 3), dtype=numpy.uint8), mode="RGB") for _ in range(num_images) ] content = [{"type": "image", "image": img} for img in images] if images is not None else [] - content.append({"type": "text", "text": self.prompt}) + content.append({"type": "text", "text": prompt_text}) messages = [ {"role": "user", "content": content}, - {"role": "assistant", "content": [{"type": "text", "text": "dummy assistant response"}]}, + {"role": "assistant", "content": [{"type": "text", "text": response_text}]}, + ] + return {"conversation": messages} + + def _make_base_examples(self) -> List[Dict[str, Any]]: + rng = numpy.random.default_rng(seed=self.random_seed) + + if not self.pack_sequences_in_batch: + # Single minimal conversation example; dataset will repeat to target length + return [self._make_single_example(rng, self.prompt, "dummy assistant response")] + + # When packing is enabled, produce several examples with varied response lengths + # so that the packing logic concatenates sequences of different sizes. + varied_responses = [ + "Short answer.", + "A somewhat longer response that contains more tokens to create length variation in the batch.", + "Medium length reply with a bit of detail.", ] - return [{"conversation": messages}] + return [self._make_single_example(rng, self.prompt, resp) for resp in varied_responses] def build_datasets(self, context: DatasetBuildContext): from transformers import AutoProcessor diff --git a/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py b/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py index 331850203c..f7c5e78f70 100644 --- a/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py +++ b/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py @@ -199,6 +199,9 @@ class PreloadedVLMConversationProvider(DatasetProvider): # Default dataloader type for VLM providers dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single" + # Enable batch-level online sequence packing + pack_sequences_in_batch: bool = False + def _build_split_dataset( self, split_path: Optional[str], diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 5d9661b325..693e29fd51 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -1597,6 +1597,13 @@ def validate(self) -> None: f"https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/packed_sequence.html" ) + if getattr(self.dataset, "pack_sequences_in_batch", False) and self.train.micro_batch_size == 1: + raise ValueError( + "micro_batch_size should be greater than 1 when using pack_sequences_in_batch=True. " + "In-batch packing concatenates multiple sequences within a microbatch, so at least 2 sequences " + "are required per micro-batch." + ) + if self.peft is not None: assert self.checkpoint.pretrained_checkpoint is not None, "PEFT requires a pretrained checkpoint path" diff --git a/src/megatron/bridge/training/utils/packed_seq_utils.py b/src/megatron/bridge/training/utils/packed_seq_utils.py index 98dbd6d5ac..1194fc7ed2 100644 --- a/src/megatron/bridge/training/utils/packed_seq_utils.py +++ b/src/megatron/bridge/training/utils/packed_seq_utils.py @@ -43,6 +43,8 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams: cu_seqlens_argmin = batch.get("cu_seqlens_argmin") cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin") + # note: if argmin is not pre-computed in the dataloader, torch.argmin here will incur a + # device-to-host synchronization, which can slow down training if cu_seqlens_argmin is not None: cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()] else: diff --git a/src/megatron/bridge/training/vlm_step.py b/src/megatron/bridge/training/vlm_step.py index 3c9724ca80..88f4f3e72f 100644 --- a/src/megatron/bridge/training/vlm_step.py +++ b/src/megatron/bridge/training/vlm_step.py @@ -403,9 +403,11 @@ def forward_step( # Add packed sequence support if cu_seqlens is not None: + cu_seqlens_argmin = torch.tensor(len(cu_seqlens)) # no padding in cu_seqlens since packing is done in-batch packed_seq_params = { "cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, + "cu_seqlens_argmin": cu_seqlens_argmin, } forward_args["packed_seq_params"] = get_packed_seq_params(packed_seq_params) diff --git a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py index 3b94e2f84c..c691a30d22 100644 --- a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py @@ -31,6 +31,16 @@ ), ] +GEMMA3_VL_FINETUNE_PACKED_RECIPES = [ + # Small model with packed sequences, only use 2 layers + ( + gemma3_vl_4b_finetune_config, + "gemma3_vl_4b_packed", + {"tensor_model_parallel_size": 1, "pipeline_model_parallel_size": 1, "num_layers": 2}, + {"pack_sequences_in_batch": True}, + ), +] + class TestGemma3VLRecipes: """Test class for Gemma3-VL recipe functional tests.""" @@ -40,3 +50,19 @@ class TestGemma3VLRecipes: def test_gemma3_vl_finetune_recipes(self, config_func, recipe_name, model_overrides, tmp_path): """Functional test for Gemma3-VL recipes with appropriate parallelism configurations.""" run_pretrain_vl_recipe_test(config_func, recipe_name, tmp_path, model_overrides=model_overrides) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,model_overrides,dataset_overrides", GEMMA3_VL_FINETUNE_PACKED_RECIPES + ) + def test_gemma3_vl_finetune_packed_recipes( + self, config_func, recipe_name, model_overrides, dataset_overrides, tmp_path + ): + """Functional test for Gemma3-VL recipes with packed sequences enabled.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + dataset_overrides=dataset_overrides, + ) diff --git a/tests/functional_tests/recipes/test_glm_45v_recipes_finetune.py b/tests/functional_tests/recipes/test_glm_45v_recipes_finetune.py index c2d7a5e3cf..00a341f33f 100644 --- a/tests/functional_tests/recipes/test_glm_45v_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_glm_45v_recipes_finetune.py @@ -43,6 +43,26 @@ ), ] +GLM_45V_FINETUNE_PACKED_RECIPES = [ + # Small model with packed sequences, only use 2 layers + ( + partial(glm_45v_finetune_config, peft=None), + "glm_45v_packed", + { + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "expert_model_parallel_size": 1, + "num_layers": 2, + "num_moe_experts": 8, + "hidden_size": 4096, + "ffn_hidden_size": 512, + "moe_layer_freq": [0, 1], + "pipeline_model_parallel_layout": None, + }, + {"pack_sequences_in_batch": True}, + ), +] + class TestGLM45VRecipes: """Test class for GLM 4.5V recipe functional tests.""" @@ -52,3 +72,19 @@ class TestGLM45VRecipes: def test_glm_45v_finetune_recipes(self, config_func, recipe_name, model_overrides, tmp_path): """Functional test for GLM 4.5V recipes with appropriate parallelism configurations.""" run_pretrain_vl_recipe_test(config_func, recipe_name, tmp_path, model_overrides=model_overrides) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,model_overrides,dataset_overrides", GLM_45V_FINETUNE_PACKED_RECIPES + ) + def test_glm_45v_finetune_packed_recipes( + self, config_func, recipe_name, model_overrides, dataset_overrides, tmp_path + ): + """Functional test for GLM 4.5V recipes with packed sequences enabled.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + dataset_overrides=dataset_overrides, + ) diff --git a/tests/functional_tests/recipes/test_ministral3_recipes_finetune.py b/tests/functional_tests/recipes/test_ministral3_recipes_finetune.py index 649a244bc5..fa9dedb4c0 100644 --- a/tests/functional_tests/recipes/test_ministral3_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_ministral3_recipes_finetune.py @@ -33,6 +33,16 @@ ), ] +MINISTRAL3_FINETUNE_PACKED_RECIPES = [ + # Small model with packed sequences, only use 2 layers + ( + partial(ministral3_3b_finetune_config, peft=None), + "ministral3_3b_packed", + {"tensor_model_parallel_size": 1, "pipeline_model_parallel_size": 1, "num_layers": 2}, + {"pack_sequences_in_batch": True}, + ), +] + class TestMinistral3Recipes: """Test class for Ministral 3 recipe functional tests.""" @@ -47,3 +57,24 @@ def test_ministral3_finetune_recipes(self, config_func, recipe_name, model_overr except ImportError: pytest.skip("Ministral 3 not available in transformers") run_pretrain_vl_recipe_test(config_func, recipe_name, tmp_path, model_overrides=model_overrides) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,model_overrides,dataset_overrides", MINISTRAL3_FINETUNE_PACKED_RECIPES + ) + def test_ministral3_finetune_packed_recipes( + self, config_func, recipe_name, model_overrides, dataset_overrides, tmp_path + ): + """Functional test for Ministral 3 recipes with packed sequences enabled.""" + try: + from transformers import Ministral3ForCausalLM, Mistral3ForConditionalGeneration # noqa: F401 + from transformers.models.mistral3.configuration_mistral3 import Mistral3Config # noqa: F401 + except ImportError: + pytest.skip("Ministral 3 not available in transformers") + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + dataset_overrides=dataset_overrides, + ) diff --git a/tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py index 93829bef2e..3d5dfa035a 100644 --- a/tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_qwen3_vl_recipes_finetune.py @@ -43,6 +43,18 @@ ), ] +QWEN3_VL_FINETUNE_PACKED_RECIPES = [ + # (config_func, recipe_name, parallelism_overrides, model_overrides, dataset_overrides) + # Qwen3-VL 8B finetune with packed sequences + ( + qwen3_vl_8b_finetune_config, + "qwen3_vl_8b_finetune_packed", + {"tensor_model_parallel_size": 2, "pipeline_model_parallel_size": 1}, + {"num_layers": 4, "deepstack_visual_indexes": [0, 1, 2]}, + {"pack_sequences_in_batch": True}, + ), +] + class TestQwen3VLFinetuneRecipes: """Test class for Qwen3-VL finetune recipe functional tests.""" @@ -75,3 +87,27 @@ def test_qwen3_vl_finetune_recipes( model_overrides=model_overrides, **parallelism_overrides, ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides,dataset_overrides", + QWEN3_VL_FINETUNE_PACKED_RECIPES, + ) + def test_qwen3_vl_finetune_packed_recipes( + self, + config_func, + recipe_name, + parallelism_overrides, + model_overrides, + dataset_overrides, + tmp_path, + ): + """Functional test for Qwen3-VL finetune recipes with packed sequences enabled.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + dataset_overrides=dataset_overrides, + **parallelism_overrides, + ) diff --git a/tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py index 47961d33a9..3dd20a924c 100644 --- a/tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_qwen_vl_recipes_finetune.py @@ -31,6 +31,18 @@ ), ] +QWEN_VL_PRETRAIN_PACKED_RECIPES = [ + # (config_func, name, parallelism_overrides, model_overrides, dataset_overrides) + # Two-GPU TP with packed sequences + ( + qwen25_vl_3b_finetune_config, + "qwen25_vl_3b_packed", + {"tensor_model_parallel_size": 2, "pipeline_model_parallel_size": 1}, + {"num_layers": 2}, + {"pack_sequences_in_batch": True}, + ), +] + class TestQwenVLRecipes: """Test class for Qwen2.5-VL recipe functional tests.""" @@ -44,3 +56,21 @@ def test_qwen25_vl_pretrain_recipes( run_pretrain_vl_recipe_test( config_func, recipe_name, tmp_path, model_overrides=model_overrides, **parallelism_overrides ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "config_func,recipe_name,parallelism_overrides,model_overrides,dataset_overrides", + QWEN_VL_PRETRAIN_PACKED_RECIPES, + ) + def test_qwen25_vl_pretrain_packed_recipes( + self, config_func, recipe_name, parallelism_overrides, model_overrides, dataset_overrides, tmp_path + ): + """Functional test for Qwen2.5-VL recipes with packed sequences enabled.""" + run_pretrain_vl_recipe_test( + config_func, + recipe_name, + tmp_path, + model_overrides=model_overrides, + dataset_overrides=dataset_overrides, + **parallelism_overrides, + ) diff --git a/tests/functional_tests/recipes/utils.py b/tests/functional_tests/recipes/utils.py index 4bf7f256b3..49399c805e 100644 --- a/tests/functional_tests/recipes/utils.py +++ b/tests/functional_tests/recipes/utils.py @@ -198,6 +198,7 @@ def run_pretrain_vl_recipe_test( tensor_model_parallel_size: Optional[int] = None, pipeline_model_parallel_size: Optional[int] = None, model_overrides: Optional[dict] = None, + dataset_overrides: Optional[dict] = None, forward_step_func: Optional[Callable] = None, ): """ @@ -213,6 +214,7 @@ def run_pretrain_vl_recipe_test( tensor_model_parallel_size: Override tensor parallelism (None = use recipe default) pipeline_model_parallel_size: Override pipeline parallelism (None = use recipe default) model_overrides: Optional mapping of model attribute overrides to apply + dataset_overrides: Optional mapping of dataset attribute overrides to apply """ if forward_step_func is None: # Import locally to avoid loading VLM stack for non-VL tests @@ -269,6 +271,14 @@ def run_pretrain_vl_recipe_test( for attribute_name, attribute_value in model_overrides.items(): setattr(config.model, attribute_name, attribute_value) + # Apply any dataset-specific overrides provided by the caller + if dataset_overrides: + for attribute_name, attribute_value in dataset_overrides.items(): + setattr(config.dataset, attribute_name, attribute_value) + + if config.dataset.pack_sequences_in_batch: + config.train.micro_batch_size = 2 + pretrain(config, vlm_forward_step) # Basic verification that training completed successfully diff --git a/tests/unit_tests/training/test_config.py b/tests/unit_tests/training/test_config.py index 3bce9d61f7..f230258334 100644 --- a/tests/unit_tests/training/test_config.py +++ b/tests/unit_tests/training/test_config.py @@ -891,6 +891,52 @@ def test_packed_sequence_validation_skipped_for_gpt_dataset(self, monkeypatch): finally: restore_get_world_size_safe(og_ws, cfg_mod) + def test_pack_sequences_in_batch_requires_micro_batch_size_gt_1(self, monkeypatch): + """Test validation error when micro_batch_size == 1 with pack_sequences_in_batch=True.""" + gpt_model_cfg = create_test_gpt_config() + train_cfg = create_test_training_config(micro_batch_size=1, global_batch_size=32) + dataset_cfg = create_test_finetuning_dataset_config(sequence_length=512) + dataset_cfg.pack_sequences_in_batch = True + + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, + model_config=gpt_model_cfg, + train_config=train_cfg, + dataset_config_override=dataset_cfg, + ) + error_msg = ( + "micro_batch_size should be greater than 1 when using pack_sequences_in_batch=True. " + "In-batch packing concatenates multiple sequences within a microbatch, so at least 2 sequences " + "are required per micro-batch." + ) + try: + with pytest.raises( + ValueError, + match=error_msg, + ): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_pack_sequences_in_batch_passes_with_micro_batch_size_gt_1(self, monkeypatch): + """Test validation passes when micro_batch_size > 1 with pack_sequences_in_batch=True.""" + gpt_model_cfg = create_test_gpt_config() + train_cfg = create_test_training_config(micro_batch_size=4, global_batch_size=32) + dataset_cfg = create_test_finetuning_dataset_config(sequence_length=512) + dataset_cfg.pack_sequences_in_batch = True + + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, + model_config=gpt_model_cfg, + train_config=train_cfg, + dataset_config_override=dataset_cfg, + ) + + try: + container.validate() # Should pass without error + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + @pytest.mark.parametrize( "seq_length, context_parallel_size, expect_assertion_error", [