From 312c0eb701ed446ef7218fe6356168b3c931084e Mon Sep 17 00:00:00 2001 From: brightwind26 Date: Thu, 16 Apr 2026 10:44:21 +0000 Subject: [PATCH 1/5] feat: support datasets saved in str format --- src/axolotl/prompt_strategies/chat_template.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index a943a14484..6d7b03f59c 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -394,8 +394,8 @@ def supports_batched(self) -> bool: def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: try: - return all(isinstance(v, list) for v in prompt.values()) and all( - isinstance(v, list) for v in prompt[self.prompter.field_messages] + return all(isinstance(v, (str, list)) for v in prompt.values()) and all( + isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages] ) except KeyError: return False @@ -1034,6 +1034,14 @@ def _get_messages(self, prompt): if messages is None: raise ValueError("Messages is null. Please check `field_messages`.") + if isinstance(messages, str): + messages = json.loads(messages) + assert isinstance(messages, list), f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" + + # Extra check here to make sure decoded json is a list of dicts. + for i, message in enumerate(messages): + assert isinstance(message, dict), f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}" + if isinstance(messages, list): return messages From 023e42fc558e21bea2fc3d4c7105204e33e74258 Mon Sep 17 00:00:00 2001 From: brightwind26 Date: Thu, 16 Apr 2026 10:45:36 +0000 Subject: [PATCH 2/5] add also str for tools --- src/axolotl/prompt_strategies/chat_template.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 6d7b03f59c..db0f07836e 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1004,6 +1004,11 @@ def _get_tools(self, prompt) -> list[dict] | None: if tools is None: return None + + # Some datasets have tools set to str + if isinstance(tools, str): + tools = json.loads(tools) + if isinstance(tools, list): # Process each tool to handle JSON string parameters for tool in tools: From dd44e2ff18d95fc421efb050753ccdbc6a2f54d0 Mon Sep 17 00:00:00 2001 From: brightwind26 Date: Thu, 16 Apr 2026 10:51:21 +0000 Subject: [PATCH 3/5] format --- src/axolotl/prompt_strategies/chat_template.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index db0f07836e..65f0cf4710 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1004,7 +1004,6 @@ def _get_tools(self, prompt) -> list[dict] | None: if tools is None: return None - # Some datasets have tools set to str if isinstance(tools, str): tools = json.loads(tools) @@ -1041,11 +1040,15 @@ def _get_messages(self, prompt): if isinstance(messages, str): messages = json.loads(messages) - assert isinstance(messages, list), f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" + assert isinstance(messages, list), ( + f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" + ) # Extra check here to make sure decoded json is a list of dicts. for i, message in enumerate(messages): - assert isinstance(message, dict), f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}" + assert isinstance(message, dict), ( + f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}" + ) if isinstance(messages, list): return messages From 98a56bae727a85fce79c903b9c746e2b56bf4e8a Mon Sep 17 00:00:00 2001 From: brightwind26 Date: Fri, 17 Apr 2026 06:26:41 +0000 Subject: [PATCH 4/5] fix: address comments + add unit test --- .../prompt_strategies/chat_template.py | 19 +++++- tests/test_datasets.py | 60 +++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 65f0cf4710..e05ad748ed 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1006,8 +1006,14 @@ def _get_tools(self, prompt) -> list[dict] | None: # Some datasets have tools set to str if isinstance(tools, str): - tools = json.loads(tools) - + try: + tools = json.loads(tools) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing tool parameters as JSON. " + f"Error: {e}" + ) + raise if isinstance(tools, list): # Process each tool to handle JSON string parameters for tool in tools: @@ -1039,7 +1045,14 @@ def _get_messages(self, prompt): raise ValueError("Messages is null. Please check `field_messages`.") if isinstance(messages, str): - messages = json.loads(messages) + try: + messages = json.loads(messages) + except json.JSONDecodeError as e: + LOG.error( + f"Error parsing messages as JSON. " + f"Error: {e}" + ) + raise assert isinstance(messages, list), ( f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3b24ad5805..4449ed6feb 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -487,3 +487,63 @@ def test_loading_local_dataset_folder(self, tokenizer): assert "attention_mask" in dataset.features assert "labels" in dataset.features shutil.rmtree(tmp_ds_path) + + @enable_hf_offline + def test_load_dataset_with_str_json_data(self, tokenizer): + """ + Test loading datasets where data is stored as str JSON instead of list of dicts. + see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + import json + + str_json_ds = Dataset.from_list( + [ + { + "messages": json.dumps([ + {"role": "user", "content": "Hello how are you?"}, + {"role": "assistant", "content": "I am doing good thanks"} + ]) + }, + { + "messages": json.dumps([ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."} + ]) + } + ] + ) + + tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet" + str_json_ds.to_parquet(tmp_ds_path) + + prepared_path = Path(tmp_dir) / "prepared" + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 512, + "datasets": [ + { + "path": str(tmp_ds_path), + "name": "test_str_json", + "type": "chat_template", + "field_messages": "messages", + "message_field_role": "role", + "message_field_content": "content", + }, + ], + "dataset_num_proc": 4, + } + ) + + with patch( + "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) + ): + dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) + + assert len(dataset) == 2 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + + assert len(dataset[0]["input_ids"]) > 0 \ No newline at end of file From ee77c6ef7d98f4f2ac5527512171c7887a3d271d Mon Sep 17 00:00:00 2001 From: brightwind26 Date: Fri, 17 Apr 2026 06:27:25 +0000 Subject: [PATCH 5/5] format --- .../prompt_strategies/chat_template.py | 10 +---- tests/test_datasets.py | 39 +++++++++++-------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index e05ad748ed..be6a38800e 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1009,10 +1009,7 @@ def _get_tools(self, prompt) -> list[dict] | None: try: tools = json.loads(tools) except json.JSONDecodeError as e: - LOG.error( - f"Error parsing tool parameters as JSON. " - f"Error: {e}" - ) + LOG.error(f"Error parsing tool parameters as JSON. Error: {e}") raise if isinstance(tools, list): # Process each tool to handle JSON string parameters @@ -1048,10 +1045,7 @@ def _get_messages(self, prompt): try: messages = json.loads(messages) except json.JSONDecodeError as e: - LOG.error( - f"Error parsing messages as JSON. " - f"Error: {e}" - ) + LOG.error(f"Error parsing messages as JSON. Error: {e}") raise assert isinstance(messages, list), ( f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4449ed6feb..bdb795e136 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -496,27 +496,34 @@ def test_load_dataset_with_str_json_data(self, tokenizer): """ with tempfile.TemporaryDirectory() as tmp_dir: import json - + str_json_ds = Dataset.from_list( [ { - "messages": json.dumps([ - {"role": "user", "content": "Hello how are you?"}, - {"role": "assistant", "content": "I am doing good thanks"} - ]) + "messages": json.dumps( + [ + {"role": "user", "content": "Hello how are you?"}, + { + "role": "assistant", + "content": "I am doing good thanks", + }, + ] + ) }, { - "messages": json.dumps([ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."} - ]) - } + "messages": json.dumps( + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "2+2 equals 4."}, + ] + ) + }, ] ) - + tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet" str_json_ds.to_parquet(tmp_ds_path) - + prepared_path = Path(tmp_dir) / "prepared" cfg = DictDefault( { @@ -535,15 +542,15 @@ def test_load_dataset_with_str_json_data(self, tokenizer): "dataset_num_proc": 4, } ) - + with patch( "axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path) ): dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg) - + assert len(dataset) == 2 assert "input_ids" in dataset.features assert "attention_mask" in dataset.features assert "labels" in dataset.features - - assert len(dataset[0]["input_ids"]) > 0 \ No newline at end of file + + assert len(dataset[0]["input_ids"]) > 0