diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 664bc1a253..fb9c720c20 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -92,7 +92,7 @@ def _parallelize_gemma3( Tensor parallelism is not supported for Gemma3 models because of tied word embeddings. """ if isinstance(model, Gemma3ForConditionalGeneration): - model_prefix = "language_model.model" + model_prefix = "language_model" else: model_prefix = "model" @@ -399,7 +399,7 @@ def _parallelize_model( """ model_cls = type(model) if model_cls == Gemma3ForConditionalGeneration: - layers: torch.nn.ModuleList = model.language_model.model.layers # type: ignore + layers: torch.nn.ModuleList = model.language_model.layers # type: ignore num_attention_heads = model.config.text_config.num_attention_heads num_key_value_heads = model.config.text_config.num_key_value_heads else: