Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def supports_batched(self) -> bool:

def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.prompter.field_messages]
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
)
Comment on lines +397 to 399

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

is_prompt_batched can misclassify single-string prompts as batched.

At Line 397 and Line 398, a single prompt with messages as str now passes both checks (because iterating a string yields str chars), so tokenize_prompt zips characters instead of examples.

Proposed fix
 def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
     try:
-            return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
-                isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
-            )
+            messages = prompt[self.prompter.field_messages]
+            if not isinstance(messages, list):
+                return False
+            return all(isinstance(v, list) for v in prompt.values()) and all(
+                isinstance(v, (str, list)) for v in messages
+            )
     except KeyError:
         return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/chat_template.py` around lines 397 - 399,
is_prompt_batched currently misclassifies a single-string prompt as batched
because iterating a str yields chars; fix the predicate so it only treats
prompts as batched when the relevant fields are actual lists: in
is_prompt_batched, require that prompt[self.prompter.field_messages] is an
instance of list (not str) and that all prompt.values() are lists (or more
precisely that fields that should be batched are lists), then validate each item
in prompt[self.prompter.field_messages] with all(isinstance(m, (str, list)) for
m in ...) to avoid iterating characters; update the condition around
prompt[self.prompter.field_messages] and any uses in tokenize_prompt
accordingly.

except KeyError:
return False
Expand Down Expand Up @@ -1004,6 +1004,13 @@ def _get_tools(self, prompt) -> list[dict] | None:
if tools is None:
return None

# Some datasets have tools set to str
if isinstance(tools, str):
try:
tools = json.loads(tools)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
raise
if isinstance(tools, list):
# Process each tool to handle JSON string parameters
for tool in tools:
Expand Down Expand Up @@ -1034,6 +1041,22 @@ def _get_messages(self, prompt):
if messages is None:
raise ValueError("Messages is null. Please check `field_messages`.")

if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError as e:
LOG.error(f"Error parsing messages as JSON. Error: {e}")
raise
assert isinstance(messages, list), (
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
)

# Extra check here to make sure decoded json is a list of dicts.
for i, message in enumerate(messages):
assert isinstance(message, dict), (
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
)

Comment on lines +1044 to +1059

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verification: confirm the undefined name and assert-based runtime checks are present.
rg -n "type\(message\)|assert isinstance\(messages, list\)|assert isinstance\(message, dict\)" src/axolotl/prompt_strategies/chat_template.py

Repository: axolotl-ai-cloud/axolotl

Length of output: 1814


Replace runtime validation with explicit error handling and fix undefined variable name.

Line 1044 references type(message) which is not defined at that scope—the variable message is only introduced inside the loop at line 1047. This causes a NameError whenever messages is not a list. Additionally, using assert for runtime input validation is unreliable since assertions can be disabled with python -O, leaving malformed inputs undetected in production.

Proposed fix
         if isinstance(messages, str):
-            messages = json.loads(messages)
-            assert isinstance(messages, list), f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
-
-            # Extra check here to make sure decoded json is a list of dicts.
-            for i, message in enumerate(messages):
-                assert isinstance(message, dict), f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
+            try:
+                messages = json.loads(messages)
+            except json.JSONDecodeError as e:
+                raise ValueError(
+                    f"Invalid JSON in `{self.prompter.field_messages}` field."
+                ) from e
+
+            if not isinstance(messages, list):
+                raise ValueError(
+                    "For SFT datasets stored as `str`, decoded `messages` must be a list[dict], "
+                    f"got {type(messages)}."
+                )
+
+            # Extra check here to make sure decoded json is a list of dicts.
+            for i, message in enumerate(messages):
+                if not isinstance(message, dict):
+                    raise ValueError(
+                        "For SFT datasets stored as `str`, each turn must be a dict; "
+                        f"got {type(message)} at turn {i}."
+                    )
🧰 Tools
🪛 Ruff (0.15.10)

[error] 1044-1044: Undefined name message

(F821)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/prompt_strategies/chat_template.py` around lines 1042 - 1049, The
code block incorrectly uses assert (which can be disabled) and references an
undefined name type(message) before the loop; replace the asserts with explicit
runtime validation: after json.loads(messages) (wrap in try/except to catch
JSONDecodeError) validate that the result is a list and if not raise a
ValueError with a clear message (use type(messages) in that message), then
iterate over the list and validate each item is a dict, raising ValueError that
includes the index and the actual type of the offending element (use the loop
variable message for per-item checks); ensure no undefined variables are
referenced and that errors are descriptive for SFT dataset parsing.

if isinstance(messages, list):
return messages

Expand Down
67 changes: 67 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,70 @@ def test_loading_local_dataset_folder(self, tokenizer):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)

@enable_hf_offline
def test_load_dataset_with_str_json_data(self, tokenizer):
"""
Test loading datasets where data is stored as str JSON instead of list of dicts.
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
import json

str_json_ds = Dataset.from_list(
[
{
"messages": json.dumps(
[
{"role": "user", "content": "Hello how are you?"},
{
"role": "assistant",
"content": "I am doing good thanks",
},
]
)
},
{
"messages": json.dumps(
[
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
]
)
},
]
)

tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
str_json_ds.to_parquet(tmp_ds_path)

prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 512,
"datasets": [
{
"path": str(tmp_ds_path),
"name": "test_str_json",
"type": "chat_template",
"field_messages": "messages",
"message_field_role": "role",
"message_field_content": "content",
},
],
"dataset_num_proc": 4,
}
)

with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)

assert len(dataset) == 2
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features

assert len(dataset[0]["input_ids"]) > 0
Loading