Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughAdds Energy-Based Fine-Tuning (EBFT) to axolotl: new trainers (structured/async/strided), feature-matching reward logic, Triton kernels, dataset transforms/prompt strategies, config/schema validation, vLLM serving/weight-sync extensions, example configs and documentation. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://69c31535cfa9d9b93c3450c3--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit d111a05 |
There was a problem hiding this comment.
Actionable comments posted: 19
🧹 Nitpick comments (15)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (2)
242-261: Think-tag masking logic may miss multi-token tags.The masking logic assumes
<think>and</think>each tokenize to a single token. If the tokenizer splits these tags into multiple tokens (e.g.,<,think,>),think_open_idwill beunk_token_idand masking will be silently skipped.Consider adding a warning or fallback to text-based span detection when single-token IDs are unavailable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 242 - 261, The current mask_thinking block assumes "<think>" and "</think>" map to single token IDs (think_open_id/think_close_id) and silently skips masking when they equal tokenizer.unk_token_id; update it to detect and handle multi-token tags: if think_open_id or think_close_id == tokenizer.unk_token_id, fall back to scanning the decoded substring (e.g., tokenizer.decode(input_ids[scan_start:end]) or joining tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>" and "</think>" text spans and then map those spans back to token index ranges to set labels[i] = -100, and/or emit a warning when single-token IDs are unavailable; keep the existing behavior when single-token IDs are present. Ensure you modify the mask_thinking branch and use input_ids, labels, start, end, tokenizer, think_open_id/think_close_id identifiers.
35-40: Remove unused helper function_extract_thinking.The function is defined but never called anywhere in the codebase. Since it's prefixed with an underscore (indicating private scope), there's no indication of intended external usage.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 35 - 40, The private helper function _extract_thinking is unused; remove its definition from ebft_reasoning.py (delete the def _extract_thinking(...) block) and ensure there are no remaining references to _extract_thinking elsewhere; if the re import is now unused after removal, also remove the import to avoid linter warnings and run tests/linters to confirm no residual uses.examples/ebft/ebft_opencode.py (1)
18-20: Consider usingremove_columns: "__all__"to avoid schema drift.The explicit list is brittle if upstream dataset columns change.
♻️ Suggested simplification
- return transform_fn, {"remove_columns": ["id", "domain", "generation_algorithm", - "llm_judgement", "unit_tests", - "tests_execution_status", "average_test_score"]} + return transform_fn, {"remove_columns": "__all__"}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/ebft_opencode.py` around lines 18 - 20, The return value currently hardcodes a list of columns to drop which is brittle; update the second element of the returned tuple (the transformer config returned alongside transform_fn in ebft_opencode.py) to use "remove_columns": "__all__" instead of the explicit list so the transformer removes all original dataset columns and avoids schema drift while preserving transformed outputs from transform_fn.src/axolotl/utils/schemas/config.py (3)
66-70: Consider usingLiteraltype forembed_methodvalidation.The
embed_methodfield accepts any string but the description lists specific valid values. Using aLiteraltype would provide compile-time validation and better IDE support.♻️ Suggested improvement
+from typing import Literal + +EmbedMethod = Literal["last_token", "mean_pooling", "concat"] + - embed_method: str = Field( + embed_method: EmbedMethod = Field( default="last_token", json_schema_extra={ "description": "Embedding method: 'last_token', 'mean_pooling', or 'concat'" }, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 66 - 70, The embed_method Field currently allows any string even though the docstring lists specific valid values; change its type from str to typing.Literal with the allowed options (e.g., Literal["last_token","mean_pooling","concat"]) and update the Field declaration for embed_method in the Config schema so Pydantic/typing enforces and IDEs surface the valid choices; keep the same json_schema_extra description but ensure the attribute name embed_method in the relevant class/schema is updated to use the Literal type.
132-134: Consider usingLiteralforadvantage_estimatorvalidation.The
advantage_estimatorhas three defined valid values that could benefit from type-level validation.♻️ Suggested improvement
- advantage_estimator: str = Field( + advantage_estimator: Literal["rloo", "group_norm", "reinforce"] = Field( default="rloo",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 132 - 134, Change the advantage_estimator field to use typing.Literal for strict validation: update the type annotation of advantage_estimator to Literal["rloo", "group_norm", "reinforce"] (keep the default "rloo"), add the Literal import, and leave the existing Field(json_schema_extra=...) intact so pydantic/JSON schema will enforce and document the allowed values; locate and modify the advantage_estimator declaration and add the import near other typing imports.
98-102: Consider usingLiteralformodevalidation.Similar to
embed_method, themodefield has defined valid values that could be enforced with aLiteraltype.♻️ Suggested improvement
- mode: str = Field( + mode: Literal["structured", "strided"] = Field( default="structured",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/config.py` around lines 98 - 102, The mode field currently declared as mode: str = Field(...) should enforce allowed values using typing.Literal like embed_method does; update the type annotation for mode to use Literal["structured", "strided"] (and import Literal) and remove or keep the Field default/json_schema_extra as needed so Pydantic validates values at type level; target the mode declaration in src/axolotl/utils/schemas/config.py (near the existing embed_method pattern) to make this change.examples/ebft/llama-1b-ebft-opencode.yaml (1)
64-64: Non-zerolora_dropoutmay disable LoRA kernel optimizations.Setting
lora_dropout: 0.05typically disables auto-enabled LoRA kernel optimizations (lora_mlp_kernel,lora_qkv_kernel,lora_o_kernel). If you want the kernel speedups, consider settinglora_dropout: 0.0or explicitly enabling kernels.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/llama-1b-ebft-opencode.yaml` at line 64, The lora_dropout: 0.05 setting will likely disable auto-enabled LoRA kernel optimizations; change the lora_dropout value to 0.0 in the configuration or explicitly enable the kernels (lora_mlp_kernel, lora_qkv_kernel, lora_o_kernel) so the LoRA kernel speedups remain active—update the lora_dropout entry or add boolean flags for the three kernel options to ensure optimizations are not disabled.src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py (1)
42-43: Use defensive message-field access to avoid hard failures on bad rows.These transforms directly index message keys. A single malformed record can crash
datasets.mapwithKeyError. Preferdict.get(...)+ skip/fallback handling.Example hardening pattern
- if msg["role"] == "assistant" and not found_first: - first_gt = msg["content"] + role = msg.get("role") if isinstance(msg, dict) else None + content = msg.get("content", "") if isinstance(msg, dict) else "" + if role == "assistant" and not found_first: + first_gt = content found_first = True elif found_first: remaining.append(msg) else: prompt_msgs.append(msg)Also applies to: 56-57, 83-85, 116-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py` around lines 42 - 43, The loop in ebft_chat_multiturn.py is indexing message dicts directly (e.g., msg["role"], msg["content"]) which can raise KeyError on malformed rows; update the code in the functions/blocks that set first_gt/last_gt and append messages (references: variables msg, first_gt, last_gt, found_first, the message-processing loop) to use dict.get("role") and dict.get("content") and skip or continue when required keys are missing or not strings (e.g., if role is None or content is None: continue), ensuring all places noted in the review (around the current first_gt/last_gt assignments and the other mentioned blocks) perform defensive checks before accessing message fields.src/axolotl/utils/schemas/validation.py (1)
1572-1575: Guard EBFT strided length math with explicit parameter validation.Line 1574 can throw a raw
ZeroDivisionError(or produce invalid block math) for bad configs. Prefer an explicitValueErrorwith clear guidance.Proposed hardening
stride = ebft.get("stride", 8) ctx_len = ebft.get("context_length", 8) - max_blocks = (seq_len - gen_len - ctx_len) // stride + 1 + if stride <= 0: + raise ValueError("ebft.stride must be > 0 in strided mode") + if seq_len <= gen_len + ctx_len: + raise ValueError( + "sequence_len must be greater than ebft.generate_max_len + ebft.context_length" + ) + max_blocks = (seq_len - gen_len - ctx_len) // stride + 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/schemas/validation.py` around lines 1572 - 1575, The strided-length calculation using ebft.get("stride") and ebft.get("context_length") can raise ZeroDivisionError or produce invalid math; validate parameters first: ensure stride is an int > 0 and context_length (ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor (seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a ValueError with a clear message referencing the offending keys ("stride" / "context_length") and the values (seq_len, gen_len) so callers know how to correct the EBFT config before computing max_blocks and full_seq.src/axolotl/core/trainers/ebft/__init__.py (1)
60-166: Consider reducing repetition in kwargs mapping.The
is not Nonechecks are repetitive. A helper function or loop over a mapping could reduce boilerplate, though this is a stylistic preference.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/__init__.py` around lines 60 - 166, The mapping in set_training_args_kwargs repeats many "if X is not None: kwargs[...]=X" blocks for ebft and trl; refactor by introducing a small helper (e.g., a local function map_if_present or apply_mapping) that accepts a source object and an iterable of (attr_name, kwarg_key) pairs and sets kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or truthy where appropriate), then replace the repeated blocks for ebft and trl with calls to this helper and explicit handling only for special cases like vllm colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode inside set_training_args_kwargs.src/axolotl/core/trainers/ebft/strided.py (1)
314-314: Mutable class attribute detected by Ruff (RUF012).Class-level mutable defaults can cause unexpected sharing. While
_tag_namesis unlikely to be mutated, using a tuple is safer.♻️ Use tuple instead of list
- _tag_names = ["ebft", "strided", "axolotl"] + _tag_names = ("ebft", "strided", "axolotl")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/strided.py` at line 314, Replace the mutable class attribute _tag_names (currently a list) with an immutable tuple to avoid accidental shared-state; locate the _tag_names declaration in the Strided trainer class in strided.py and change ["ebft", "strided", "axolotl"] to ("ebft", "strided", "axolotl") so the class-level constant is immutable.examples/ebft/qwen35-4b-ebft-structured-async.yaml (1)
70-72: Complex regex for LoRA targeting — verify it matches intended layers.The regex pattern targets full-attention layers (3,7,11,15,19,23,27,31) and MLP on all layers. This is a careful design for hybrid attention models, but the pattern complexity makes it easy to miss layers.
Consider adding a verification script or test to confirm the pattern matches exactly the intended modules when applied to the model.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml` around lines 70 - 72, The lora_target_modules regex is complex and may miss or overmatch intended module names; add a small verification helper that loads the model's state_dict or module names and tests each key against lora_target_modules to assert that exactly the intended layer indices (3,7,11,15,19,23,27,31) for self_attn.(q|k|v|o)_proj and all mlp.(gate|up|down)_proj are matched; implement this check as a unit test or a CLI validation (e.g., verify_lora_targets or validate_lora_regex) that prints unmatched expected modules and any unexpected matches so you can adjust the regex if it misfires.src/axolotl/core/trainers/ebft/trainer.py (3)
44-44: Mutable class attribute (RUF012).Same as strided.py — use tuple for immutable tag list.
♻️ Use tuple instead of list
- _tag_names = ["trl", "ebft", "axolotl"] + _tag_names = ("trl", "ebft", "axolotl")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` at line 44, The class-level _tag_names is defined as a mutable list; change it to an immutable tuple to avoid RUF012. Replace _tag_names = ["trl", "ebft", "axolotl"] with _tag_names = ("trl", "ebft", "axolotl") in trainer.py (same pattern as in strided.py) so the attribute is immutable at class scope.
364-366: BareExceptioncatch is overly broad (BLE001).Catching all exceptions can mask unexpected errors. Consider catching specific vLLM client exceptions or at minimum re-raising after logging for non-recoverable errors.
♻️ Narrow exception handling
- except Exception as e: - LOG.warning(f"Multi-turn rollout generation failed: {e}") - gen_text = "" + except (ConnectionError, TimeoutError, RuntimeError) as e: + LOG.warning(f"Multi-turn rollout generation failed: {e}") + gen_text = ""Or if the vLLM client has specific exception types, catch those instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 364 - 366, The current broad except Exception block around the multi-turn rollout generation masks unexpected errors; replace it by catching concrete vLLM/client errors (e.g., vllm.client.exceptions.VLLMError or the client-specific exception types) and handle them by logging and setting gen_text as before, and add a separate fallback that logs with LOG.exception(...) and re-raises for any other unexpected exceptions; import the client exception types at top and update the try/except in the multi-turn rollout generation block (the one that currently sets gen_text = "") to use specific except clauses and a final except Exception to re-raise after logging.
60-70: Mypy errors onsuper().__init__()call are expected for mixin pattern.The Mypy errors about "unexpected keyword arguments for
__init__of object" occur because Mypy doesn't see the full MRO at the mixin level. This is a known limitation with mixin patterns. Consider adding aTYPE_CHECKINGblock with protocol hints or# type: ignorecomments if the errors are noisy in CI.♻️ Silence Mypy for mixin super() call
- super().__init__( + super().__init__( # type: ignore[call-arg] model=model,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/trainer.py` around lines 60 - 70, Mypy complains about unexpected keyword args on the super().__init__() call because of the mixin pattern; either silence it by adding a TYPE_CHECKING block that defines a minimal Protocol for the base initializer signature (import TYPE_CHECKING from typing and declare a Protocol with __init__(..., model, reward_funcs, args, train_dataset, eval_dataset, processing_class, callbacks, optimizers, peft_config) and use it only under TYPE_CHECKING), or add a scoped type ignore on the call (e.g., append # type: ignore[arg-type] to the super().__init__(...) line) so the mixin pattern no longer fails CI; target the super().__init__ invocation in trainer.py (the call that passes model, reward_funcs=[self._feature_matching_reward], args, train_dataset, eval_dataset, processing_class, callbacks, optimizers, peft_config).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/ebft/ebft_pretrain.py`:
- Around line 27-31: The returned dict aliases labels to encoded["input_ids"],
which can create shared references when the tokenizer returns Python lists;
change the assignment so "labels" is a shallow copy of encoded["input_ids"]
(e.g., use list(...) or .copy()) instead of referencing the same object; update
the return in ebft_pretrain.py that currently returns {"input_ids":
encoded["input_ids"], "attention_mask": encoded["attention_mask"], "labels":
encoded["input_ids"]} to set "labels" to a copy of encoded["input_ids"] to avoid
accidental mutation of input_ids.
- Line 17: The variable pad_id assigned from tokenizer.pad_token_id or
tokenizer.eos_token_id is unused; remove the pad_id assignment in
examples/ebft/ebft_pretrain.py (the line creating pad_id) and rely on the
tokenizer's built-in padding (used later with padding="max_length"), ensuring no
other code references pad_id (search for pad_id to confirm) and run tests or
lint to verify no remaining usages.
In `@examples/ebft/ebft_strided_structured.py`:
- Around line 76-77: The file examples/ebft/ebft_strided_structured.py is
missing a trailing newline; open the file and add a single newline character at
the end of the file (ensure it ends with exactly one '\n') so the pre-commit
end-of-file-fixer check passes.
In `@examples/ebft/llama-3b-ebft-strided-fft.yaml`:
- Around line 53-55: The EBFT validator fails because strided EBFT with
gradient_checkpointing enabled is incompatible with torch_compile=true and
requires reentrant checkpointing; update the YAML to set torch_compile: false
(change the torch_compile key) and set
gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant value)
so the configuration meets the validator's requirements for strided EBFT and
flex-attention checkpointing.
In `@examples/ebft/qwen35-4b-ebft-structured.yaml`:
- Around line 32-33: Replace the non-routable wildcard host value used for the
TRL vllm server with a loopback address: change the vllm_server_host setting
(trl.vllm_server_host) from "0.0.0.0" to "127.0.0.1" so clients connect to a
routable local address; update any corresponding examples or documentation in
the same YAML file to use 127.0.0.1 for vllm_server_host while leaving
vllm_server_port as-is (8000).
In `@examples/ebft/README.md`:
- Around line 187-188: Update the strided-mode performance guidance to stop
recommending "torch_compile: true" because EBFT validation now warns/errors when
torch_compile is enabled with strided mode and gradient checkpointing; replace
that sentence so it either recommends leaving torch_compile disabled for strided
configurations or documents the validation restriction, and keep the explanation
of "flex_attention" behavior and fallback as-is so users know flex_attention is
used when available.
In `@src/axolotl/cli/vllm_serve.py`:
- Around line 82-86: The current precedence uses Python's truthy "or" so an
explicit CLI False (cli_args.get("enforce_eager") == False) is ignored when
cfg.vllm.enforce_eager is truthy; change the logic to prefer an explicitly
provided CLI value by checking presence/None instead of truthiness: if
cli_args.get("enforce_eager") is not None use
bool(cli_args.get("enforce_eager")) else use getattr(cfg.vllm, "enforce_eager",
False). Replace the current assignment to enforce_eager with this pattern and
apply the same change to the other occurrence referenced (the second
enforce_eager assignment at the other location).
- Around line 109-110: The current check calls getattr(cfg.trl,
"vllm_lora_sync", False) and will raise AttributeError if cfg has no trl
attribute; change the guard to safely access trl first (e.g., use getattr(cfg,
"trl", None) and then getattr(..., "vllm_lora_sync", False) or check
hasattr(cfg, "trl") before reading vllm_lora_sync) so that when trl is absent
you fall back to False and still set lora_kwargs["enable_lora"] = False; update
the expression around cfg.trl, "vllm_lora_sync", False to use a safe nested
getattr or an existence check so the code never dereferences a missing trl.
In `@src/axolotl/common/datasets.py`:
- Around line 121-124: Normalize cfg.rl to an RLType before the membership check
so string values like "grpo"/"ebft" are treated the same as RLType.GRPO/EBFT;
update the block that computes total_num_steps (referencing cfg.rl and
total_num_steps) to first map/normalize cfg.rl into an RLType (or compare
against lowercased names) and then check membership against {RLType.GRPO,
RLType.EBFT} so GRPO/EBFT are properly excluded whether cfg.rl is provided as a
string or an enum.
In `@src/axolotl/core/trainers/ebft/kernels.py`:
- Around line 256-261: The kernel divides by (N - 1) and when N == 1 this yields
a divide-by-zero; update the Python wrapper fused_diversity_penalty to guard
against N <= 1 by short-circuiting before launching the kernel: detect the input
size (N), and if N <= 1 return an appropriately shaped tensor of zeros (or the
intended neutral penalty) without calling the kernel; otherwise proceed to call
the existing kernel as before. Ensure the check references
fused_diversity_penalty and the N dimension used to compute (N - 1) so the
kernel never receives N == 1.
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 205-213: The whitening currently builds W from left singular
vectors U (producing a (B,B) matrix) but must operate in feature space: use the
right singular vectors V and singular values S to build a (D,D) whitening matrix
W = V @ diag(inv_s) @ V.T (where inv_s is computed with whiten_tol as before),
then apply it to features via phi_w = phi_f @ W.T and phi_gt_w = phi_gt_f @ W.T
(ensure dtype casts remain as phi.dtype / phi_gt.dtype); replace usages of U, W
(B,B) with V and the new (D,D) W to correct the transform.
In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 768: The return uses an undefined outputs when return_outputs=True;
update the function (the block that computes loss when backbone is not None) to
always set a value for outputs (e.g., outputs = None or the actual model
outputs) before the final return, or change the final return to return (loss,
None) when outputs were not produced; ensure references to outputs,
return_outputs, and backbone in this function (the strided trainer method
containing the loss computation) are adjusted so outputs is always defined when
return_outputs is True.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 181-190: The zip(prompts, ground_truth) usage in the loop that
builds gt_texts should be made strict to avoid silent length mismatches; update
the zip call inside the loop (the one iterating with for i, (p, gt) in
enumerate(...)) to zip(prompts, ground_truth, strict=True) so mismatched lengths
raise an error, keeping the rest of the logic that uses num_gens,
processing_class.apply_chat_template, and gt_texts unchanged.
- Around line 155-166: The loop in trainer.py that iterates over prompts and
completions uses zip(prompts, completions) without the strict parameter; change
it to zip(prompts, completions, strict=True) to enforce equal lengths and
surface mismatches early, keeping the existing handling of list vs. scalar
prompt/completion values and appending combined strings to gen_texts
(references: prompts, completions, gen_texts, and
processing_class.apply_chat_template).
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 61-64: The code uses `"prompt_msgs_snapshot" in dir()` to detect
whether prompt_msgs_snapshot was assigned; replace this unreliable check by
explicitly tracking assignment: either initialize prompt_msgs_snapshot = None at
the top of each function (transform, transform_split_thinking,
transform_answer_only) and test `if prompt_msgs_snapshot is not None` before
using it, or set a boolean flag (e.g., has_prompt_snapshot = False -> True when
you create the snapshot) and test that flag in the return. Update the three
functions (transform, transform_split_thinking, transform_answer_only) to use
the sentinel/flag instead of the dir() check so the branch is deterministic.
In `@src/axolotl/scripts/vllm_serve_lora.py`:
- Around line 505-526: The http_update_weights handler is using torch
(torch.frombuffer and getattr(torch,...)) but torch is not imported, causing a
NameError; add an import for torch (preferably at module top) so
http_update_weights can reference it. Also address the unused results variable
returned from asyncio.gather: either consume/check results for worker
responses/errors (e.g., inspect the list returned by asyncio.gather) or remove
the assignment and await the gather call solely for synchronization. Target
symbols: http_update_weights, torch.frombuffer, getattr(torch,...), connections,
asyncio.gather, and results.
In `@src/axolotl/train.py`:
- Around line 141-142: The current condition accesses cfg.trl directly which can
raise if cfg has no trl attribute; change it to safely obtain trl first (e.g.
trl = getattr(cfg, 'trl', None)) and then check if trl is not None before
reading beta, e.g. if cfg.rl in {RLType.GRPO, RLType.EBFT} and trl and
getattr(trl, 'beta', 0) == 0: reference_model = False — this ensures accessing
beta is guarded and prevents attribute errors on cfg.trl.
In `@src/axolotl/utils/data/rl.py`:
- Around line 223-229: The removal of columns logic assumes a "train" split and
will break for DatasetDicts without that key; change the DatasetDict branch in
the remove_columns resolution to pick the first available split dynamically
(e.g., use next(iter(dataset)) or list(dataset.keys())[0]) and read its
.column_names instead of dataset["train"].column_names so it works for arbitrary
split names; update the code path that computes ds_columns (the block guarded by
isinstance(dataset, DatasetDict)) to use the first split's column_names and keep
the existing Dataset and fallback behaviors unchanged.
In `@src/axolotl/utils/schemas/validation.py`:
- Line 1579: Replace the Unicode multiplication sign in the EBFT log message
with an ASCII character to satisfy Ruff RUF001: update the f-string containing
"EBFT strided: full_seq_len={full_seq} × n_samples={n_samples} = " to use "x"
(e.g. "full_seq_len={full_seq} x n_samples={n_samples}") or "*" instead; locate
the string that references variables full_seq and n_samples in the validation
code and make this simple character substitution.
---
Nitpick comments:
In `@examples/ebft/ebft_opencode.py`:
- Around line 18-20: The return value currently hardcodes a list of columns to
drop which is brittle; update the second element of the returned tuple (the
transformer config returned alongside transform_fn in ebft_opencode.py) to use
"remove_columns": "__all__" instead of the explicit list so the transformer
removes all original dataset columns and avoids schema drift while preserving
transformed outputs from transform_fn.
In `@examples/ebft/llama-1b-ebft-opencode.yaml`:
- Line 64: The lora_dropout: 0.05 setting will likely disable auto-enabled LoRA
kernel optimizations; change the lora_dropout value to 0.0 in the configuration
or explicitly enable the kernels (lora_mlp_kernel, lora_qkv_kernel,
lora_o_kernel) so the LoRA kernel speedups remain active—update the lora_dropout
entry or add boolean flags for the three kernel options to ensure optimizations
are not disabled.
In `@examples/ebft/qwen35-4b-ebft-structured-async.yaml`:
- Around line 70-72: The lora_target_modules regex is complex and may miss or
overmatch intended module names; add a small verification helper that loads the
model's state_dict or module names and tests each key against
lora_target_modules to assert that exactly the intended layer indices
(3,7,11,15,19,23,27,31) for self_attn.(q|k|v|o)_proj and all
mlp.(gate|up|down)_proj are matched; implement this check as a unit test or a
CLI validation (e.g., verify_lora_targets or validate_lora_regex) that prints
unmatched expected modules and any unexpected matches so you can adjust the
regex if it misfires.
In `@src/axolotl/core/trainers/ebft/__init__.py`:
- Around line 60-166: The mapping in set_training_args_kwargs repeats many "if X
is not None: kwargs[...]=X" blocks for ebft and trl; refactor by introducing a
small helper (e.g., a local function map_if_present or apply_mapping) that
accepts a source object and an iterable of (attr_name, kwarg_key) pairs and sets
kwargs[kwarg_key] = getattr(source, attr_name) when the value is not None (or
truthy where appropriate), then replace the repeated blocks for ebft and trl
with calls to this helper and explicit handling only for special cases like vllm
colocate logic, async_prefetch, vllm_server_host/port and vllm_enable_sleep_mode
inside set_training_args_kwargs.
In `@src/axolotl/core/trainers/ebft/strided.py`:
- Line 314: Replace the mutable class attribute _tag_names (currently a list)
with an immutable tuple to avoid accidental shared-state; locate the _tag_names
declaration in the Strided trainer class in strided.py and change ["ebft",
"strided", "axolotl"] to ("ebft", "strided", "axolotl") so the class-level
constant is immutable.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Line 44: The class-level _tag_names is defined as a mutable list; change it to
an immutable tuple to avoid RUF012. Replace _tag_names = ["trl", "ebft",
"axolotl"] with _tag_names = ("trl", "ebft", "axolotl") in trainer.py (same
pattern as in strided.py) so the attribute is immutable at class scope.
- Around line 364-366: The current broad except Exception block around the
multi-turn rollout generation masks unexpected errors; replace it by catching
concrete vLLM/client errors (e.g., vllm.client.exceptions.VLLMError or the
client-specific exception types) and handle them by logging and setting gen_text
as before, and add a separate fallback that logs with LOG.exception(...) and
re-raises for any other unexpected exceptions; import the client exception types
at top and update the try/except in the multi-turn rollout generation block (the
one that currently sets gen_text = "") to use specific except clauses and a
final except Exception to re-raise after logging.
- Around line 60-70: Mypy complains about unexpected keyword args on the
super().__init__() call because of the mixin pattern; either silence it by
adding a TYPE_CHECKING block that defines a minimal Protocol for the base
initializer signature (import TYPE_CHECKING from typing and declare a Protocol
with __init__(..., model, reward_funcs, args, train_dataset, eval_dataset,
processing_class, callbacks, optimizers, peft_config) and use it only under
TYPE_CHECKING), or add a scoped type ignore on the call (e.g., append # type:
ignore[arg-type] to the super().__init__(...) line) so the mixin pattern no
longer fails CI; target the super().__init__ invocation in trainer.py (the call
that passes model, reward_funcs=[self._feature_matching_reward], args,
train_dataset, eval_dataset, processing_class, callbacks, optimizers,
peft_config).
In `@src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py`:
- Around line 42-43: The loop in ebft_chat_multiturn.py is indexing message
dicts directly (e.g., msg["role"], msg["content"]) which can raise KeyError on
malformed rows; update the code in the functions/blocks that set
first_gt/last_gt and append messages (references: variables msg, first_gt,
last_gt, found_first, the message-processing loop) to use dict.get("role") and
dict.get("content") and skip or continue when required keys are missing or not
strings (e.g., if role is None or content is None: continue), ensuring all
places noted in the review (around the current first_gt/last_gt assignments and
the other mentioned blocks) perform defensive checks before accessing message
fields.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 242-261: The current mask_thinking block assumes "<think>" and
"</think>" map to single token IDs (think_open_id/think_close_id) and silently
skips masking when they equal tokenizer.unk_token_id; update it to detect and
handle multi-token tags: if think_open_id or think_close_id ==
tokenizer.unk_token_id, fall back to scanning the decoded substring (e.g.,
tokenizer.decode(input_ids[scan_start:end]) or joining
tokenizer.convert_ids_to_tokens(input_ids[scan_start:end])) to locate "<think>"
and "</think>" text spans and then map those spans back to token index ranges to
set labels[i] = -100, and/or emit a warning when single-token IDs are
unavailable; keep the existing behavior when single-token IDs are present.
Ensure you modify the mask_thinking branch and use input_ids, labels, start,
end, tokenizer, think_open_id/think_close_id identifiers.
- Around line 35-40: The private helper function _extract_thinking is unused;
remove its definition from ebft_reasoning.py (delete the def
_extract_thinking(...) block) and ensure there are no remaining references to
_extract_thinking elsewhere; if the re import is now unused after removal, also
remove the import to avoid linter warnings and run tests/linters to confirm no
residual uses.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 66-70: The embed_method Field currently allows any string even
though the docstring lists specific valid values; change its type from str to
typing.Literal with the allowed options (e.g.,
Literal["last_token","mean_pooling","concat"]) and update the Field declaration
for embed_method in the Config schema so Pydantic/typing enforces and IDEs
surface the valid choices; keep the same json_schema_extra description but
ensure the attribute name embed_method in the relevant class/schema is updated
to use the Literal type.
- Around line 132-134: Change the advantage_estimator field to use
typing.Literal for strict validation: update the type annotation of
advantage_estimator to Literal["rloo", "group_norm", "reinforce"] (keep the
default "rloo"), add the Literal import, and leave the existing
Field(json_schema_extra=...) intact so pydantic/JSON schema will enforce and
document the allowed values; locate and modify the advantage_estimator
declaration and add the import near other typing imports.
- Around line 98-102: The mode field currently declared as mode: str =
Field(...) should enforce allowed values using typing.Literal like embed_method
does; update the type annotation for mode to use Literal["structured",
"strided"] (and import Literal) and remove or keep the Field
default/json_schema_extra as needed so Pydantic validates values at type level;
target the mode declaration in src/axolotl/utils/schemas/config.py (near the
existing embed_method pattern) to make this change.
In `@src/axolotl/utils/schemas/validation.py`:
- Around line 1572-1575: The strided-length calculation using ebft.get("stride")
and ebft.get("context_length") can raise ZeroDivisionError or produce invalid
math; validate parameters first: ensure stride is an int > 0 and context_length
(ctx_len) is an int >= 0 and less than (seq_len - gen_len) so the divisor
(seq_len - gen_len - ctx_len) is non-negative; if any check fails raise a
ValueError with a clear message referencing the offending keys ("stride" /
"context_length") and the values (seq_len, gen_len) so callers know how to
correct the EBFT config before computing max_blocks and full_seq.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d773aa21-0e02-40e5-853d-1c6fb87e4935
📒 Files selected for processing (40)
examples/ebft/README.mdexamples/ebft/ebft_opencode.pyexamples/ebft/ebft_pretrain.pyexamples/ebft/ebft_strided_structured.pyexamples/ebft/llama-1b-ebft-opencode-novllm.yamlexamples/ebft/llama-1b-ebft-opencode.yamlexamples/ebft/llama-1b-ebft-strided-structured.yamlexamples/ebft/llama-1b-ebft-strided.yamlexamples/ebft/llama-3b-ebft-strided-fft.yamlexamples/ebft/llama-8b-ebft-strided-fft.yamlexamples/ebft/qwen35-4b-ebft-structured-async.yamlexamples/ebft/qwen35-4b-ebft-structured.yamlexamples/ebft/qwen35-9b-ebft-structured.yamlsrc/axolotl/cli/vllm_serve.pysrc/axolotl/common/datasets.pysrc/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/__init__.pysrc/axolotl/core/trainers/ebft/__init__.pysrc/axolotl/core/trainers/ebft/args.pysrc/axolotl/core/trainers/ebft/kernels.pysrc/axolotl/core/trainers/ebft/rewards.pysrc/axolotl/core/trainers/ebft/strided.pysrc/axolotl/core/trainers/ebft/trainer.pysrc/axolotl/core/trainers/grpo/async_trainer.pysrc/axolotl/monkeypatch/trainer/trl_vllm.pysrc/axolotl/prompt_strategies/ebft/__init__.pysrc/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.pysrc/axolotl/prompt_strategies/ebft/ebft_opencode.pysrc/axolotl/prompt_strategies/ebft/ebft_reasoning.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_chat.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_structured.pysrc/axolotl/scripts/vllm_serve_lora.pysrc/axolotl/scripts/vllm_worker_ext.pysrc/axolotl/train.pysrc/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/trl.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/schemas/vllm.py
| torch_compile: true | ||
| gradient_checkpointing_kwargs: | ||
| use_reentrant: false |
There was a problem hiding this comment.
This example currently fails EBFT validation at startup.
With strided EBFT + gradient_checkpointing: true, Line 53 (torch_compile: true) is rejected by the new validator. Also, Line 55 (use_reentrant: false) conflicts with the documented flex-attention checkpointing requirement.
Proposed fix
gradient_checkpointing: true
-torch_compile: true
+# Keep torch_compile disabled in EBFT strided mode with gradient checkpointing
gradient_checkpointing_kwargs:
- use_reentrant: false
+ use_reentrant: true📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| torch_compile: true | |
| gradient_checkpointing_kwargs: | |
| use_reentrant: false | |
| # Keep torch_compile disabled in EBFT strided mode with gradient checkpointing | |
| gradient_checkpointing_kwargs: | |
| use_reentrant: true |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/ebft/llama-3b-ebft-strided-fft.yaml` around lines 53 - 55, The EBFT
validator fails because strided EBFT with gradient_checkpointing enabled is
incompatible with torch_compile=true and requires reentrant checkpointing;
update the YAML to set torch_compile: false (change the torch_compile key) and
set gradient_checkpointing_kwargs.use_reentrant: true (change use_reentrant
value) so the configuration meets the validator's requirements for strided EBFT
and flex-attention checkpointing.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/axolotl/utils/callbacks/generation.py (1)
28-60:⚠️ Potential issue | 🔴 CriticalUnindent the sample generation block to fix unreachable code.
Lines 31–60 are incorrectly indented inside the
if not getattr(cfg, "generate_samples", False):block. This means the generation code is unreachable regardless of the config value: whengenerate_samplesis false, thereturnexecutes; when true, the entire if-block is skipped.Dedent lines 31–60 by one level (4 spaces) so they execute when
generate_samplesis true.Suggested fix
if not getattr(cfg, "generate_samples", False): return - dataloader = None - try: - if getattr(self.trainer, "eval_dataset", None) is not None: - dataloader = self.trainer.get_eval_dataloader() - LOG.info( - f"Using eval dataloader for generation at step {state.global_step}" - ) - except Exception as e: - LOG.warning(f"Could not get eval dataloader: {e}") - dataloader = None - - if dataloader is None: - dataloader = self.trainer.get_train_dataloader() - LOG.info( - f"Using train dataloader for generation at step {state.global_step}" - ) - - samples = generate_samples( - model=self.trainer.model, - tokenizer=self.trainer.processing_class, - dataloader=dataloader, - num_generation_samples=getattr(cfg, "num_generation_samples", 3), - max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), - temperature=getattr(cfg, "generation_temperature", 0.7), - top_p=getattr(cfg, "generation_top_p", None), - top_k=getattr(cfg, "generation_top_k", None), - do_sample=getattr(cfg, "generation_do_sample", True), - prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), - ) - self._log_samples(samples, state.global_step) + dataloader = None + try: + if getattr(self.trainer, "eval_dataset", None) is not None: + dataloader = self.trainer.get_eval_dataloader() + LOG.info( + f"Using eval dataloader for generation at step {state.global_step}" + ) + except Exception as e: + LOG.warning(f"Could not get eval dataloader: {e}") + dataloader = None + + if dataloader is None: + dataloader = self.trainer.get_train_dataloader() + LOG.info( + f"Using train dataloader for generation at step {state.global_step}" + ) + + samples = generate_samples( + model=self.trainer.model, + tokenizer=self.trainer.processing_class, + dataloader=dataloader, + num_generation_samples=getattr(cfg, "num_generation_samples", 3), + max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50), + temperature=getattr(cfg, "generation_temperature", 0.7), + top_p=getattr(cfg, "generation_top_p", None), + top_k=getattr(cfg, "generation_top_k", None), + do_sample=getattr(cfg, "generation_do_sample", True), + prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5), + ) + self._log_samples(samples, state.global_step)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/callbacks/generation.py` around lines 28 - 60, The sample-generation block is incorrectly indented under the early-return for getattr(cfg, "generate_samples", False), making it unreachable; dedent the entire block that calls dataloader selection and generate_samples (the try/except fetching self.trainer.get_eval_dataloader(), the fallback to self.trainer.get_train_dataloader(), the call to generate_samples with model=self.trainer.model and tokenizer=self.trainer.processing_class, and the subsequent self._log_samples(samples, state.global_step)) so it runs only when generate_samples is True (i.e., move that block out of the if that contains the return).src/axolotl/core/trainers/grpo/async_trainer.py (1)
645-670:⚠️ Potential issue | 🟠 MajorRestore
_init_vllmafter this trainer finishes initialization.This mutates
VLLMGeneration._init_vllmat class scope and never puts the original method back. After one async/HTTP-only trainer is constructed, later trainers in the same process will also skip communicator init even when they need the stock behavior.Suggested fix
- if _skip_nccl: + restore_init_vllm = None + if _skip_nccl: from trl.generation.vllm_generation import VLLMGeneration _orig_init_vllm = VLLMGeneration._init_vllm + restore_init_vllm = _orig_init_vllm ... VLLMGeneration._init_vllm = _init_vllm_no_communicator - super().__init__(*args, **kwargs) + try: + super().__init__(*args, **kwargs) + finally: + if restore_init_vllm is not None: + VLLMGeneration._init_vllm = restore_init_vllm🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 645 - 670, The patch permanently replaces VLLMGeneration._init_vllm when _skip_nccl is true, causing later trainers to inherit the no-communicator behavior; instead, save _orig_init_vllm, assign VLLMGeneration._init_vllm = _init_vllm_no_communicator only for the duration of this trainer's initialization and restore the original in a finally/cleanup block (or use a context manager) so that the original _init_vllm is reinstated whether initialization succeeds or raises; reference VLLMGeneration._init_vllm, _orig_init_vllm, and _init_vllm_no_communicator and ensure restoration happens after the trainer finishes initialization.
♻️ Duplicate comments (2)
src/axolotl/prompt_strategies/ebft/ebft_reasoning.py (1)
148-151:⚠️ Potential issue | 🟠 MajorReplace the remaining
dir()guards with an explicit sentinel.This is the same unresolved issue from the previous round:
transform_split_thinking()andtransform_answer_only()still useif "prompt_msgs_snapshot" in dir()to decide whether a local was assigned. That keeps the branch dependent on interpreter locals instead of explicit state. Initializeprompt_msgs_snapshot = Nonebefore the loop and testis not Nonein both functions.Also applies to: 172-175
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py` around lines 148 - 151, The code uses if "prompt_msgs_snapshot" in dir() to detect whether prompt_msgs_snapshot was set; instead, initialize prompt_msgs_snapshot = None before the loop and change both guards in transform_split_thinking and transform_answer_only to explicit checks (prompt_msgs_snapshot is not None) so the branch depends on explicit state; update all occurrences (including the similar checks around lines 172-175) to use the sentinel instead of dir() and ensure the functions return prompt_msgs_snapshot when not None and fall back to split_messages[:-1] otherwise.src/axolotl/core/trainers/ebft/rewards.py (1)
208-236:⚠️ Potential issue | 🔴 CriticalFix whitening to operate in feature space before enabling it anywhere.
This still builds
WfromU, soWis(B, B)instead of(D, D). Besides the math bug from the earlier review,EBFTMixin._feature_matching_reward()calls this withphi.shape == (num_generations, D)andphi_gt.shape == (1, D), soW @ phi_gt_fwill shape-mismatch as soon as whitening is enabled with more than one generation.Suggested fix
- U, S, _ = torch.linalg.svd(phi_f.unsqueeze(0), full_matrices=False) + _, S, Vh = torch.linalg.svd(phi_f, full_matrices=False) ... - U, S = U.squeeze(0), S.squeeze(0) - # Safe inverse of singular values s_max = S.max() inv_s = torch.where(S > whiten_tol * s_max, 1.0 / (S + 1e-12), torch.zeros_like(S)) - # FIXME - # W = U @ diag(inv_S) @ U^T - W = (U * inv_s.unsqueeze(0)) @ U.T # (B, B) - phi_w = (W @ phi_f).to(phi.dtype) - phi_gt_w = (W @ phi_gt_f).to(phi_gt.dtype) + V = Vh.transpose(-2, -1) + W = (V * inv_s.unsqueeze(0)) @ Vh # (D, D) + phi_w = (phi_f @ W.T).to(phi.dtype) + phi_gt_w = (phi_gt_f @ W.T).to(phi_gt.dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/core/trainers/ebft/rewards.py` around lines 208 - 236, The whitening builds W in sample space (B,B) because SVD was taken on phi_f (shape (B,D)), causing a shape mismatch when multiplying with phi_gt_f; fix by performing SVD in feature space so W is (D,D): compute SVD on phi_f.T (or equivalently compute eigendecomposition of phi_f.T @ phi_f) to produce U with shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @ U.T (use whiten_tol and small eps as before), then apply W @ phi_f.T (or transpose inputs appropriately) to get phi_w and phi_gt_w in the feature dimension; update the code paths around U, S, inv_s, W, phi_f, phi_gt_f and ensure EBFTMixin._feature_matching_reward() (which calls this) receives correctly-shaped outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 92-120: Pooling assumes left-aligned tokens; change pooling to use
attention_mask-derived token positions per sample: in last_token use
attention_mask to compute per-sample last valid index
(torch.where(attention_mask.bool()) grouped by batch or
attention_mask.sum(dim=1)-1) and index hidden_states accordingly instead of raw
indices; in completion_mean build comp_mask by computing valid token positions
per sample from prompt_lengths and attention_mask (i.e., find positions >=
prompt_lengths AND attention_mask==1) before mean-pooling; in concat, for each
sample compute the list of valid token indices from attention_mask (or
prompt-aware valid positions), pick quartile positions relative to that
per-sample valid-length (e.g., floor((valid_len-1)*[0.25,0.5,0.75])) and gather
hidden_states at those indices before concatenation so padding/left-padding is
never sampled. Ensure all indexing handles batched gather safely and uses
hidden_states device/dtypes.
In `@src/axolotl/core/trainers/ebft/trainer.py`:
- Around line 395-478: _sequential_rollout currently uses
self.vllm_generation.client and returns only concatenated assistant text; change
it to use self.vllm_generation.vllm_client (replace vllm_client =
self.vllm_generation.client with vllm_client = self.vllm_generation.vllm_client)
and instead of appending full_gen_text to extended_completions append the full
conversation representation (the conv list or a fully rendered conversation that
includes both user and assistant turns) so downstream reward code sees prompt +
interleaved user/assistant messages; ensure the no-remaining-turns branch
returns a full-message list consistent with the new format (e.g., original
prompt_msgs + assistant first turn) and keep vllm_client.chat call and decoding
logic (result, gen_ids, self.processing_class.decode) unchanged.
In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 897-902: The code only syncs weights when the computed mod_path
exists in lora_info, which drops trainable parameters stored under
modules_to_save.default.* (e.g., lm_head, embed_tokens). Update the conditional
around vllm_name/mod_path so it also accepts entries that were prefixed by
"modules_to_save.default.": after computing mod_path from vllm_name (and after
calling fix_name with extra_prefixes), check both mod_path and
"modules_to_save.default."+mod_path (or the original un-fixed mod_path) against
lora_info, and only continue if neither is present; this ensures
modules_to_save.default.* parameters are included in the sync.
In `@src/axolotl/monkeypatch/trainer/trl_vllm.py`:
- Around line 78-97: The current fallback loop uses MAX_PARAMS_PER_REQUEST to
split by parameter count which can still produce huge base64 JSON bodies; change
the batching in the function that builds payloads (where MAX_PARAMS_PER_REQUEST,
chunk, payload, and the self.session.post call are used) to instead accumulate
parameters until a MAX_BYTES_PER_REQUEST threshold (e.g., ~10MB) would be
exceeded, then send that batch; additionally, detect individual tensors whose
serialized byte size exceeds MAX_BYTES_PER_REQUEST and split them into smaller
slices (preserving name and adding slice metadata such as a shard index or
byte/row range) so each slice is serialized, base64-encoded and included as
separate payload entries, and ensure the shape/dtype fields reflect the slice so
the server can reassemble; replace the fixed-count loop with this byte-aware
accumulator before calling self.session.post.
- Around line 58-63: The POST calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 225-260: The loop in ebft_reasoning.py currently marks every
assistant turn as trainable (labels[]) which makes prompt_length be derived from
an earlier assistant turn; instead detect and record only the final assistant
turn span: when iterating messages use the existing tokenizer logic to compute
start/end but do not set labels for each assistant turn immediately—instead
store the last assistant span (final_start, final_end), then after the loop set
labels[i]=input_ids[i] only for i in range(final_start, min(final_end,
len(labels))) and compute prompt_length from final_start (before applying any
mask_thinking_ce modifications); apply the same change to the other occurrence
noted around the 285-290 region so only the final assistant span is treated as
the structured completion.
- Line 205: The code treats tokenizer.pad_token_id as missing when it equals 0
by using a falsy fallback expression; change the logic around pad_id in
ebft_reasoning.py so you only fall back to tokenizer.eos_token_id when
pad_token_id is actually None (or not set), e.g. replace the
`tokenizer.pad_token_id or tokenizer.eos_token_id` pattern with an explicit None
check (use tokenizer.pad_token_id if it is not None, otherwise
tokenizer.eos_token_id) where pad_id is assigned so PAD=0 remains respected.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 245-250: Add a pydantic root validator on the same config model
that defines ebft (the model containing the field "ebft: EBFTConfig | None =
Field(...)") to enforce that when rl is set to EBFT (check data.get("rl")
against the RL enum or the string "ebft"), the incoming data contains a non-null
"ebft" entry; if missing or None, raise a ValueError with a clear message (e.g.,
"ebft config is required when rl is EBFT") so parsing fails early instead of
letting downstream validators in validation.py silently bypass the requirement.
- Around line 67-157: Tighten validation on the EBFT config fields by adding
explicit enum/range constraints: restrict embed_method to the allowed values
(e.g., "last_token","mean_pooling","concat") and mode to
("structured","strided") and advantage_estimator to
("rloo","group_norm","reinforce") (use Literal or an Enum for
embed_method/mode/advantage_estimator), enforce top_p between 0.0 and 1.0
(le=1.0, ge=0.0), ensure temperature is non-negative (ge=0.0), and make integer
fields like stride, context_length, generate_max_len, n_samples_per_prompt, and
min_completion_prefix positive or non-negative as appropriate (ge=0 or ge=1).
Update the Field declarations for embed_method, mode, advantage_estimator,
top_p, temperature, stride, context_length, generate_max_len,
n_samples_per_prompt, and min_completion_prefix to include these constraints so
invalid values fail fast during model validation.
---
Outside diff comments:
In `@src/axolotl/core/trainers/grpo/async_trainer.py`:
- Around line 645-670: The patch permanently replaces VLLMGeneration._init_vllm
when _skip_nccl is true, causing later trainers to inherit the no-communicator
behavior; instead, save _orig_init_vllm, assign VLLMGeneration._init_vllm =
_init_vllm_no_communicator only for the duration of this trainer's
initialization and restore the original in a finally/cleanup block (or use a
context manager) so that the original _init_vllm is reinstated whether
initialization succeeds or raises; reference VLLMGeneration._init_vllm,
_orig_init_vllm, and _init_vllm_no_communicator and ensure restoration happens
after the trainer finishes initialization.
In `@src/axolotl/utils/callbacks/generation.py`:
- Around line 28-60: The sample-generation block is incorrectly indented under
the early-return for getattr(cfg, "generate_samples", False), making it
unreachable; dedent the entire block that calls dataloader selection and
generate_samples (the try/except fetching self.trainer.get_eval_dataloader(),
the fallback to self.trainer.get_train_dataloader(), the call to
generate_samples with model=self.trainer.model and
tokenizer=self.trainer.processing_class, and the subsequent
self._log_samples(samples, state.global_step)) so it runs only when
generate_samples is True (i.e., move that block out of the if that contains the
return).
---
Duplicate comments:
In `@src/axolotl/core/trainers/ebft/rewards.py`:
- Around line 208-236: The whitening builds W in sample space (B,B) because SVD
was taken on phi_f (shape (B,D)), causing a shape mismatch when multiplying with
phi_gt_f; fix by performing SVD in feature space so W is (D,D): compute SVD on
phi_f.T (or equivalently compute eigendecomposition of phi_f.T @ phi_f) to
produce U with shape (D,D), form inv_s from S and build W = U @ diag(inv_s) @
U.T (use whiten_tol and small eps as before), then apply W @ phi_f.T (or
transpose inputs appropriately) to get phi_w and phi_gt_w in the feature
dimension; update the code paths around U, S, inv_s, W, phi_f, phi_gt_f and
ensure EBFTMixin._feature_matching_reward() (which calls this) receives
correctly-shaped outputs.
In `@src/axolotl/prompt_strategies/ebft/ebft_reasoning.py`:
- Around line 148-151: The code uses if "prompt_msgs_snapshot" in dir() to
detect whether prompt_msgs_snapshot was set; instead, initialize
prompt_msgs_snapshot = None before the loop and change both guards in
transform_split_thinking and transform_answer_only to explicit checks
(prompt_msgs_snapshot is not None) so the branch depends on explicit state;
update all occurrences (including the similar checks around lines 172-175) to
use the sentinel instead of dir() and ensure the functions return
prompt_msgs_snapshot when not None and fall back to split_messages[:-1]
otherwise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 3bc4db9a-27a7-4043-a5b9-dba87868b9a1
📒 Files selected for processing (44)
docker/Dockerfile-cloud-uvexamples/ebft/README.mdexamples/ebft/ebft_opencode.pyexamples/ebft/ebft_pretrain.pyexamples/ebft/ebft_strided_structured.pyexamples/ebft/llama-1b-ebft-opencode-novllm.yamlexamples/ebft/llama-1b-ebft-opencode.yamlexamples/ebft/llama-1b-ebft-strided-structured.yamlexamples/ebft/llama-1b-ebft-strided.yamlexamples/ebft/llama-3b-ebft-strided-fft.yamlexamples/ebft/llama-8b-ebft-strided-fft.yamlexamples/ebft/qwen35-4b-ebft-structured-async.yamlexamples/ebft/qwen35-4b-ebft-structured.yamlexamples/ebft/qwen35-9b-ebft-structured.yamlsrc/axolotl/cli/vllm_serve.pysrc/axolotl/common/datasets.pysrc/axolotl/core/builders/rl.pysrc/axolotl/core/trainers/__init__.pysrc/axolotl/core/trainers/ebft/__init__.pysrc/axolotl/core/trainers/ebft/args.pysrc/axolotl/core/trainers/ebft/kernels.pysrc/axolotl/core/trainers/ebft/rewards.pysrc/axolotl/core/trainers/ebft/strided.pysrc/axolotl/core/trainers/ebft/trainer.pysrc/axolotl/core/trainers/grpo/async_trainer.pysrc/axolotl/integrations/diffusion/callbacks.pysrc/axolotl/monkeypatch/trainer/trl_vllm.pysrc/axolotl/prompt_strategies/ebft/__init__.pysrc/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.pysrc/axolotl/prompt_strategies/ebft/ebft_opencode.pysrc/axolotl/prompt_strategies/ebft/ebft_reasoning.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_chat.pysrc/axolotl/prompt_strategies/ebft/ebft_strided_structured.pysrc/axolotl/scripts/vllm_serve_lora.pysrc/axolotl/scripts/vllm_worker_ext.pysrc/axolotl/train.pysrc/axolotl/utils/callbacks/__init__.pysrc/axolotl/utils/callbacks/generation.pysrc/axolotl/utils/data/rl.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/enums.pysrc/axolotl/utils/schemas/trl.pysrc/axolotl/utils/schemas/validation.pysrc/axolotl/utils/schemas/vllm.py
✅ Files skipped from review due to trivial changes (21)
- docker/Dockerfile-cloud-uv
- src/axolotl/integrations/diffusion/callbacks.py
- src/axolotl/utils/schemas/enums.py
- src/axolotl/prompt_strategies/ebft/init.py
- examples/ebft/ebft_pretrain.py
- examples/ebft/ebft_opencode.py
- src/axolotl/utils/callbacks/init.py
- examples/ebft/qwen35-4b-ebft-structured.yaml
- examples/ebft/qwen35-9b-ebft-structured.yaml
- examples/ebft/llama-1b-ebft-opencode.yaml
- src/axolotl/prompt_strategies/ebft/ebft_strided_structured.py
- examples/ebft/llama-1b-ebft-strided.yaml
- examples/ebft/README.md
- src/axolotl/prompt_strategies/ebft/ebft_chat_multiturn.py
- examples/ebft/llama-1b-ebft-strided-structured.yaml
- src/axolotl/cli/vllm_serve.py
- examples/ebft/qwen35-4b-ebft-structured-async.yaml
- examples/ebft/llama-3b-ebft-strided-fft.yaml
- examples/ebft/llama-8b-ebft-strided-fft.yaml
- examples/ebft/llama-1b-ebft-opencode-novllm.yaml
- src/axolotl/core/trainers/ebft/kernels.py
🚧 Files skipped from review as they are similar to previous changes (15)
- src/axolotl/common/datasets.py
- src/axolotl/utils/schemas/vllm.py
- src/axolotl/core/trainers/init.py
- src/axolotl/utils/data/rl.py
- src/axolotl/utils/schemas/trl.py
- src/axolotl/core/builders/rl.py
- examples/ebft/ebft_strided_structured.py
- src/axolotl/scripts/vllm_serve_lora.py
- src/axolotl/prompt_strategies/ebft/ebft_opencode.py
- src/axolotl/prompt_strategies/ebft/ebft_strided_chat.py
- src/axolotl/train.py
- src/axolotl/scripts/vllm_worker_ext.py
- src/axolotl/utils/schemas/validation.py
- src/axolotl/core/trainers/ebft/init.py
- src/axolotl/core/trainers/ebft/args.py
| vllm_name = fix_name(vllm_name, extra_prefixes=["modules_to_save.default."]) | ||
|
|
||
| # Only sync weights that have LoRA adapters | ||
| mod_path = vllm_name[: -len(".weight")] | ||
| if mod_path not in lora_info: | ||
| continue |
There was a problem hiding this comment.
modules_to_save weights are still dropped from no-merge sync.
The new filter only syncs weights whose mod_path exists in lora_info. That excludes trainable modules_to_save.default.* parameters such as lm_head or embed_tokens, so server-side generation can drift from training whenever PEFT is configured with modules_to_save.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/trainers/grpo/async_trainer.py` around lines 897 - 902, The
code only syncs weights when the computed mod_path exists in lora_info, which
drops trainable parameters stored under modules_to_save.default.* (e.g.,
lm_head, embed_tokens). Update the conditional around vllm_name/mod_path so it
also accepts entries that were prefixed by "modules_to_save.default.": after
computing mod_path from vllm_name (and after calling fix_name with
extra_prefixes), check both mod_path and "modules_to_save.default."+mod_path (or
the original un-fixed mod_path) against lora_info, and only continue if neither
is present; this ensures modules_to_save.default.* parameters are included in
the sync.
| url = f"{self.base_url}/batch_update_named_params/" | ||
| response = self.session.post(url, json={"params": param_metadata}) | ||
| if response.status_code != 200: | ||
| raise Exception( | ||
| f"Request failed: {response.status_code}, {response.text}" | ||
| ) |
There was a problem hiding this comment.
Add an explicit timeout to both weight-sync POSTs.
These requests currently inherit Requests' infinite timeout. If the vLLM server wedges during metadata sync or /http_update_weights/, training can block forever on the request thread.
Suggested fix
- response = self.session.post(url, json={"params": param_metadata})
+ response = self.session.post(
+ url,
+ json={"params": param_metadata},
+ timeout=getattr(self, "connection_timeout", None),
+ )
...
- response = self.session.post(url, json={"params": payload})
+ response = self.session.post(
+ url,
+ json={"params": payload},
+ timeout=getattr(self, "connection_timeout", None),
+ )Also applies to: 96-101
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/monkeypatch/trainer/trl_vllm.py` around lines 58 - 63, The POST
calls that sync weights (the session.post to
f"{self.base_url}/batch_update_named_params/" and the other session.post to
f"{self.base_url}/http_update_weights/") currently have no timeout and can block
forever; update both calls to pass an explicit timeout (e.g., timeout=30) to
session.post, and keep the existing status_code check/Exception behavior; also
import and optionally catch requests.exceptions.Timeout around the calls in the
same function (or let it propagate) so timeouts surface deterministically.
* nemo gym integration with grpo wip * mostly working * cleanup * simplify * update docs * nemo gym support wip * cleanup * chore: lint * address PR review and add more tests * chore: lint * post merge lora fixes for CI (#3536) [skip ci] * post merge lora fixes for CI * handle lora kernel auto-enable for moe without grouped_mm * prefer not to import torch in schema validation * address pr comments, add timeout, add tests * roundup_power2_divisions not needed with newer pytorch versions (#3540) * roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg * fix: robust handling of race condition on patching check (#3543) [skip ci] * EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci] * EBFT wip * fixes * more fixeS * add missing strided module * ebft fixes for multi-turn * make ebft work with async * add example for ebft w qwen3.5 * fix for split thinking and update yaml for lora over linear attention only * enforce_eager for vllm arg in schema * fix sync weights * fix multi-gpu * handle updated sig for mm * ddp fixes * improve multi-gpu handling, don't calculate logits, adaptive completion length * chore: lint * chore: lint * support completion_mean * Address corereview feedback * clamp min IS ratio * Address PR code review * more fixes identified * address code review * Fix property from rebase conflict * fix for ebft sync and update docs * make trainer loss patch check a solo test --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary by CodeRabbit
New Features
Chores
Documentation