Skip to content

Update rl_replacements.py, merging my changes with dattas original ch…#3

Merged
Datta0 merged 3 commits into
Datta0:trl_upgrade_fixfrom
pluesclues:trl_upgrade_fix
May 26, 2025
Merged

Update rl_replacements.py, merging my changes with dattas original ch…#3
Datta0 merged 3 commits into
Datta0:trl_upgrade_fixfrom
pluesclues:trl_upgrade_fix

Conversation

@pluesclues
Copy link
Copy Markdown
Collaborator

…anges

@Datta0 Datta0 merged commit 4424b21 into Datta0:trl_upgrade_fix May 26, 2025
danielhanchen added a commit that referenced this pull request Apr 19, 2026
…n, finalize_huggingface_model

- patch_gemma4_vllm_lora_support: use functools.wraps on patched_create_lora_manager so
  _call_create_lora_manager's signature inspection still sees vllm_config; pass model
  positionally to lora_manager_cls to avoid "multiple values for 'model'".
- patch_gemma4_vllm_k_eq_v_support: also handle split k_proj/v_proj layout (current
  upstream Gemma4) by duplicating k quant-state to synthetic v entry; keep packed
  qkv_proj path as fallback.
- load_vllm: gate Gemma4 patches on enable_lora / use_bitsandbytes (not is_vision_model),
  so text-only Gemma4 + LoRA / BnB also works.
- extract_gdn_layers: derive qkvz offsets from gdn.key_dim/value_dim when
  ColumnParallelLinear has no output_sizes; manually split in_proj_ba into b/a instead
  of calling get_state_dict with kk=1 (IndexError); preserve BnB quant_state sidecars;
  handle FP8 weight_scale (not only weight_scale_inv) and dynamic/row-wise FP8;
  export linear_attn.norm.weight.
- finalize_huggingface_model: fix layer_idx for standard causal LMs (not only VLM path);
  rebuild Gemma4 vision rotary_emb from vision_config with fp32 buffers; guard
  rotary_pos_emb on vision_config availability; mirror language_model detection from
  set_additional_modules.
- get_model_layer_config: register Gemma4 per_layer_input_gate / per_layer_projection /
  post_per_layer_input_norm; add Qwen3.5 visual.merger.linear_fc1 / linear_fc2 and drop
  the broken linear_fc{kk} template.
- set_dtype_in_config (hf_utils): prefer the modern 'dtype' field; fall back to
  'torch_dtype' only when 'dtype' is absent, avoiding the deprecation warning on
  current transformers.
- vllm_utils state-dict loop: skip layer.mlp extraction for linear-attn-only layers
  (defensive) while still capturing layer_scalar.
- _normalize_state_dict_tensor: guard is_sparse behind isinstance(value, torch.Tensor)
  so non-tensor state-dict values pass through.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants