Skip to content

[dependencies] Upgrade transformers to >=5.0.0,<=5.3.0#1426

Merged
erictang000 merged 15 commits intomainfrom
transformers_v5
Apr 4, 2026
Merged

[dependencies] Upgrade transformers to >=5.0.0,<=5.3.0#1426
erictang000 merged 15 commits intomainfrom
transformers_v5

Conversation

@erictang000
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 commented Apr 1, 2026

Upgrade to transformers v5

Summary

Upgrades transformers from >=4.56.1,<5 to >=5.0.0,<=5.3.0 and adapts SkyRL's model initialization, FSDP loading, and test code to accommodate v5 breaking changes.

CI

Round 2 CI: https://github.com/NovaSky-AI/SkyRL/actions/runs/23917102581 -> 10 failing from before
Megatron CI Round 2: https://github.com/NovaSky-AI/SkyRL/actions/runs/23959241150/job/69884903884 -> 1 failing from before

~~Round 1 CI: https://github.com/NovaSky-AI/SkyRL/actions/runs/23876002482 ~~ -> 17 still failing
Megatron CI: https://github.com/NovaSky-AI/SkyRL/actions/runs/23920479124

Key changes

Meta-device model initialization (fsdp_utils.py, model_wrapper.py, fsdp_worker.py)

v5 disallows from_pretrained() inside accelerate.init_empty_weights() (TypeError: Parameter.__new__() got an unexpected keyword argument '_is_hf_initialized'). Replaced with:

  • Rank 0: from_pretrained() (loads real weights)
  • Non-rank-0: from_config() inside torch.device("meta") (empty shell; weights broadcast by FSDP)

rope_scaling, rope_theta, and _attn_implementation are applied to the config before the branch so both paths are consistent.

FSDP2 non-persistent buffer sync (fsdp_utils.py)

from_config on meta produces non-persistent buffers (inv_freq in RotaryEmbedding) with no data. These are excluded from state_dict() and never broadcast. Fixes:

  • _sync_non_persistent_buffers() broadcasts these from rank 0 after state dict loading
  • offload_fsdp2_model_to_cpu() now materializes only meta buffers instead of calling model.to_empty() (which wiped all loaded parameters → NaN)

CriticModel post_init() (model_wrapper.py)

v5 added all_tied_weights_keys in PreTrainedModel.post_init(). The dynamic CriticModel class now calls self.post_init(), and the meta-init path wraps construction in no_init_weights().

Strict dataclass configs (configs.py)

PretrainedConfig is now a strict dataclass. Made ModelConfig.__init__ args optional with defaults; fixed get_text_config() signature for v5.

VLM mm_token_type_ids (model_wrapper.py, VLM tests)

v5 requires mm_token_type_ids for M-RoPE in multimodal models. Threaded through HFModelWrapper.forward() and tests.

Megatron rope_theta (megatron_worker.py)

v5 moved rope_theta into rope_parameters dict. Added workaround to set provider.rotary_base from the new location.

Other fixes

  • cuda_ipc_strategy.py: .view(-1).reshape(-1) for non-contiguous weight tensors
  • vllm_server.py: guard sock.close() against uvloop TransportSocket AttributeError
  • test_remote_inference_client_chat_template.py: use render_chat_completion() for prompt token verification

Open with Devin

gemini-code-assist[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

mhc_expansion_rate: mHC expansion rate. Connectors are trainable when this is > 1.
"""

# Type hints for config attributes
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove these? It would be good to keep them for documentation purposes if possible :)

@pcmoritz
Copy link
Copy Markdown
Collaborator

pcmoritz commented Apr 3, 2026

For the tx backend, you will also need to adapt for the change that rope_theta moved to rope_arameters, currently it is failing with

AttributeError: 'ModelConfig' object has no attribute 'rope_theta'

# Broadcast non-persistent buffers (e.g. inv_freq from RotaryEmbedding) that
# are excluded from state_dict. On non-rank-0 meta-init these are still on
# meta device with no data; rank 0 has the correctly computed values.
_sync_non_persistent_buffers(model, sharded_sd)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious do you know why upgrading transformers necessitates this change? Seems a little surprising :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the claude writeup of why this was needed:

image

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing :)

if hasattr(provider, "q_lora_rank") and hasattr(hf_config, "q_lora_rank"):
provider.q_lora_rank = hf_config.q_lora_rank

# Workaround for transformers v5 moving rope_theta into rope_parameters
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why this is needed, since megatron-bridge already updated NVIDIA-NeMo/Megatron-Bridge#2068 -- if this is still needed, should we raise an issue against megatron-bridge so we can remove this workaround going forward?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe bumping to latest main on megatron-bridge which i'm planning to do here: #1425

should fix this, but i wanted to isolate the transformers v5 update in this PR.

I can look into removing this in #1425

def __init__(
self,
config: PretrainedConfig | dict,
config: PretrainedConfig | dict | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this is needed now? Who is calling this without passing in a config? I'm also concerned that the defaults in

max_lora_adapters: int = 0,
max_lora_rank: int = 0,
shard_attention_heads: bool = True,

could cause trouble and it would be better to not need to have the **kwargs part, since it can mask problems.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like it's this PR in transformers 5.4.0: huggingface/transformers#41250

i'm pinning to <= 5.3.0 so it actually isn't an issue right now (but i guess i was testing with 5.4.0 when originally changing this code). I can revert the changes here for now and we can revisit when upgrading to >=5.4.0.

seems like megatron-bridge caps at <=5.3.0 as well and there's some relevant activity on transformers so these changes could be avoided in the changes: huggingface/transformers#45070

devin-ai-integration[bot]

This comment was marked as resolved.

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
hf_model = AutoModelForCausalLM.from_pretrained(
model_name, attn_implementation="eager", use_safetensors=True, trust_remote_code=True
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed now since behavior in v5 changed from defaulting to float32 to defaulting to the default model dtype.

devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp

@erictang000 erictang000 merged commit bcf680f into main Apr 4, 2026
6 checks passed
@erictang000 erictang000 deleted the transformers_v5 branch April 4, 2026 19:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants