From 8c929aeee4756296ea302b0b1b18fc8a1338bb10 Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Thu, 14 Aug 2025 09:02:59 -0700 Subject: [PATCH] Enable system_message for openai prompt format Signed-off-by: Igor Gitman --- nemo_skills/inference/generate.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index e6134f16fb..c517568157 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -178,7 +178,6 @@ def _post_init_validate_params(self): if self.prompt_format == "openai": assert self.prompt_config is None, "prompt_config is not supported for prompt_format == 'openai'" assert self.prompt_template is None, "prompt_template is not supported for prompt_format == 'openai'" - assert self.system_message is None, "system_message is not supported for prompt_format == 'openai'" else: assert self.prompt_config is not None, "prompt_config is required when prompt_format == 'ns'" for param, default_value in self._get_disallowed_params(): @@ -305,7 +304,7 @@ def log_example_prompt(self, data): if self.cfg.prompt_format == "openai": # print the prompt in openai format - LOG.info("Example prompt in OpenAI format: \nData dictionary: %s", data_point) + LOG.info("Example prompt in OpenAI format: %s", self.fill_prompt(data_point, data)) return if self.cfg.multi_turn_key is None: @@ -388,6 +387,11 @@ def fill_prompt(self, data_point, data): if self.cfg.prompt_format == "openai": if self.cfg.prompt_suffix: data_point["messages"][-1]["content"] += self.cfg.prompt_suffix + if self.cfg.system_message: + if data_point["messages"][0]["role"] != "system": + data_point["messages"].insert(0, {"role": "system", "content": self.cfg.system_message}) + else: + data_point["messages"][0]["content"] = self.cfg.system_message return data_point["messages"] total_code_executions_in_prompt = self.cfg.total_code_executions_in_prompt