-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Feat: add support for datasets with str saved messages field
#3607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -394,8 +394,8 @@ def supports_batched(self) -> bool: | |
|
|
||
| def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: | ||
| try: | ||
| return all(isinstance(v, list) for v in prompt.values()) and all( | ||
| isinstance(v, list) for v in prompt[self.prompter.field_messages] | ||
| return all(isinstance(v, (str, list)) for v in prompt.values()) and all( | ||
| isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages] | ||
| ) | ||
| except KeyError: | ||
| return False | ||
|
|
@@ -1004,6 +1004,13 @@ def _get_tools(self, prompt) -> list[dict] | None: | |
| if tools is None: | ||
| return None | ||
|
|
||
| # Some datasets have tools set to str | ||
| if isinstance(tools, str): | ||
| try: | ||
| tools = json.loads(tools) | ||
| except json.JSONDecodeError as e: | ||
| LOG.error(f"Error parsing tool parameters as JSON. Error: {e}") | ||
| raise | ||
| if isinstance(tools, list): | ||
| # Process each tool to handle JSON string parameters | ||
| for tool in tools: | ||
|
|
@@ -1034,6 +1041,22 @@ def _get_messages(self, prompt): | |
| if messages is None: | ||
| raise ValueError("Messages is null. Please check `field_messages`.") | ||
|
|
||
| if isinstance(messages, str): | ||
| try: | ||
| messages = json.loads(messages) | ||
| except json.JSONDecodeError as e: | ||
| LOG.error(f"Error parsing messages as JSON. Error: {e}") | ||
| raise | ||
| assert isinstance(messages, list), ( | ||
| f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" | ||
| ) | ||
|
|
||
| # Extra check here to make sure decoded json is a list of dicts. | ||
| for i, message in enumerate(messages): | ||
| assert isinstance(message, dict), ( | ||
| f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}" | ||
| ) | ||
|
|
||
|
Comment on lines
+1044
to
+1059
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verification: confirm the undefined name and assert-based runtime checks are present.
rg -n "type\(message\)|assert isinstance\(messages, list\)|assert isinstance\(message, dict\)" src/axolotl/prompt_strategies/chat_template.pyRepository: axolotl-ai-cloud/axolotl Length of output: 1814 Replace runtime validation with explicit error handling and fix undefined variable name. Line 1044 references Proposed fix if isinstance(messages, str):
- messages = json.loads(messages)
- assert isinstance(messages, list), f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
-
- # Extra check here to make sure decoded json is a list of dicts.
- for i, message in enumerate(messages):
- assert isinstance(message, dict), f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
+ try:
+ messages = json.loads(messages)
+ except json.JSONDecodeError as e:
+ raise ValueError(
+ f"Invalid JSON in `{self.prompter.field_messages}` field."
+ ) from e
+
+ if not isinstance(messages, list):
+ raise ValueError(
+ "For SFT datasets stored as `str`, decoded `messages` must be a list[dict], "
+ f"got {type(messages)}."
+ )
+
+ # Extra check here to make sure decoded json is a list of dicts.
+ for i, message in enumerate(messages):
+ if not isinstance(message, dict):
+ raise ValueError(
+ "For SFT datasets stored as `str`, each turn must be a dict; "
+ f"got {type(message)} at turn {i}."
+ )🧰 Tools🪛 Ruff (0.15.10)[error] 1044-1044: Undefined name (F821) 🤖 Prompt for AI Agents |
||
| if isinstance(messages, list): | ||
| return messages | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_prompt_batchedcan misclassify single-string prompts as batched.At Line 397 and Line 398, a single prompt with
messagesasstrnow passes both checks (because iterating a string yieldsstrchars), sotokenize_promptzips characters instead of examples.Proposed fix
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: try: - return all(isinstance(v, (str, list)) for v in prompt.values()) and all( - isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages] - ) + messages = prompt[self.prompter.field_messages] + if not isinstance(messages, list): + return False + return all(isinstance(v, list) for v in prompt.values()) and all( + isinstance(v, (str, list)) for v in messages + ) except KeyError: return False🤖 Prompt for AI Agents