Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion resources_servers/text_to_sql/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Each data sample should include:
},
{
"role": "user",
"content": "DIALECT: postgresql\n\nDATABASE CONTEXT:\nCREATE TABLE users (id SERIAL PRIMARY KEY, name VARCHAR(100));\nINSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');\n\nQUESTION:\nList all user names ordered alphabetically"
"content": "<DIALECT>postgresql</DIALECT>\n\n<DATABASE_CONTEXT>\nCREATE TABLE users (id SERIAL PRIMARY KEY, name VARCHAR(100));\nINSERT INTO users VALUES (1, 'Alice'), (2, 'Bob');\n</DATABASE_CONTEXT>\n\n<QUESTION>\nList all user names ordered alphabetically\n</QUESTION>"
}
]
},
Expand Down
27 changes: 4 additions & 23 deletions resources_servers/text_to_sql/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,6 @@ def extract_sql_from_response(text: str) -> Optional[str]:
return None


def _extract_question_text(params: NeMoGymResponseCreateParamsNonStreaming) -> str:
"""Extract the question text from the last user message."""
last_text: Optional[str] = None
for m in params.input or []:
if getattr(m, "role", None) == "user":
c = getattr(m, "content", None)
if isinstance(c, str):
last_text = c
return (last_text or "").strip()


def _extract_dialect_from_prompt(text: str) -> Optional[str]:
"""Extract SQL dialect from a structured prompt."""
if not text:
return None
match = re.search(r"^\s*DIALECT:\s*([a-zA-Z0-9_+-]+)\s*$", text, re.MULTILINE)
return match.group(1) if match else None


def _normalize_dialect(dialect: Optional[str]) -> Optional[str]:
if not dialect:
return None
Expand Down Expand Up @@ -171,7 +152,7 @@ class TextToSqlRunRequest(BaseRunRequest):
sql: str # Ground truth SQL query (required)
sql_dialect: str # SQL dialect: mysql, postgresql, sqlite (required)
sql_context: str = "" # Database schema (CREATE/INSERT statements)
sql_prompt: Optional[str] = None # Natural language question (optional, extracted from input if not provided)
sql_prompt: str # Natural language question (required)
metadata: Optional[dict[str, Any]] = None


Expand All @@ -196,7 +177,7 @@ class TextToSqlVerifyResponse(BaseVerifyResponse):
extracted_sql: Optional[str] = None
sql_dialect: str # SQL dialect used
sql_context: str # Database schema provided
sql_prompt: Optional[str] = None # May be extracted from input
sql_prompt: str # Natural language question
judge_passed: bool = False
failure_reason: Optional[FailureCode] = None
judge_evaluations: list[JudgeEvaluation] = []
Expand Down Expand Up @@ -235,8 +216,8 @@ async def verify(self, body: TextToSqlVerifyRequest) -> TextToSqlVerifyResponse:
if sql_dialect not in SUPPORTED_DIALECTS:
raise ValueError(f"Unsupported SQL dialect '{sql_dialect}'. Supported: {sorted(SUPPORTED_DIALECTS)}")

# Extract question from request field or from user message
sql_prompt = body.sql_prompt or _extract_question_text(body.responses_create_params)
# sql_prompt is a required field, validated by Pydantic
sql_prompt = body.sql_prompt

# Get model output text directly from response
generated = body.response.output_text or ""
Expand Down
Loading