diff --git a/docs/cli.qmd b/docs/cli.qmd
index 2bdaf90189..b44c5db23b 100644
--- a/docs/cli.qmd
+++ b/docs/cli.qmd
@@ -137,7 +137,8 @@ lora_alpha:
### inference
-Runs inference using your trained model in either CLI or Gradio interface mode.
+Runs inference using your trained model in CLI, interactive chat, or Gradio
+interface mode.
```bash
# CLI inference with LoRA
@@ -146,6 +147,10 @@ axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
# CLI inference with full model
axolotl inference config.yml --base-model="./completed-model"
+# Interactive multi-turn chat (see the inference guide for commands)
+axolotl inference config.yml --chat \
+ --lora-model-dir="./outputs/lora-out"
+
# Gradio web interface
axolotl inference config.yml --gradio \
--lora-model-dir="./outputs/lora-out"
diff --git a/docs/inference.qmd b/docs/inference.qmd
index 6917d3c330..e977e3c3c5 100644
--- a/docs/inference.qmd
+++ b/docs/inference.qmd
@@ -35,6 +35,76 @@ axolotl inference your_config.yml --base-model="./completed-model"
:::
+### Interactive Chat {#sec-chat}
+
+For multi-turn testing of conversational models, use chat mode. The chat template
+is resolved exactly as it was during training and re-applied to the full
+conversation each turn:
+
+```{.bash}
+axolotl inference your_config.yml --chat
+```
+
+Type a message to chat. End a line with `\` to continue typing on the next line.
+Slash commands control the session:
+
+| Command | Aliases | Description |
+|---------|---------|-------------|
+| `/help` | `/?` | Show all commands |
+| `/new` | `/clear`, `/reset` | Clear the conversation (keeps system prompt and parameters) |
+| `/system [text\|clear]` | | Show, set, or clear the system prompt |
+| `/set ` | | Set a generation parameter |
+| `/status` | `/params` | Show model info and current settings |
+| `/history` | | Show the conversation so far |
+| `/retry` | `/regen` | Regenerate the last assistant reply |
+| `/undo` | | Remove the last exchange |
+| `/save [path]` | | Append the conversation as a `chat_template`-format JSONL sample |
+| `/quit` | `/exit`, `/q` | Exit |
+
+Generation parameters can also be set directly, e.g. `/temperature 0.7` (or
+`/temp 0.7`), `/top_p 0.9`, `/top_k 50`, `/max_tokens 512`, `/rep 1.05`,
+`/seed 42`. Setting `temperature` to `0` switches to greedy decoding.
+
+Press `Ctrl+C` during generation to stop the current reply; the partial response
+is kept in the conversation (diffusion replies denoise in one piece, so an
+interrupted diffusion turn is discarded instead).
+
+#### Thinking Models {#sec-chat-thinking}
+
+Thinking blocks (e.g. `...`) stream live in a small dim window,
+then collapse to a one-line summary — `/expand` shows the full reasoning of the
+last reply, and `/collapse off` switches to raw verbatim output. The per-turn
+stats split thinking from reply tokens. If the chat template supports a
+render-time thinking toggle (e.g. Qwen's `enable_thinking`), `/think off`
+disables thinking entirely from the next turn; `/think default` restores the
+template default.
+
+::: {.callout-note}
+Assistant turns are stored the way `transformers` recommends: special tokens
+are stripped and thinking is kept on a separate `reasoning_content` key (via
+the tokenizer's `parse_response` schema when it ships one, marker-splitting
+otherwise), so the chat template decides how prior-turn reasoning is
+re-rendered — matching what the model saw during training. The KV cache is
+re-used across turns whenever the rendered conversation extends the previous
+one, so long chats stay responsive.
+:::
+
+`/save` writes conversations in the `messages` format accepted by
+`type: chat_template` datasets, so a good interactive session can be turned
+directly into training data.
+
+#### Diffusion Models {#sec-chat-diffusion}
+
+With the diffusion plugin enabled, chat mode generates each reply by appending
+a masked block to the conversation and denoising it. Replies arrive in one
+piece (no token streaming), and the parameter set changes accordingly:
+`/tokens N` sets the completion block size, `/steps N` the number of denoising
+steps, and `/temperature` the denoising temperature. Defaults come from the
+`diffusion:` section of your config.
+
+Chat mode is not supported with `--prompter`; use the default inference mode
+for legacy prompters.
+
## Advanced Usage {#sec-advanced}
### Gradio Interface {#sec-gradio}
diff --git a/src/axolotl/cli/chat.py b/src/axolotl/cli/chat.py
new file mode 100644
index 0000000000..c54fe262d2
--- /dev/null
+++ b/src/axolotl/cli/chat.py
@@ -0,0 +1,1279 @@
+"""Interactive multi-turn chat CLI for a trained model."""
+
+import difflib
+import json
+import shlex
+import sys
+import threading
+import time
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Any, Callable
+
+import torch
+from rich.console import Console
+from rich.live import Live
+from rich.markup import escape
+from rich.text import Text
+from transformers import (
+ DynamicCache,
+ GenerationConfig,
+ StoppingCriteria,
+ StoppingCriteriaList,
+ TextIteratorStreamer,
+)
+
+from axolotl.cli.args import InferenceCliArgs
+from axolotl.cli.utils import load_model_and_tokenizer, resolve_chat_template_str
+from axolotl.integrations.base import PluginManager
+from axolotl.telemetry.errors import send_errors
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.logging import get_logger
+
+LOG = get_logger(__name__)
+
+USER_PROMPT = ">>> "
+CONTINUATION_PROMPT = "... "
+
+# marker pairs used by the bundled chat templates (qwen3/exaone4/phi_4 vs command_a)
+THINK_MARKER_PAIRS: tuple[tuple[str, str], ...] = (
+ ("", ""),
+ ("<|START_THINKING|>", "<|END_THINKING|>"),
+)
+DEFAULT_THINK_MARKERS = THINK_MARKER_PAIRS[0]
+
+
+def detect_think_markers(chat_template_str: str | None) -> tuple[str, str]:
+ """Picks the thinking marker pair the template works with. Called once at startup."""
+ if chat_template_str:
+ for pair in THINK_MARKER_PAIRS:
+ if pair[1] in chat_template_str:
+ return pair
+ return DEFAULT_THINK_MARKERS
+
+
+def detect_think_toggle_key(chat_template_str: str | None) -> str | None:
+ """
+ Finds the jinja variable the template uses to toggle thinking at render time
+ (`enable_thinking` in our bundled gemma4/qwen3_5 templates; `thinking` on some
+ hub templates). Called once at startup.
+ """
+ if not chat_template_str:
+ return None
+ if "enable_thinking" in chat_template_str:
+ return "enable_thinking"
+ if "thinking" in chat_template_str:
+ return "thinking"
+ return None
+
+
+@dataclass(frozen=True)
+class GenParamSpec:
+ """Specification of a runtime-adjustable generation parameter."""
+
+ key: str
+ cast: Callable
+ lo: float
+ hi: float
+ default: Any
+ aliases: tuple[str, ...] = ()
+ nullable: bool = False
+ help: str = ""
+
+
+GEN_PARAMS: tuple[GenParamSpec, ...] = (
+ GenParamSpec(
+ "temperature", float, 0.0, 5.0, 0.9, ("temp",), help="0 = greedy decoding"
+ ),
+ GenParamSpec("top_p", float, 0.0, 1.0, 0.95),
+ GenParamSpec("top_k", int, 0, 1000, 40),
+ GenParamSpec("min_p", float, 0.0, 1.0, None, nullable=True),
+ GenParamSpec("max_new_tokens", int, 1, 1_000_000, 1024, ("max_tokens", "max")),
+ GenParamSpec("repetition_penalty", float, 0.5, 3.0, 1.1, ("rep",)),
+ GenParamSpec(
+ "seed", int, 0, 2**32 - 1, None, nullable=True, help="`/set seed none` clears"
+ ),
+)
+
+
+DIFFUSION_GEN_PARAMS: tuple[GenParamSpec, ...] = (
+ GenParamSpec(
+ "temperature", float, 0.0, 5.0, 0.0, ("temp",), help="0 = greedy denoising"
+ ),
+ GenParamSpec(
+ "max_new_tokens",
+ int,
+ 1,
+ 100_000,
+ 256,
+ ("tokens", "max_tokens", "max"),
+ help="size of the denoised completion block",
+ ),
+ GenParamSpec("steps", int, 1, 10_000, 128, help="number of denoising steps"),
+ GenParamSpec(
+ "seed", int, 0, 2**32 - 1, None, nullable=True, help="`/set seed none` clears"
+ ),
+)
+
+
+def default_gen_params(
+ specs: tuple[GenParamSpec, ...] = GEN_PARAMS,
+) -> dict[str, Any]:
+ return {spec.key: spec.default for spec in specs}
+
+
+def resolve_gen_param(
+ name: str, specs: tuple[GenParamSpec, ...] = GEN_PARAMS
+) -> GenParamSpec | None:
+ for spec in specs:
+ if name == spec.key or name in spec.aliases:
+ return spec
+ return None
+
+
+def parse_gen_param_value(spec: GenParamSpec, raw: str) -> Any:
+ if spec.nullable and raw.lower() in ("none", "null", "off"):
+ return None
+ try:
+ value = spec.cast(raw)
+ except ValueError as err:
+ raise ValueError(f"{spec.key} expects a {spec.cast.__name__}") from err
+ if not spec.lo <= value <= spec.hi:
+ raise ValueError(f"{spec.key} must be in [{spec.lo}, {spec.hi}]")
+ return value
+
+
+def longest_common_prefix_len(a: list[int], b: list[int]) -> int:
+ n = min(len(a), len(b))
+ for i in range(n):
+ if a[i] != b[i]:
+ return i
+ return n
+
+
+def find_subsequence(haystack: list[int], needle: list[int], start: int = 0) -> int:
+ if not needle:
+ return -1
+ n = len(needle)
+ for i in range(start, len(haystack) - n + 1):
+ if haystack[i : i + n] == needle:
+ return i
+ return -1
+
+
+def partial_suffix_len(text: str, marker: str) -> int:
+ """Length of the longest suffix of `text` that is a proper prefix of `marker`."""
+ max_len = min(len(text), len(marker) - 1)
+ for k in range(max_len, 0, -1):
+ if text.endswith(marker[:k]):
+ return k
+ return 0
+
+
+def content_as_text(content: Any) -> str:
+ # parse_response may return content as a list of parts rather than a string
+ if isinstance(content, list):
+ return "".join(
+ part.get("text", "") for part in content if isinstance(part, dict)
+ )
+ return content or ""
+
+
+@dataclass
+class ChatSession:
+ """Holds the conversation state for a chat session."""
+
+ messages: list[dict] = field(default_factory=list)
+ system: str | None = None
+
+ def conversation(self) -> list[dict]:
+ prefix = [{"role": "system", "content": self.system}] if self.system else []
+ return prefix + self.messages
+
+ def add_user(self, content: str):
+ # merge into a trailing unanswered user message (e.g. after a failed
+ # generation) so strict templates never see consecutive user turns
+ if self.messages and self.messages[-1]["role"] == "user":
+ self.messages[-1]["content"] += "\n" + content
+ return
+ self.messages.append({"role": "user", "content": content})
+
+ def add_assistant(self, content: str):
+ self.add_assistant_message({"role": "assistant", "content": content})
+
+ def add_assistant_message(self, message: dict):
+ self.messages.append(message)
+
+ def clear(self):
+ self.messages = []
+
+ def undo(self) -> bool:
+ """Removes the last user/assistant exchange. Returns False if empty."""
+ if not self.messages:
+ return False
+ if self.messages[-1]["role"] == "assistant":
+ self.messages.pop()
+ if self.messages and self.messages[-1]["role"] == "user":
+ self.messages.pop()
+ return True
+
+ def drop_last_assistant(self) -> bool:
+ """Removes the trailing assistant message so the turn can be retried."""
+ if self.messages and self.messages[-1]["role"] == "assistant":
+ self.messages.pop()
+ return bool(self.messages) and self.messages[-1]["role"] == "user"
+
+ def save_jsonl(self, path: str):
+ # content-parts format: text-only today, but matches the multimodal
+ # dataset format so saved sessions stay usable as training data
+ messages = []
+ for message in self.conversation():
+ content = message.get("content")
+ parts = (
+ content
+ if isinstance(content, list)
+ else [{"type": "text", "text": content or ""}]
+ )
+ out = {
+ "role": message["role"],
+ "content": parts,
+ }
+ for key in ("reasoning_content", "thinking", "tool_calls"):
+ if message.get(key):
+ out[key] = message[key]
+ messages.append(out)
+ with open(path, "a", encoding="utf-8") as file:
+ file.write(json.dumps({"messages": messages}, ensure_ascii=False) + "\n")
+
+
+@dataclass
+class TurnResult:
+ """Result of generating a single assistant turn."""
+
+ content: str
+ message: dict | None = None
+ interrupted: bool = False
+ prompt_tokens: int = 0
+ reused_tokens: int = 0
+ new_tokens: int = 0
+ thinking_tokens: int = 0
+ response_tokens: int = 0
+ seconds: float = 0.0
+
+
+class _StopOnEvent(StoppingCriteria):
+ """Stops generation when the given event is set (e.g. on Ctrl+C)."""
+
+ def __init__(self, event: threading.Event):
+ self.event = event
+
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
+ return self.event.is_set()
+
+
+class TurnGenerator:
+ """Base for assistant-turn generators: template rendering and EOS handling."""
+
+ def __init__(self, model, tokenizer, chat_template_str: str | None, device):
+ self.model = model
+ self.tokenizer = tokenizer
+ self.chat_template_str = chat_template_str
+ self.device = device
+ self.think_markers = detect_think_markers(
+ chat_template_str or getattr(tokenizer, "chat_template", None)
+ )
+ self._think_marker_ids: tuple[list[int], list[int]] | None = None
+ self._eos_strings: tuple[str, ...] | None = None
+
+ self.eos_token_ids: set[int] = set()
+ if tokenizer.eos_token_id is not None:
+ self.eos_token_ids.add(tokenizer.eos_token_id)
+ config_eos = getattr(model.generation_config, "eos_token_id", None)
+ if isinstance(config_eos, int):
+ self.eos_token_ids.add(config_eos)
+ elif isinstance(config_eos, (list, tuple)):
+ self.eos_token_ids.update(config_eos)
+
+ def render(
+ self, conversation: list[dict], render_kwargs: dict | None = None
+ ) -> list[int]:
+ kwargs = dict(render_kwargs or {})
+ if self.chat_template_str:
+ kwargs["chat_template"] = self.chat_template_str
+ batch = self.tokenizer.apply_chat_template(
+ conversation,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ **kwargs,
+ )
+ return list(batch["input_ids"])
+
+ def _split_think_token_ids(
+ self, generated: list[int]
+ ) -> tuple[list[int], list[int]]:
+ """Splits a generated sequence into (thinking, response) token ids."""
+ if self._think_marker_ids is None:
+ try:
+ self._think_marker_ids = (
+ self.tokenizer.encode(
+ self.think_markers[0], add_special_tokens=False
+ ),
+ self.tokenizer.encode(
+ self.think_markers[1], add_special_tokens=False
+ ),
+ )
+ except Exception: # pylint: disable=broad-exception-caught
+ self._think_marker_ids = ([], [])
+
+ open_ids, close_ids = self._think_marker_ids
+ i = find_subsequence(generated, open_ids)
+ if i < 0:
+ return [], generated
+ j = find_subsequence(generated, close_ids, i + len(open_ids))
+ if j < 0:
+ return generated[i + len(open_ids) :], generated[:i]
+ thinking = generated[i + len(open_ids) : j]
+ response = generated[:i] + generated[j + len(close_ids) :]
+ return thinking, response
+
+ def split_think_token_counts(self, generated: list[int]) -> tuple[int, int]:
+ """Returns (thinking, response) token counts for a generated sequence."""
+ thinking, response = self._split_think_token_ids(generated)
+ return len(thinking), len(response)
+
+ def build_assistant_message(
+ self,
+ generated: list[int],
+ split: tuple[list[int], list[int]] | None = None,
+ ) -> dict:
+ """
+ Parses generated token ids into an assistant message dict. Prefers the
+ tokenizer's own `parse_response` schema (transformers v5); otherwise splits
+ thinking out of the content by marker and stores it under
+ `reasoning_content`, the key the bundled chat templates read. Special
+ tokens are kept out of the stored text either way — the template re-adds
+ them on render.
+ """
+ if getattr(self.tokenizer, "response_schema", None):
+ try:
+ text = self.tokenizer.decode(generated, skip_special_tokens=False)
+ message = self.tokenizer.parse_response(text)
+ if isinstance(message, dict):
+ message.setdefault("role", "assistant")
+ message.setdefault("content", "")
+ return message
+ except Exception: # pylint: disable=broad-exception-caught
+ LOG.warning(
+ "tokenizer.parse_response failed; falling back to marker split",
+ exc_info=True,
+ )
+
+ thinking_ids, response_ids = (
+ split if split is not None else self._split_think_token_ids(generated)
+ )
+ parsed: dict[str, Any] = {
+ "role": "assistant",
+ "content": self.tokenizer.decode(
+ response_ids, skip_special_tokens=True
+ ).strip(),
+ }
+ if thinking_ids:
+ parsed["reasoning_content"] = self.tokenizer.decode(
+ thinking_ids, skip_special_tokens=True
+ ).strip()
+ return parsed
+
+ def eos_strings(self) -> tuple[str, ...]:
+ if self._eos_strings is None:
+ self._eos_strings = tuple(
+ text
+ for token_id in sorted(self.eos_token_ids)
+ if (text := self.tokenizer.decode([token_id]))
+ )
+ return self._eos_strings
+
+ def generate_turn(
+ self,
+ conversation: list[dict],
+ params: dict[str, Any],
+ on_text: Callable[[str], None],
+ render_kwargs: dict | None = None,
+ ) -> TurnResult:
+ raise NotImplementedError
+
+
+class EosTextTrimmer:
+ """
+ Filters streamed text so terminal EOS markers (e.g. `<|im_end|>`) never reach
+ the display. Text that could be the start of an EOS string is held back until
+ disambiguated by the next chunk.
+ """
+
+ def __init__(self, eos_strings: tuple[str, ...], emit: Callable[[str], None]):
+ self.eos_strings = tuple(s for s in eos_strings if s)
+ self.emit = emit
+ self.pending = ""
+ self.done = False
+
+ def feed(self, text: str):
+ if self.done or not text:
+ return
+ self.pending += text
+ positions = [
+ idx for s in self.eos_strings if (idx := self.pending.find(s)) >= 0
+ ]
+ if positions:
+ if min(positions) > 0:
+ self.emit(self.pending[: min(positions)])
+ self.pending = ""
+ self.done = True
+ return
+ hold = max(
+ (partial_suffix_len(self.pending, s) for s in self.eos_strings),
+ default=0,
+ )
+ if len(self.pending) > hold:
+ self.emit(self.pending[: len(self.pending) - hold])
+ self.pending = self.pending[len(self.pending) - hold :]
+
+ def finish(self):
+ if not self.done and self.pending:
+ self.emit(self.pending)
+ self.pending = ""
+
+
+class CausalTurnGenerator(TurnGenerator):
+ """
+ Generates assistant turns with `model.generate`, re-using the KV cache across
+ turns when the rendered conversation extends the previously cached tokens.
+ """
+
+ def __init__(self, model, tokenizer, chat_template_str: str | None, device):
+ super().__init__(model, tokenizer, chat_template_str, device)
+ self._cache: DynamicCache | None = None
+ self._cached_ids: list[int] = []
+
+ def reset_cache(self):
+ self._cache = None
+ self._cached_ids = []
+
+ def _new_cache(self) -> DynamicCache:
+ return DynamicCache(config=self.model.config)
+
+ def _prepare_cache(self, ids: list[int]) -> int:
+ """
+ Crops or resets the cross-turn cache so it holds a strict prefix of `ids`.
+ Chat templates may rewrite earlier turns when re-rendering (e.g. stripping
+ prior-turn thinking blocks), so reuse is gated on a token-level prefix
+ check rather than assumed.
+
+ Returns the number of re-used prefix tokens.
+ """
+ common = longest_common_prefix_len(self._cached_ids, ids)
+ # generate() needs at least one uncached input token
+ keep = min(common, len(ids) - 1)
+
+ if self._cache is None or keep <= 0:
+ self._cache = self._new_cache()
+ self._cached_ids = []
+ return 0
+
+ if keep < len(self._cached_ids):
+ try:
+ self._cache.crop(keep)
+ self._cached_ids = self._cached_ids[:keep]
+ except Exception: # pylint: disable=broad-exception-caught
+ # some cache layer types (e.g. sliding window) cannot crop
+ self._cache = self._new_cache()
+ self._cached_ids = []
+ return 0
+
+ return len(self._cached_ids)
+
+ def _build_generation_config(self, params: dict[str, Any]) -> GenerationConfig:
+ do_sample = params["temperature"] > 0
+ kwargs: dict[str, Any] = {
+ "max_new_tokens": params["max_new_tokens"],
+ "repetition_penalty": params["repetition_penalty"],
+ "do_sample": do_sample,
+ "bos_token_id": self.tokenizer.bos_token_id,
+ "eos_token_id": sorted(self.eos_token_ids) or None,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "use_cache": True,
+ "return_dict_in_generate": True,
+ }
+ if do_sample:
+ kwargs["temperature"] = params["temperature"]
+ kwargs["top_p"] = params["top_p"]
+ kwargs["top_k"] = params["top_k"]
+ if params["min_p"] is not None:
+ kwargs["min_p"] = params["min_p"]
+ return GenerationConfig(**kwargs)
+
+ def generate_turn(
+ self,
+ conversation: list[dict],
+ params: dict[str, Any],
+ on_text: Callable[[str], None],
+ render_kwargs: dict | None = None,
+ ) -> TurnResult:
+ ids = self.render(conversation, render_kwargs)
+ reused = self._prepare_cache(ids)
+ cache = self._cache
+ assert cache is not None
+
+ if params["seed"] is not None:
+ torch.manual_seed(params["seed"])
+
+ input_ids = torch.tensor([ids], dtype=torch.long, device=self.device)
+ attention_mask = torch.ones_like(input_ids)
+ generation_config = self._build_generation_config(params)
+
+ stop_event = threading.Event()
+ streamer = TextIteratorStreamer(
+ self.tokenizer, skip_prompt=True, skip_special_tokens=False
+ )
+ holder: dict[str, Any] = {}
+
+ def _worker():
+ try:
+ with torch.no_grad():
+ holder["output"] = self.model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ generation_config=generation_config,
+ streamer=streamer,
+ stopping_criteria=StoppingCriteriaList(
+ [_StopOnEvent(stop_event)]
+ ),
+ past_key_values=cache,
+ )
+ except Exception as err: # pylint: disable=broad-exception-caught
+ holder["error"] = err
+ streamer.end()
+
+ start = time.monotonic()
+ thread = threading.Thread(target=_worker, daemon=True)
+ thread.start()
+
+ trimmer = EosTextTrimmer(self.eos_strings(), on_text)
+ interrupted = False
+ try:
+ for text in streamer:
+ trimmer.feed(text)
+ except KeyboardInterrupt:
+ interrupted = True
+ stop_event.set()
+ for text in streamer:
+ trimmer.feed(text)
+ finally:
+ # a 2nd Ctrl+C can escape the drain; join so the worker stops writing the cache
+ stop_event.set()
+ thread.join()
+ trimmer.finish()
+ seconds = time.monotonic() - start
+
+ if "error" in holder:
+ self.reset_cache()
+ raise holder["error"]
+
+ sequence = holder["output"].sequences[0].tolist()
+ self._cached_ids = sequence[: cache.get_seq_length()]
+
+ generated = sequence[len(ids) :]
+ while generated and generated[-1] in self.eos_token_ids:
+ generated.pop()
+ thinking_ids, response_ids = self._split_think_token_ids(generated)
+ message = self.build_assistant_message(generated, (thinking_ids, response_ids))
+
+ return TurnResult(
+ content=message.get("content") or "",
+ message=message,
+ interrupted=interrupted,
+ prompt_tokens=len(ids),
+ reused_tokens=reused,
+ new_tokens=len(generated),
+ thinking_tokens=len(thinking_ids),
+ response_tokens=len(response_ids),
+ seconds=seconds,
+ )
+
+
+class DiffusionTurnGenerator(TurnGenerator):
+ """
+ Generates assistant turns for diffusion LMs by appending a masked completion
+ block to the rendered conversation and denoising it. The whole block resolves
+ at once, so the reply is emitted in one piece rather than streamed.
+ """
+
+ def __init__(
+ self, model, tokenizer, chat_template_str: str | None, device, mask_token_id
+ ):
+ super().__init__(model, tokenizer, chat_template_str, device)
+ self.mask_token_id = int(mask_token_id)
+
+ def generate_turn(
+ self,
+ conversation: list[dict],
+ params: dict[str, Any],
+ on_text: Callable[[str], None],
+ render_kwargs: dict | None = None,
+ ) -> TurnResult:
+ from axolotl.integrations.diffusion import generate as diffusion_generate
+
+ ids = self.render(conversation, render_kwargs)
+ if params["seed"] is not None:
+ torch.manual_seed(params["seed"])
+
+ sequence = torch.tensor([ids], dtype=torch.long, device=self.device)
+
+ start = time.monotonic()
+ with torch.no_grad():
+ result = diffusion_generate(
+ self.model,
+ self.tokenizer,
+ original_sequence=sequence,
+ num_diffusion_steps=params["steps"],
+ temperature=params["temperature"],
+ mask_token_id=self.mask_token_id,
+ mode="completion",
+ completion_tokens=params["max_new_tokens"],
+ )
+ seconds = time.monotonic() - start
+
+ generated = result["generated_ids"][len(ids) :]
+ for i, token_id in enumerate(generated):
+ if token_id in self.eos_token_ids:
+ generated = generated[:i]
+ break
+ content = self.tokenizer.decode(generated, skip_special_tokens=False)
+ on_text(content)
+ thinking_ids, response_ids = self._split_think_token_ids(generated)
+
+ return TurnResult(
+ content=content,
+ message=self.build_assistant_message(
+ generated, (thinking_ids, response_ids)
+ ),
+ prompt_tokens=len(ids),
+ new_tokens=len(generated),
+ thinking_tokens=len(thinking_ids),
+ response_tokens=len(response_ids),
+ seconds=seconds,
+ )
+
+
+class ThinkStreamRenderer:
+ """
+ Renders one streamed turn. When collapsing, thinking is shown as a rolling
+ dim tail in a live region (so it never enters scrollback) and replaced by a
+ one-line summary when the block closes; the reply streams normally. When
+ collapse is off, this is a plain passthrough print.
+ """
+
+ LIVE_FPS = 12
+
+ def __init__(
+ self,
+ console: Console,
+ collapse: bool,
+ markers: tuple[str, str] = DEFAULT_THINK_MARKERS,
+ tail_lines: int = 6,
+ ):
+ self.console = console
+ self.collapse = collapse
+ self.open_marker, self.close_marker = markers
+ self.tail_lines = tail_lines
+ self.think_text = ""
+ self._mode = "detect"
+ self._pending = ""
+ self._start = time.monotonic()
+ self._live: Live | None = None
+ self._last_live_update = 0.0
+
+ def feed(self, text: str):
+ if not self.collapse:
+ print(text, end="", flush=True)
+ return
+ self._pending += text
+ self._process()
+
+ def finish(self, interrupted: bool = False):
+ if not self.collapse:
+ return
+ if self._mode == "think":
+ self.think_text += self._pending
+ self._pending = ""
+ reason = "interrupted" if interrupted else f"no {self.close_marker}"
+ self._end_think(f" ({escape(reason)})")
+ elif self._pending:
+ print(self._pending, end="", flush=True)
+ self._pending = ""
+
+ def _process(self):
+ while True:
+ if self._mode == "detect":
+ stripped = self._pending.lstrip()
+ if stripped.startswith(self.open_marker):
+ idx = self._pending.find(self.open_marker)
+ self._pending = self._pending[idx + len(self.open_marker) :]
+ self._mode = "think"
+ self._live = Live(
+ Text(""),
+ console=self.console,
+ refresh_per_second=self.LIVE_FPS,
+ transient=True,
+ )
+ self._live.start()
+ continue
+ if not stripped or self.open_marker.startswith(stripped):
+ return # could still be a marker prefix; wait for more text
+ self._mode = "reply"
+ continue
+
+ if self._mode == "think":
+ idx = self._pending.find(self.close_marker)
+ if idx >= 0:
+ self.think_text += self._pending[:idx]
+ self._pending = self._pending[
+ idx + len(self.close_marker) :
+ ].lstrip("\n")
+ self._end_think()
+ self._mode = "reply"
+ continue
+ keep = partial_suffix_len(self._pending, self.close_marker)
+ emit_until = len(self._pending) - keep
+ self.think_text += self._pending[:emit_until]
+ self._pending = self._pending[emit_until:]
+ self._update_live()
+ return
+
+ # reply mode
+ if self._pending:
+ print(self._pending, end="", flush=True)
+ self._pending = ""
+ return
+
+ def _update_live(self):
+ if self._live is None:
+ return
+ # Live repaints at LIVE_FPS; building renderables faster than that is wasted
+ now = time.monotonic()
+ if now - self._last_live_update < 1 / self.LIVE_FPS:
+ return
+ self._last_live_update = now
+ lines = self.think_text.splitlines()[-self.tail_lines :]
+ self._live.update(Text("\n".join(lines), style="dim"))
+
+ def _end_think(self, note: str = ""):
+ if self._live is not None:
+ self._live.stop()
+ self._live = None
+ seconds = time.monotonic() - self._start
+ self.console.print(
+ f"[dim]▸ thought for {seconds:.1f}s{note} · /expand to view[/dim]"
+ )
+
+
+@dataclass(frozen=True)
+class Command:
+ """A slash command with its aliases and handler."""
+
+ name: str
+ handler: str
+ help: str
+ aliases: tuple[str, ...] = ()
+ usage: str = ""
+
+
+COMMANDS: tuple[Command, ...] = (
+ Command("help", "cmd_help", "show this help", ("?",)),
+ Command(
+ "new",
+ "cmd_new",
+ "clear the conversation (keeps system prompt and params)",
+ ("clear", "reset"),
+ ),
+ Command(
+ "system",
+ "cmd_system",
+ "show, set, or clear the system prompt",
+ usage="/system [text|clear]",
+ ),
+ Command(
+ "set",
+ "cmd_set",
+ "set a generation parameter",
+ usage="/set ",
+ ),
+ Command("status", "cmd_status", "show model and generation settings", ("params",)),
+ Command("history", "cmd_history", "show the conversation so far"),
+ Command("retry", "cmd_retry", "regenerate the last assistant reply", ("regen",)),
+ Command("undo", "cmd_undo", "remove the last exchange"),
+ Command(
+ "save",
+ "cmd_save",
+ "append conversation as a chat_template-format JSONL sample",
+ usage="/save [path]",
+ ),
+ Command(
+ "think",
+ "cmd_think",
+ "toggle template-level thinking, if the template supports it",
+ usage="/think [on|off|default]",
+ ),
+ Command(
+ "collapse",
+ "cmd_collapse",
+ "collapse thinking blocks in the display",
+ usage="/collapse [on|off]",
+ ),
+ Command("expand", "cmd_expand", "show the hidden thinking from the last reply"),
+ Command("quit", "cmd_quit", "exit chat", ("exit", "q")),
+)
+
+
+def resolve_command(name: str) -> Command | None:
+ for command in COMMANDS:
+ if name == command.name or name in command.aliases:
+ return command
+ return None
+
+
+class ChatRepl:
+ """Interactive chat loop: slash commands plus streamed model turns."""
+
+ def __init__(
+ self,
+ *,
+ generator: TurnGenerator,
+ session: ChatSession | None = None,
+ params: dict[str, Any] | None = None,
+ param_specs: tuple[GenParamSpec, ...] = GEN_PARAMS,
+ console: Console | None = None,
+ banner: dict[str, str] | None = None,
+ input_fn: Callable[[str], str] | None = None,
+ think_toggle_key: str | None = None,
+ collapse_thinking: bool = True,
+ ):
+ self.generator = generator
+ self.session = session or ChatSession()
+ self.param_specs = param_specs
+ self.params = params or default_gen_params(param_specs)
+ self.console = console or Console()
+ self.banner = banner or {}
+ self.input_fn = input_fn or input
+ self.think_toggle_key = think_toggle_key
+ self.collapse_thinking = collapse_thinking
+ self.render_kwargs: dict[str, Any] = {}
+ self.last_think_text: str | None = None
+
+ def run(self):
+ self._print_banner()
+ while True:
+ try:
+ line = self._read_line()
+ except EOFError:
+ break
+ except KeyboardInterrupt:
+ self.console.print("\n[dim]Use /quit to exit.[/dim]")
+ continue
+
+ line = line.strip()
+ if not line:
+ continue
+
+ if line.startswith("/"):
+ try:
+ action = self._dispatch(line)
+ except Exception as err: # pylint: disable=broad-exception-caught
+ self.console.print(f"[red]Command failed: {escape(str(err))}[/red]")
+ continue
+ if action == "quit":
+ break
+ if action == "regenerate":
+ self._generate_turn()
+ continue
+
+ self.session.add_user(line)
+ self._generate_turn()
+
+ def _read_line(self) -> str:
+ parts = []
+ prompt = USER_PROMPT
+ while True:
+ line = self.input_fn(prompt)
+ if line.endswith("\\"):
+ parts.append(line[:-1])
+ prompt = CONTINUATION_PROMPT
+ continue
+ parts.append(line)
+ break
+ return "\n".join(parts)
+
+ def _dispatch(self, line: str) -> str | None:
+ name, _, args = line[1:].partition(" ")
+ name = name.lower()
+ args = args.strip()
+
+ command = resolve_command(name)
+ if command:
+ return getattr(self, command.handler)(args)
+
+ # bare parameter shortcuts: /temp 0.7, /top_p 0.9, ...
+ spec = resolve_gen_param(name, self.param_specs)
+ if spec:
+ return self.cmd_set(f"{spec.key} {args}" if args else spec.key)
+
+ candidates = [
+ alias for command in COMMANDS for alias in (command.name, *command.aliases)
+ ] + [alias for spec in self.param_specs for alias in (spec.key, *spec.aliases)]
+ close = difflib.get_close_matches(name, candidates, n=1)
+ hint = f" Did you mean /{close[0]}?" if close else ""
+ self.console.print(f"[red]Unknown command /{name}.[/red]{hint}")
+ return None
+
+ def _generate_turn(self):
+ renderer = ThinkStreamRenderer(
+ self.console,
+ collapse=self.collapse_thinking,
+ markers=getattr(self.generator, "think_markers", DEFAULT_THINK_MARKERS),
+ )
+ try:
+ result = self.generator.generate_turn(
+ self.session.conversation(),
+ self.params,
+ renderer.feed,
+ render_kwargs=self.render_kwargs or None,
+ )
+ except KeyboardInterrupt:
+ # an escaped interrupt leaves the cache out of sync with _cached_ids
+ reset_cache = getattr(self.generator, "reset_cache", None)
+ if callable(reset_cache):
+ reset_cache()
+ renderer.finish(interrupted=True)
+ print()
+ self.console.print(
+ "[dim]Interrupted; reply discarded. Your message is kept —"
+ " /retry regenerates, /undo removes it.[/dim]"
+ )
+ return
+ except Exception as err: # pylint: disable=broad-exception-caught
+ renderer.finish(interrupted=True)
+ self.console.print(f"\n[red]Generation failed: {escape(str(err))}[/red]")
+ self.console.print(
+ "[dim]Your message is kept — /retry regenerates, /undo removes it.[/dim]"
+ )
+ return
+
+ renderer.finish(interrupted=result.interrupted)
+ message = result.message or {"role": "assistant", "content": result.content}
+ self.last_think_text = (
+ renderer.think_text.strip()
+ or message.get("reasoning_content")
+ or message.get("thinking")
+ or None
+ )
+
+ print()
+ self.session.add_assistant_message(message)
+ token_summary = f"{result.new_tokens} tokens"
+ if result.thinking_tokens:
+ token_summary += (
+ f" ({result.thinking_tokens} thinking · {result.response_tokens} reply)"
+ )
+ stats = (
+ f"{token_summary} · {result.seconds:.1f}s · "
+ f"{result.prompt_tokens} prompt ({result.reused_tokens} cached)"
+ )
+ if result.interrupted:
+ stats += " · interrupted (partial reply kept; /retry to regenerate)"
+ self.console.print(f"[dim]{stats}[/dim]")
+
+ def _print_banner(self):
+ self.console.print("[bold]axolotl chat[/bold]")
+ for key, value in self.banner.items():
+ self.console.print(f"[dim]{key}:[/dim] {escape(value)}")
+ self.console.print(
+ "[dim]Type a message to chat, /help for commands, \\ at line end to"
+ " continue on the next line.[/dim]"
+ )
+
+ # --- command handlers (return "quit", "regenerate", or None) ---
+
+ def cmd_help(self, _args: str) -> None:
+ for command in COMMANDS:
+ names = "/" + command.name
+ if command.aliases:
+ names += " (" + ", ".join("/" + a for a in command.aliases) + ")"
+ usage = f" — {command.usage}" if command.usage else ""
+ self.console.print(f" [bold]{names}[/bold]: {command.help}{usage}")
+ params = ", ".join(
+ "/" + s.key + ("".join(f" /{a}" for a in s.aliases))
+ for s in self.param_specs
+ )
+ self.console.print(f" parameter shortcuts: {params}")
+ return None
+
+ def cmd_new(self, _args: str) -> None:
+ self.session.clear()
+ self.last_think_text = None
+ reset_cache = getattr(self.generator, "reset_cache", None)
+ if callable(reset_cache):
+ reset_cache()
+ self.console.print("[dim]Conversation cleared.[/dim]")
+ return None
+
+ def cmd_system(self, args: str) -> None:
+ if not args:
+ if self.session.system:
+ self.console.print(escape(self.session.system))
+ else:
+ self.console.print("[dim]No system prompt set.[/dim]")
+ return None
+ if args.lower() == "clear":
+ self.session.system = None
+ self.console.print("[dim]System prompt cleared.[/dim]")
+ return None
+ self.session.system = args
+ self.console.print("[dim]System prompt set.[/dim]")
+ return None
+
+ def cmd_set(self, args: str) -> str | None:
+ tokens = args.replace("=", " ").split()
+ if len(tokens) != 2:
+ self.console.print("[red]Usage: /set [/red]")
+ return None
+ spec = resolve_gen_param(tokens[0].lower(), self.param_specs)
+ if not spec:
+ valid = ", ".join(s.key for s in self.param_specs)
+ self.console.print(f"[red]Unknown parameter. Valid: {valid}[/red]")
+ return None
+ try:
+ self.params[spec.key] = parse_gen_param_value(spec, tokens[1])
+ except ValueError as err:
+ self.console.print(f"[red]{err}[/red]")
+ return None
+ self.console.print(f"[dim]{spec.key} = {self.params[spec.key]}[/dim]")
+ return None
+
+ def cmd_status(self, _args: str) -> None:
+ for key, value in self.banner.items():
+ self.console.print(f"[dim]{key}:[/dim] {escape(value)}")
+ for spec in self.param_specs:
+ self.console.print(f"[dim]{spec.key}:[/dim] {self.params[spec.key]}")
+ n_messages = len(self.session.messages)
+ self.console.print(f"[dim]messages:[/dim] {n_messages}")
+ return None
+
+ def cmd_history(self, _args: str) -> None:
+ conversation = self.session.conversation()
+ if not conversation:
+ self.console.print("[dim]No messages yet.[/dim]")
+ return None
+ for message in conversation:
+ self.console.print(f"[bold]{message['role']}:[/bold]")
+ reasoning = message.get("reasoning_content") or message.get("thinking")
+ if reasoning:
+ self.console.print(escape(content_as_text(reasoning)), style="dim")
+ self.console.print(escape(content_as_text(message.get("content"))))
+ return None
+
+ def cmd_retry(self, _args: str) -> str | None:
+ if not self.session.drop_last_assistant():
+ self.console.print("[dim]Nothing to retry yet.[/dim]")
+ return None
+ return "regenerate"
+
+ def cmd_undo(self, _args: str) -> None:
+ if self.session.undo():
+ self.last_think_text = None
+ self.console.print("[dim]Removed last exchange.[/dim]")
+ else:
+ self.console.print("[dim]Nothing to undo.[/dim]")
+ return None
+
+ def cmd_save(self, args: str) -> None:
+ if not self.session.messages:
+ self.console.print("[dim]Nothing to save yet.[/dim]")
+ return None
+ path = (
+ shlex.split(args)[0]
+ if args
+ else f"chat-{datetime.now().strftime('%Y%m%d-%H%M%S')}.jsonl"
+ )
+ self.session.save_jsonl(path)
+ self.console.print(f"[dim]Saved conversation to {path}[/dim]")
+ return None
+
+ def cmd_think(self, args: str) -> None:
+ if not args:
+ current = self.render_kwargs.get(self.think_toggle_key or "", "default")
+ self.console.print(f"[dim]template thinking: {current}[/dim]")
+ return None
+ if not self.think_toggle_key:
+ self.console.print(
+ "[dim]This chat template has no thinking toggle; thinking is"
+ " controlled by the model/template itself.[/dim]"
+ )
+ return None
+ value = args.lower()
+ if value in ("on", "true"):
+ self.render_kwargs[self.think_toggle_key] = True
+ elif value in ("off", "false"):
+ self.render_kwargs[self.think_toggle_key] = False
+ elif value == "default":
+ self.render_kwargs.pop(self.think_toggle_key, None)
+ else:
+ self.console.print("[red]Usage: /think [on|off|default][/red]")
+ return None
+ current = self.render_kwargs.get(self.think_toggle_key, "default")
+ self.console.print(
+ f"[dim]{self.think_toggle_key} = {current} (applies from next turn)[/dim]"
+ )
+ return None
+
+ def cmd_collapse(self, args: str) -> None:
+ value = args.lower()
+ if value in ("on", "true", ""):
+ self.collapse_thinking = True
+ elif value in ("off", "false"):
+ self.collapse_thinking = False
+ else:
+ self.console.print("[red]Usage: /collapse [on|off][/red]")
+ return None
+ state = "collapsed" if self.collapse_thinking else "shown raw"
+ self.console.print(f"[dim]Thinking blocks will be {state}.[/dim]")
+ return None
+
+ def cmd_expand(self, _args: str) -> None:
+ if self.last_think_text:
+ self.console.print(escape(self.last_think_text), style="dim")
+ else:
+ self.console.print(
+ "[dim]No hidden thinking recorded for the last reply.[/dim]"
+ )
+ return None
+
+ def cmd_quit(self, _args: str) -> str:
+ return "quit"
+
+
+def _build_banner(cfg: DictDefault) -> dict[str, str]:
+ banner = {"model": str(cfg.base_model)}
+
+ if cfg.lora_model_dir:
+ banner["adapter"] = f"{cfg.adapter or 'lora'} from {cfg.lora_model_dir}"
+ elif cfg.adapter:
+ banner["adapter"] = str(cfg.adapter)
+
+ quant = []
+ if cfg.load_in_4bit:
+ quant.append("4-bit (bnb)")
+ if cfg.load_in_8bit:
+ quant.append("8-bit (bnb)")
+ if cfg.qat:
+ quant.append("QAT fake-quant active")
+ if quant:
+ banner["quantization"] = ", ".join(quant)
+
+ if cfg.chat_template:
+ template_name = getattr(cfg.chat_template, "value", cfg.chat_template)
+ banner["chat template"] = f"config ({template_name})"
+ elif cfg.datasets and cfg.datasets[0].type == "chat_template":
+ banner["chat template"] = "dataset config"
+ else:
+ banner["chat template"] = "tokenizer default"
+
+ return banner
+
+
+@send_errors
+def do_chat(
+ *,
+ cfg: DictDefault,
+ cli_args: InferenceCliArgs,
+):
+ """
+ Runs an interactive multi-turn chat session on the command line. The chat
+ template is applied to the full conversation each turn, and generation
+ parameters can be adjusted at runtime via slash commands.
+
+ Args:
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ cli_args: Inference-specific CLI arguments.
+ """
+ if cli_args.prompter:
+ raise ValueError(
+ "--chat does not support --prompter; legacy prompters are single-turn."
+ " Use the default inference mode instead."
+ )
+
+ plugin_manager = PluginManager.get_instance()
+ is_diffusion = any(
+ plugin.__class__.__name__ == "DiffusionPlugin"
+ for plugin in plugin_manager.plugins.values()
+ )
+
+ if not sys.stdin.isatty():
+ raise ValueError(
+ "--chat requires an interactive terminal. For piped input, use the"
+ " default inference mode."
+ )
+
+ try:
+ import readline # noqa: F401 pylint: disable=unused-import
+ except ImportError:
+ pass
+
+ model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
+ if cfg.is_multimodal:
+ LOG.warning(
+ "Multimodal attachments are not supported in chat mode yet;"
+ " proceeding with text-only chat."
+ )
+
+ chat_template_str = resolve_chat_template_str(cfg, tokenizer)
+ if not chat_template_str and not tokenizer.chat_template:
+ raise ValueError(
+ "Chat mode requires a chat template. Set `chat_template` in your config"
+ " or use a tokenizer that provides one."
+ )
+
+ model = model.to(cfg.device, dtype=cfg.torch_dtype)
+ model.eval()
+
+ banner = _build_banner(cfg)
+ generator: TurnGenerator
+ param_specs = GEN_PARAMS
+
+ if is_diffusion:
+ from axolotl.integrations.diffusion import resolve_mask_token_id
+
+ mask_token_id = resolve_mask_token_id(tokenizer, cfg, allow_add=False)
+ generator = DiffusionTurnGenerator(
+ model, tokenizer, chat_template_str, cfg.device, mask_token_id
+ )
+ param_specs = DIFFUSION_GEN_PARAMS
+ params = default_gen_params(param_specs)
+ if cfg.diffusion.num_diffusion_steps:
+ params["steps"] = cfg.diffusion.num_diffusion_steps
+ if cfg.diffusion.generation_temperature is not None:
+ params["temperature"] = cfg.diffusion.generation_temperature
+ banner["mode"] = "diffusion (completion-block denoising)"
+ else:
+ generator = CausalTurnGenerator(model, tokenizer, chat_template_str, cfg.device)
+ params = default_gen_params(param_specs)
+
+ think_toggle_key = detect_think_toggle_key(
+ chat_template_str or tokenizer.chat_template
+ )
+ repl = ChatRepl(
+ generator=generator,
+ params=params,
+ param_specs=param_specs,
+ banner=banner,
+ think_toggle_key=think_toggle_key,
+ )
+ repl.run()
diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py
index 123e093c0c..aa9efe50fb 100644
--- a/src/axolotl/cli/config.py
+++ b/src/axolotl/cli/config.py
@@ -342,7 +342,7 @@ def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
- except RuntimeError:
+ except (RuntimeError, AssertionError):
return False
diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py
index cafa0f4eff..c6fd8d1dd8 100644
--- a/src/axolotl/cli/inference.py
+++ b/src/axolotl/cli/inference.py
@@ -13,16 +13,13 @@
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
-from axolotl.cli.utils import load_model_and_tokenizer
+from axolotl.cli.utils import load_model_and_tokenizer, resolve_chat_template_str
from axolotl.cli.utils.diffusion import (
diffusion_inference,
launch_diffusion_gradio_ui,
)
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
-from axolotl.utils.chat_templates import (
- get_chat_template_from_config,
-)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -70,14 +67,8 @@ def do_inference(
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
- elif cfg.chat_template:
- chat_template_str = get_chat_template_from_config(
- cfg, ds_cfg=None, tokenizer=tokenizer
- )
- elif cfg.datasets and cfg.datasets[0].type == "chat_template":
- chat_template_str = get_chat_template_from_config(
- cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
- )
+ else:
+ chat_template_str = resolve_chat_template_str(cfg, tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -190,14 +181,8 @@ def do_inference_gradio(
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
- elif cfg.chat_template:
- chat_template_str = get_chat_template_from_config(
- cfg, ds_cfg=None, tokenizer=tokenizer
- )
- elif cfg.datasets and cfg.datasets[0].type == "chat_template":
- chat_template_str = get_chat_template_from_config(
- cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
- )
+ else:
+ chat_template_str = resolve_chat_template_str(cfg, tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -297,13 +282,19 @@ def generate(instruction):
def do_cli(
- config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
+ config: Union[Path, str] = Path("examples/"),
+ gradio: bool = False,
+ chat: bool = False,
+ **kwargs,
) -> None:
"""
- Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
+ Parses axolotl config, CLI args, and calls `do_inference`, `do_inference_gradio`,
+ or `do_chat`.
Args:
config: Path to `axolotl` config YAML file.
+ gradio: Whether to launch the Gradio browser interface.
+ chat: Whether to launch the interactive multi-turn chat interface.
kwargs: Additional keyword arguments to override config file values.
"""
@@ -314,7 +305,14 @@ def do_cli(
return_remaining_strings=True
)
- if gradio:
+ if gradio and chat:
+ raise ValueError("--gradio and --chat are mutually exclusive.")
+
+ if chat:
+ from axolotl.cli.chat import do_chat
+
+ do_chat(cfg=parsed_cfg, cli_args=parsed_cli_args)
+ elif gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py
index cca6481e6e..13dcc4b77b 100644
--- a/src/axolotl/cli/main.py
+++ b/src/axolotl/cli/main.py
@@ -191,11 +191,16 @@ def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
help="Launcher to use for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
+@click.option(
+ "--chat", is_flag=True, help="Launch interactive multi-turn chat interface"
+)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
-def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
+def inference(
+ ctx: click.Context, config: str, launcher: str, gradio: bool, chat: bool, **kwargs
+):
"""
Run inference with a trained model.
@@ -204,9 +209,13 @@ def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kw
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
gradio: Whether to use Gradio browser interface or command line for inference.
+ chat: Whether to use the interactive multi-turn chat interface.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
+ if gradio and chat:
+ raise click.UsageError("--gradio and --chat are mutually exclusive.")
+
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
@@ -220,12 +229,14 @@ def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kw
base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
+ if chat:
+ base_cmd.append("--chat")
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
- do_cli(config=config, gradio=gradio, **kwargs)
+ do_cli(config=config, gradio=gradio, chat=chat, **kwargs)
@cli.command(
diff --git a/src/axolotl/cli/utils/__init__.py b/src/axolotl/cli/utils/__init__.py
index 583130339a..f54a471029 100644
--- a/src/axolotl/cli/utils/__init__.py
+++ b/src/axolotl/cli/utils/__init__.py
@@ -6,7 +6,7 @@
filter_none_kwargs,
)
from .fetch import fetch_from_github
-from .load import load_model_and_tokenizer
+from .load import load_model_and_tokenizer, resolve_chat_template_str
from .sweeps import generate_sweep_configs
from .train import build_command, generate_config_files, launch_training
@@ -18,6 +18,7 @@
"generate_config_files",
"generate_sweep_configs",
"load_model_and_tokenizer",
+ "resolve_chat_template_str",
"launch_training",
"fetch_from_github",
]
diff --git a/src/axolotl/cli/utils/load.py b/src/axolotl/cli/utils/load.py
index 610a81306d..4fb6f06f32 100644
--- a/src/axolotl/cli/utils/load.py
+++ b/src/axolotl/cli/utils/load.py
@@ -11,12 +11,39 @@
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
+from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
+def resolve_chat_template_str(
+ cfg: DictDefault,
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
+) -> str | None:
+ """
+ Resolves the chat template string for inference from the `axolotl` config,
+ mirroring how it would be resolved at training time: an explicit
+ `chat_template` config takes precedence, then the first dataset's
+ `chat_template` if that dataset is of type `chat_template`.
+
+ Args:
+ cfg: Dictionary mapping `axolotl` config keys to values.
+ tokenizer: Tokenizer to fall back to for tokenizer-default templates.
+
+ Returns:
+ Chat template string, or None if the config does not specify one.
+ """
+ if cfg.chat_template:
+ return get_chat_template_from_config(cfg, ds_cfg=None, tokenizer=tokenizer)
+ if cfg.datasets and cfg.datasets[0].type == "chat_template":
+ return get_chat_template_from_config(
+ cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
+ )
+ return None
+
+
def load_model_and_tokenizer(
*,
cfg: DictDefault,
diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py
index 67b1d32f18..e12a8be419 100644
--- a/src/axolotl/logging_config.py
+++ b/src/axolotl/logging_config.py
@@ -48,6 +48,17 @@ def filter(self, record: LogRecord) -> bool:
)
+class HubUnauthenticatedNagFilter(logging.Filter):
+ """
+ Drops the server-sent "sending unauthenticated requests" nag (an X-HF-Warning
+ response header that huggingface_hub logs with no env var to disable it).
+ Other hub warnings (retries, rate limits) pass through.
+ """
+
+ def filter(self, record: LogRecord) -> bool:
+ return "unauthenticated requests" not in record.getMessage()
+
+
class AxolotlLogger(Logger):
"""Logger that applies filtering to non-axolotl loggers."""
@@ -98,6 +109,9 @@ def format(self, record):
"ax_or_warn": {
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
},
+ "hub_unauthenticated_nag": {
+ "()": "axolotl.logging_config.HubUnauthenticatedNagFilter",
+ },
},
"handlers": {
"console": {
@@ -135,6 +149,11 @@ def format(self, record):
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
"propagate": False,
},
+ # filter at the emitting logger so the nag is dropped before it reaches
+ # both huggingface_hub's own handler and our root handlers
+ "huggingface_hub.utils._http": {
+ "filters": ["hub_unauthenticated_nag"],
+ },
},
}
diff --git a/tests/cli/test_chat_repl.py b/tests/cli/test_chat_repl.py
new file mode 100644
index 0000000000..9d6a1de646
--- /dev/null
+++ b/tests/cli/test_chat_repl.py
@@ -0,0 +1,668 @@
+"""pytest tests for the interactive chat REPL (no model required)."""
+
+import io
+import json
+
+import pytest
+from rich.console import Console
+
+from axolotl.cli.chat import (
+ CausalTurnGenerator,
+ ChatRepl,
+ ChatSession,
+ TurnResult,
+ default_gen_params,
+ longest_common_prefix_len,
+ parse_gen_param_value,
+ resolve_command,
+ resolve_gen_param,
+)
+
+
+class FakeCache:
+ """Stands in for DynamicCache in cache-planning tests."""
+
+ def __init__(self, length=0, croppable=True):
+ self.length = length
+ self.croppable = croppable
+
+ def crop(self, max_length):
+ if not self.croppable:
+ raise NotImplementedError("cannot crop")
+ self.length = max_length
+
+ def get_seq_length(self):
+ return self.length
+
+
+class FakeGenerator:
+ """Records conversations passed in and returns canned replies."""
+
+ def __init__(self, replies=None, messages=None):
+ self.replies = replies or ["canned reply"]
+ self.messages = messages
+ self.calls = []
+ self.render_kwargs_seen = []
+
+ def generate_turn(self, conversation, params, on_text, render_kwargs=None):
+ self.calls.append(([dict(m) for m in conversation], dict(params)))
+ self.render_kwargs_seen.append(
+ dict(render_kwargs) if render_kwargs is not None else None
+ )
+ index = min(len(self.calls) - 1, len(self.replies) - 1)
+ content = self.replies[index]
+ message = dict(self.messages[index]) if self.messages else None
+ on_text(content)
+ return TurnResult(
+ content=content, message=message, prompt_tokens=10, new_tokens=3
+ )
+
+
+def make_repl(inputs, generator=None, session=None):
+ lines = iter(inputs)
+
+ def input_fn(_prompt):
+ try:
+ return next(lines)
+ except StopIteration as err:
+ raise EOFError from err
+
+ generator = generator or FakeGenerator()
+ repl = ChatRepl(
+ generator=generator,
+ session=session,
+ console=Console(file=io.StringIO(), force_terminal=False),
+ input_fn=input_fn,
+ )
+ return repl, generator
+
+
+def cache_planner(cached_ids, cache):
+ generator = CausalTurnGenerator.__new__(CausalTurnGenerator)
+ generator._cache = cache # pylint: disable=protected-access
+ generator._cached_ids = cached_ids # pylint: disable=protected-access
+ generator._new_cache = FakeCache # pylint: disable=protected-access
+ return generator
+
+
+class TestGenParams:
+ def test_alias_resolution(self):
+ assert resolve_gen_param("temp").key == "temperature"
+ assert resolve_gen_param("max").key == "max_new_tokens"
+ assert resolve_gen_param("rep").key == "repetition_penalty"
+ assert resolve_gen_param("bogus") is None
+
+ def test_value_validation(self):
+ spec = resolve_gen_param("temperature")
+ assert parse_gen_param_value(spec, "0.7") == 0.7
+ with pytest.raises(ValueError):
+ parse_gen_param_value(spec, "100")
+ with pytest.raises(ValueError):
+ parse_gen_param_value(spec, "abc")
+
+ def test_nullable_params(self):
+ assert parse_gen_param_value(resolve_gen_param("seed"), "none") is None
+ assert parse_gen_param_value(resolve_gen_param("min_p"), "off") is None
+ with pytest.raises(ValueError):
+ parse_gen_param_value(resolve_gen_param("temperature"), "none")
+
+
+class TestChatSession:
+ def test_system_prompt_prepended(self):
+ session = ChatSession()
+ session.system = "be brief"
+ session.add_user("hi")
+ conversation = session.conversation()
+ assert conversation[0] == {"role": "system", "content": "be brief"}
+ assert conversation[1]["role"] == "user"
+
+ def test_undo_removes_exchange(self):
+ session = ChatSession()
+ session.add_user("q1")
+ session.add_assistant("a1")
+ session.add_user("q2")
+ session.add_assistant("a2")
+ assert session.undo()
+ assert [m["content"] for m in session.messages] == ["q1", "a1"]
+ assert session.undo()
+ assert not session.messages
+ assert not session.undo()
+
+ def test_drop_last_assistant_for_retry(self):
+ session = ChatSession()
+ session.add_user("q1")
+ session.add_assistant("a1")
+ assert session.drop_last_assistant()
+ assert session.messages[-1]["role"] == "user"
+ session.clear()
+ assert not session.drop_last_assistant()
+
+ def test_add_user_merges_consecutive_user_messages(self):
+ # a failed generation leaves a trailing user message; typing again must
+ # not create consecutive user turns (strict templates reject them)
+ session = ChatSession()
+ session.add_user("first try")
+ session.add_user("second try")
+ assert [m["role"] for m in session.messages] == ["user"]
+ assert session.messages[0]["content"] == "first try\nsecond try"
+
+ def test_save_jsonl_keeps_reasoning_content(self, tmp_path):
+ session = ChatSession()
+ session.add_user("q")
+ session.add_assistant_message(
+ {"role": "assistant", "content": "a", "reasoning_content": "hmm"}
+ )
+ path = tmp_path / "chat.jsonl"
+ session.save_jsonl(str(path))
+ sample = json.loads(path.read_text(encoding="utf-8"))
+ assistant = sample["messages"][1]
+ assert assistant["content"] == [{"type": "text", "text": "a"}]
+ assert assistant["reasoning_content"] == "hmm"
+
+ def test_save_jsonl_multimodal_parts_format(self, tmp_path):
+ session = ChatSession()
+ session.system = "sys"
+ session.add_user("q")
+ session.add_assistant("a")
+ path = tmp_path / "chat.jsonl"
+ session.save_jsonl(str(path))
+ session.save_jsonl(str(path))
+ lines = path.read_text(encoding="utf-8").strip().split("\n")
+ assert len(lines) == 2
+ sample = json.loads(lines[0])
+ assert [m["role"] for m in sample["messages"]] == [
+ "system",
+ "user",
+ "assistant",
+ ]
+ assert sample["messages"][1]["content"] == [{"type": "text", "text": "q"}]
+ assert sample["messages"][2]["content"] == [{"type": "text", "text": "a"}]
+
+
+class TestCachePlanning:
+ def test_prefix_extension_reuses_cache(self):
+ cache = FakeCache(length=5)
+ generator = cache_planner([1, 2, 3, 4, 5], cache)
+ assert generator._prepare_cache([1, 2, 3, 4, 5, 6, 7]) == 5
+ assert generator._cache is cache
+
+ def test_divergence_crops_to_common_prefix(self):
+ cache = FakeCache(length=5)
+ generator = cache_planner([1, 2, 3, 4, 5], cache)
+ assert generator._prepare_cache([1, 2, 3, 9, 9, 9]) == 3
+ assert cache.length == 3
+ assert generator._cached_ids == [1, 2, 3]
+
+ def test_no_overlap_resets_cache(self):
+ cache = FakeCache(length=3)
+ generator = cache_planner([1, 2, 3], cache)
+ assert generator._prepare_cache([7, 8, 9]) == 0
+ assert generator._cache is not cache
+ assert generator._cached_ids == []
+
+ def test_uncroppable_cache_resets(self):
+ cache = FakeCache(length=5, croppable=False)
+ generator = cache_planner([1, 2, 3, 4, 5], cache)
+ assert generator._prepare_cache([1, 2, 3, 9, 9]) == 0
+ assert generator._cache is not cache
+
+ def test_render_fully_cached_leaves_one_input_token(self):
+ # cache covering all input tokens would give generate() nothing to process
+ cache = FakeCache(length=5)
+ generator = cache_planner([1, 2, 3, 4, 5], cache)
+ assert generator._prepare_cache([1, 2, 3, 4, 5]) == 4
+ assert cache.length == 4
+
+
+class TestChatRepl:
+ def test_message_generates_turn_with_history(self):
+ repl, generator = make_repl(["hi", "again", "/quit"])
+ repl.run()
+ assert len(generator.calls) == 2
+ second_conversation = generator.calls[1][0]
+ assert [m["role"] for m in second_conversation] == [
+ "user",
+ "assistant",
+ "user",
+ ]
+ assert repl.session.messages[-1]["content"] == "canned reply"
+
+ def test_command_aliases(self):
+ assert resolve_command("clear").name == "new"
+ assert resolve_command("reset").name == "new"
+ assert resolve_command("regen").name == "retry"
+ assert resolve_command("q").name == "quit"
+ assert resolve_command("?").name == "help"
+
+ def test_new_clears_history_keeps_system_and_params(self):
+ repl, generator = make_repl(
+ ["/system be brief", "/temp 0.5", "hi", "/new", "next", "/quit"]
+ )
+ repl.run()
+ assert repl.session.system == "be brief"
+ assert repl.params["temperature"] == 0.5
+ last_conversation = generator.calls[-1][0]
+ assert [m["role"] for m in last_conversation] == ["system", "user"]
+ assert last_conversation[1]["content"] == "next"
+
+ def test_param_shortcut_and_set_forms(self):
+ repl, _ = make_repl(
+ ["/temp 0.3", "/set top_k 10", "/set max_tokens=64", "/quit"]
+ )
+ repl.run()
+ assert repl.params["temperature"] == 0.3
+ assert repl.params["top_k"] == 10
+ assert repl.params["max_new_tokens"] == 64
+
+ def test_invalid_param_value_not_applied(self):
+ repl, _ = make_repl(["/temp 100", "/quit"])
+ repl.run()
+ assert repl.params["temperature"] == default_gen_params()["temperature"]
+
+ def test_retry_regenerates_last_turn(self):
+ repl, generator = make_repl(
+ ["hi", "/retry", "/quit"], generator=FakeGenerator(["first", "second"])
+ )
+ repl.run()
+ assert len(generator.calls) == 2
+ retry_conversation = generator.calls[1][0]
+ assert retry_conversation[-1] == {"role": "user", "content": "hi"}
+ assert repl.session.messages[-1]["content"] == "second"
+
+ def test_undo_command(self):
+ repl, _ = make_repl(["hi", "/undo", "/quit"])
+ repl.run()
+ assert not repl.session.messages
+
+ def test_multiline_input(self):
+ repl, generator = make_repl(["first line\\", "second line", "/quit"])
+ repl.run()
+ assert generator.calls[0][0][0]["content"] == "first line\nsecond line"
+
+ def test_generator_message_stored_in_history(self):
+ message = {
+ "role": "assistant",
+ "content": "The answer is 4.",
+ "reasoning_content": "step by step",
+ }
+ repl, _ = make_repl(
+ ["what is 2+2?", "/quit"],
+ generator=FakeGenerator(["The answer is 4."], messages=[message]),
+ )
+ repl.run()
+ assert repl.session.messages[-1] == message
+ # renderer saw no think markers, so /expand falls back to the message
+ assert repl.last_think_text == "step by step"
+
+ def test_legacy_content_fallback_kept_verbatim(self):
+ # generators that return no message dict store their content as-is
+ reply = "step by step\nThe answer is 4."
+ repl, _ = make_repl(["what is 2+2?", "/quit"], generator=FakeGenerator([reply]))
+ repl.run()
+ assert repl.session.messages[-1]["content"] == reply
+
+ def test_unknown_command_does_not_generate(self):
+ repl, generator = make_repl(["/bogus", "/quit"])
+ repl.run()
+ assert not generator.calls
+
+ def test_command_handler_error_does_not_crash_repl(self):
+ # unclosed quote makes shlex raise inside /save
+ repl, generator = make_repl(["hi", '/save "unclosed', "again", "/quit"])
+ repl.run()
+ assert len(generator.calls) == 2
+
+ def test_keyboard_interrupt_keeps_session_alive(self):
+ class InterruptingGenerator(FakeGenerator):
+ def generate_turn(self, conversation, params, on_text, render_kwargs=None):
+ if not self.calls:
+ self.calls.append(None)
+ raise KeyboardInterrupt
+ return super().generate_turn(
+ conversation, params, on_text, render_kwargs
+ )
+
+ repl, generator = make_repl(
+ ["hi", "again", "/quit"], generator=InterruptingGenerator()
+ )
+ repl.run()
+ # interrupted turn keeps the user message; the next one merges into it
+ assert [m["role"] for m in repl.session.messages] == ["user", "assistant"]
+ assert repl.session.messages[0]["content"] == "hi\nagain"
+ assert len(generator.calls) == 2
+
+ def test_generation_failure_keeps_session_alive(self):
+ class FailingGenerator(FakeGenerator):
+ def generate_turn(self, conversation, params, on_text, render_kwargs=None):
+ if not self.calls:
+ self.calls.append(None)
+ raise RuntimeError("boom")
+ return super().generate_turn(
+ conversation, params, on_text, render_kwargs
+ )
+
+ repl, _ = make_repl(["hi", "/retry", "/quit"], generator=FailingGenerator())
+ repl.run()
+ assert [m["role"] for m in repl.session.messages] == ["user", "assistant"]
+ assert repl.session.messages[-1]["content"] == "canned reply"
+
+
+def test_longest_common_prefix_len():
+ assert longest_common_prefix_len([], [1, 2]) == 0
+ assert longest_common_prefix_len([1, 2], [1, 2]) == 2
+ assert longest_common_prefix_len([1, 2, 3], [1, 2]) == 2
+ assert longest_common_prefix_len([1, 9], [1, 2, 3]) == 1
+
+
+class TestDiffusionChat:
+ def test_diffusion_param_specs(self):
+ from axolotl.cli.chat import DIFFUSION_GEN_PARAMS
+
+ lines = iter(["/steps 32", "/tokens 64", "/top_p 0.9", "/quit"])
+
+ def input_fn(_prompt):
+ try:
+ return next(lines)
+ except StopIteration as err:
+ raise EOFError from err
+
+ repl = ChatRepl(
+ generator=FakeGenerator(),
+ param_specs=DIFFUSION_GEN_PARAMS,
+ console=Console(file=io.StringIO(), force_terminal=False),
+ input_fn=input_fn,
+ )
+ repl.run()
+ assert repl.params["steps"] == 32
+ assert repl.params["max_new_tokens"] == 64
+ assert "top_p" not in repl.params
+
+ def test_diffusion_turn_cuts_at_eos(self, monkeypatch):
+ from types import SimpleNamespace
+
+ import axolotl.integrations.diffusion as diffusion_module
+ from axolotl.cli.chat import (
+ DIFFUSION_GEN_PARAMS,
+ DiffusionTurnGenerator,
+ default_gen_params,
+ )
+
+ class FakeTokenizer:
+ eos_token_id = 2
+
+ def apply_chat_template(self, conversation, **kwargs):
+ return {"input_ids": [1, 5, 6]}
+
+ def decode(self, ids, **kwargs):
+ return ",".join(str(i) for i in ids)
+
+ fake_model = SimpleNamespace(
+ generation_config=SimpleNamespace(eos_token_id=None)
+ )
+
+ def fake_generate(model, tokenizer, **kwargs):
+ assert kwargs["mode"] == "completion"
+ assert kwargs["completion_tokens"] == 256
+ return {"generated_ids": [1, 5, 6, 7, 8, 2, 4]}
+
+ monkeypatch.setattr(diffusion_module, "generate", fake_generate)
+
+ generator = DiffusionTurnGenerator(
+ fake_model, FakeTokenizer(), None, "cpu", mask_token_id=9
+ )
+ chunks = []
+ result = generator.generate_turn(
+ [{"role": "user", "content": "hi"}],
+ default_gen_params(DIFFUSION_GEN_PARAMS),
+ chunks.append,
+ )
+ assert result.content == "7,8"
+ assert result.new_tokens == 2
+ assert result.prompt_tokens == 3
+ assert chunks == ["7,8"]
+
+
+def test_unknown_command_suggests_alias():
+ buf = io.StringIO()
+ repl = ChatRepl(
+ generator=FakeGenerator(),
+ console=Console(file=buf, force_terminal=False, width=200),
+ input_fn=lambda _p: "/quit",
+ )
+ repl._dispatch("/clea")
+ assert "Did you mean /clear?" in buf.getvalue()
+ repl._dispatch("/tem")
+ assert "Did you mean /temp?" in buf.getvalue()
+
+
+class TestThinkStreamRenderer:
+ def make_renderer(self, collapse=True, markers=("", "")):
+ from axolotl.cli.chat import ThinkStreamRenderer
+
+ buf = io.StringIO()
+ console = Console(file=buf, force_terminal=False, width=200)
+ return ThinkStreamRenderer(console, collapse=collapse, markers=markers), buf
+
+ def test_collapse_splits_thinking_from_reply(self, capsys):
+ renderer, buf = self.make_renderer()
+ for chunk in ["\nreasoning he", "re\n\nAnswer!"]:
+ renderer.feed(chunk)
+ renderer.finish()
+ assert renderer.think_text.strip() == "reasoning here"
+ assert capsys.readouterr().out == "Answer!"
+ assert "thought for" in buf.getvalue()
+
+ def test_no_thinking_passthrough(self, capsys):
+ renderer, buf = self.make_renderer()
+ renderer.feed("Just a plain reply")
+ renderer.finish()
+ assert renderer.think_text == ""
+ assert capsys.readouterr().out == "Just a plain reply"
+ assert "thought for" not in buf.getvalue()
+
+ def test_unterminated_thinking(self, capsys):
+ renderer, buf = self.make_renderer()
+ renderer.feed("partial reasoning")
+ renderer.finish()
+ assert renderer.think_text == "partial reasoning"
+ assert capsys.readouterr().out == ""
+ assert "no " in buf.getvalue()
+
+ def test_collapse_off_is_passthrough(self, capsys):
+ renderer, buf = self.make_renderer(collapse=False)
+ renderer.feed("abcreply")
+ renderer.finish()
+ assert capsys.readouterr().out == "abcreply"
+ assert buf.getvalue() == ""
+
+ def test_custom_markers(self, capsys):
+ renderer, _ = self.make_renderer(
+ markers=("<|START_THINKING|>", "<|END_THINKING|>")
+ )
+ renderer.feed("<|START_THINKING|>hmm<|END_THINKING|>ok")
+ renderer.finish()
+ assert renderer.think_text == "hmm"
+ assert capsys.readouterr().out == "ok"
+
+
+class TestThinkTokenSplit:
+ def make_generator(self, vocab):
+ from types import SimpleNamespace
+
+ from axolotl.cli.chat import TurnGenerator
+
+ class FakeTokenizer:
+ eos_token_id = 0
+ chat_template = None
+
+ def encode(self, text, **kwargs):
+ return vocab[text]
+
+ model = SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=None))
+ return TurnGenerator(model, FakeTokenizer(), None, "cpu")
+
+ def test_split_counts(self):
+ generator = self.make_generator({"": [100], "": [101]})
+ assert generator.split_think_token_counts([100, 1, 2, 3, 101, 7, 8]) == (3, 2)
+ assert generator.split_think_token_counts([100, 1, 2]) == (2, 0)
+ assert generator.split_think_token_counts([5, 6]) == (0, 2)
+ assert generator.split_think_token_counts([]) == (0, 0)
+
+
+class TestBuildAssistantMessage:
+ VOCAB = {
+ 1: "step by step",
+ 2: "The answer is 4.",
+ 50: "",
+ 100: "",
+ 101: "",
+ }
+ SPECIAL = {50, 100, 101}
+
+ def make_generator(self, response_schema=None, parse_response=None):
+ from types import SimpleNamespace
+
+ from axolotl.cli.chat import TurnGenerator
+
+ vocab, special = self.VOCAB, self.SPECIAL
+
+ class FakeTokenizer:
+ eos_token_id = 50
+ chat_template = None
+
+ def encode(self, text, **kwargs):
+ return [token_id for token_id, t in vocab.items() if t == text]
+
+ def decode(self, ids, skip_special_tokens=False):
+ return "".join(
+ vocab[i] for i in ids if not (skip_special_tokens and i in special)
+ )
+
+ tokenizer = FakeTokenizer()
+ if response_schema is not None:
+ tokenizer.response_schema = response_schema
+ tokenizer.parse_response = parse_response
+ model = SimpleNamespace(generation_config=SimpleNamespace(eos_token_id=None))
+ return TurnGenerator(model, tokenizer, None, "cpu")
+
+ def test_thinking_split_into_reasoning_content(self):
+ generator = self.make_generator()
+ message = generator.build_assistant_message([100, 1, 101, 2])
+ assert message == {
+ "role": "assistant",
+ "content": "The answer is 4.",
+ "reasoning_content": "step by step",
+ }
+
+ def test_no_thinking_omits_reasoning_key(self):
+ generator = self.make_generator()
+ message = generator.build_assistant_message([2])
+ assert message == {"role": "assistant", "content": "The answer is 4."}
+
+ def test_special_tokens_stripped_from_content(self):
+ generator = self.make_generator()
+ message = generator.build_assistant_message([2, 50])
+ assert message["content"] == "The answer is 4."
+
+ def test_parse_response_schema_preferred(self):
+ generator = self.make_generator(
+ response_schema={"x": "regex"},
+ parse_response=lambda text: {"content": "parsed", "thinking": "hmm"},
+ )
+ message = generator.build_assistant_message([2])
+ assert message == {
+ "role": "assistant",
+ "content": "parsed",
+ "thinking": "hmm",
+ }
+
+ def test_parse_response_failure_falls_back_to_markers(self):
+ def boom(text):
+ raise ValueError("bad schema")
+
+ generator = self.make_generator(
+ response_schema={"x": "regex"}, parse_response=boom
+ )
+ message = generator.build_assistant_message([100, 1, 101, 2])
+ assert message["content"] == "The answer is 4."
+ assert message["reasoning_content"] == "step by step"
+
+
+class TestEosTextTrimmer:
+ def make_trimmer(self, eos_strings=("<|im_end|>",)):
+ from axolotl.cli.chat import EosTextTrimmer
+
+ chunks = []
+ return EosTextTrimmer(eos_strings, chunks.append), chunks
+
+ def test_eos_marker_never_emitted(self):
+ trimmer, chunks = self.make_trimmer()
+ trimmer.feed("Hello")
+ trimmer.feed(" world<|im_end|>")
+ trimmer.finish()
+ assert "".join(chunks) == "Hello world"
+
+ def test_eos_split_across_chunks(self):
+ trimmer, chunks = self.make_trimmer()
+ trimmer.feed("Hi<|im_")
+ trimmer.feed("end|>")
+ trimmer.finish()
+ assert "".join(chunks) == "Hi"
+
+ def test_false_partial_released(self):
+ trimmer, chunks = self.make_trimmer()
+ trimmer.feed("a<")
+ trimmer.feed("b")
+ trimmer.finish()
+ assert "".join(chunks) == "a", "")
+
+ command_a_like = "...<|START_THINKING|>...<|END_THINKING|>..."
+ assert detect_think_markers(command_a_like) == (
+ "<|START_THINKING|>",
+ "<|END_THINKING|>",
+ )
+ assert detect_think_toggle_key(command_a_like) is None
+ assert detect_think_toggle_key("{% if thinking %}x{% endif %}") == "thinking"
+ assert detect_think_toggle_key(None) is None
diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py
index 807dc7fa35..8728db23d1 100644
--- a/tests/cli/test_cli_inference.py
+++ b/tests/cli/test_cli_inference.py
@@ -144,3 +144,42 @@ def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_pa
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
+
+
+def test_inference_chat(cli_runner, config_path):
+ """Test basic inference (chat path)"""
+ with patch("axolotl.cli.chat.do_chat") as mock:
+ result = cli_runner.invoke(
+ cli,
+ ["inference", str(config_path), "--launcher", "python", "--chat"],
+ catch_exceptions=False,
+ )
+
+ assert mock.called
+ assert result.exit_code == 0
+
+
+def test_inference_chat_with_launcher(cli_runner, config_path):
+ """Test chat flag is forwarded through the launcher command"""
+ with patch("subprocess.run") as mock_subprocess:
+ result = cli_runner.invoke(
+ cli,
+ ["inference", str(config_path), "--launcher", "accelerate", "--chat"],
+ catch_exceptions=False,
+ )
+
+ assert result.exit_code == 0
+ called_cmd = mock_subprocess.call_args.args[0]
+ assert "--chat" in called_cmd
+ assert "axolotl.cli.inference" in called_cmd
+
+
+def test_inference_chat_gradio_mutually_exclusive(cli_runner, config_path):
+ """Test that --chat and --gradio cannot be combined"""
+ result = cli_runner.invoke(
+ cli,
+ ["inference", str(config_path), "--chat", "--gradio"],
+ )
+
+ assert result.exit_code != 0
+ assert "mutually exclusive" in result.output
diff --git a/tests/test_logging_config_file_capture.py b/tests/test_logging_config_file_capture.py
index 44b0ee5e62..11e8d7c645 100644
--- a/tests/test_logging_config_file_capture.py
+++ b/tests/test_logging_config_file_capture.py
@@ -101,3 +101,28 @@ def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch):
# Ensure the marker appears once (not duplicated via propagation)
assert data.count(marker) == 1
tee.close_debug_log()
+
+
+def test_hub_unauthenticated_nag_suppressed(monkeypatch):
+ from axolotl.logging_config import configure_logging
+ from axolotl.utils import tee
+
+ with tempfile.TemporaryDirectory() as td:
+ monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
+ configure_logging()
+ path = tee.prepare_debug_log(
+ type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
+ )
+
+ hub_http = logging.getLogger("huggingface_hub.utils._http")
+ hub_http.warning(
+ "Warning: You are sending unauthenticated requests to the HF Hub."
+ " Please set a HF_TOKEN to enable higher rate limits and faster downloads."
+ )
+ hub_http.warning("Retrying in 2s [Retry 1/5].")
+ tee.file_only_stream.flush()
+
+ data = read(path)
+ assert "unauthenticated requests" not in data
+ assert "Retrying in 2s" in data
+ tee.close_debug_log()