diff --git a/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.test.ts b/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.test.ts index 9d4b064b35f61..523b2dc9a68e1 100644 --- a/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.test.ts +++ b/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.test.ts @@ -215,4 +215,19 @@ describe('correctCommonEsqlMistakes', () => { | STATS success_rate = AVG(successful)`, }); }); + + it('escapes special characters in column names', () => { + expectQuery({ + input: `FROM "custom-test" +| STATS + count = COUNT(*), + min = MIN("Total Bytes"), + max = MAX("Total Bytes"), + avg = AVG("Total Bytes"), + sum = SUM("Total Bytes") +`, + expectedOutput: `FROM "custom-test" +| STATS count = COUNT(*), min = MIN(\`Total Bytes\`), max = MAX(\`Total Bytes\`), avg = AVG(\`Total Bytes\`), sum = SUM(\`Total Bytes\`)`, + }); + }); }); diff --git a/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.ts b/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.ts index 0d4d13b61152a..00799d063f343 100644 --- a/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.ts +++ b/x-pack/platform/plugins/shared/inference/common/tasks/nl_to_esql/non_ast/correct_common_esql_mistakes.ts @@ -5,6 +5,11 @@ * 2.0. */ +import { scalarFunctionDefinitions } from '@kbn/esql-validation-autocomplete/src/definitions/generated/scalar_functions'; +import { aggregationFunctionDefinitions } from '@kbn/esql-validation-autocomplete/src/definitions/generated/aggregation_functions'; +import type { FunctionDefinition } from '@kbn/esql-validation-autocomplete'; +import { memoize } from 'lodash'; + const STRING_DELIMITER_TOKENS = ['`', "'", '"']; const ESCAPE_TOKEN = '\\\\'; @@ -94,6 +99,62 @@ function removeColumnQuotesAndEscape(column: string) { return '`' + plainColumnIdentifier + '`'; } +const getFunctionDefinitionMap = memoize(() => { + const functionDefinitionMap = new Map(); + const allFunctionDefinitions = [...scalarFunctionDefinitions, ...aggregationFunctionDefinitions]; + allFunctionDefinitions.forEach((definition) => { + const functionName = definition.name.toLowerCase(); + if (!functionDefinitionMap.has(functionName)) { + functionDefinitionMap.set(functionName, definition); + } + }); + return functionDefinitionMap; +}); + +/** + * Replaces quotes for fields in function argument if present. + * @example + * Example 1: Without quotes + * escapeColumnsInFunctions('MIN(total_bytes)'); // 'MIN(total_bytes)' + * + * @example + * Example 2: With quotes + * escapeColumnsInFunctions('MIN("Total Bytes")'); // 'MIN(`Total Bytes`)' + */ +function escapeColumnsInFunctions(string: string): string { + const regex = /([A-Za-z_]+)\s*\(([^()]*?)\)/g; + + return string.replace(regex, (match: string, functionName: string, args: string) => { + const functionDefinition = getFunctionDefinitionMap().get(functionName.toLowerCase()); + if (!functionDefinition) { + // function definition not found, return the original match + return match; + } + + const escapedArgs = args.length + ? args + .split(',') + .map((arg, index) => { + const trimmedArg = arg.trim(); + // Only escape field names + const paramName = functionDefinition.signatures[0].params[index]?.name; + if (paramName !== 'field' && paramName !== 'number') { + // It should be just a field, but some functions like SUM and AVG have a "number" name 🤷‍♂️ + return trimmedArg; + } + // If the string is not wrapped in quotes, return it as is + if (!trimmedArg.match(/^["'].*["']$/)) { + return trimmedArg; + } + return removeColumnQuotesAndEscape(trimmedArg); + }) + .join(', ') + : args; + + return `${functionName}(${escapedArgs})`; + }); +} + function replaceAsKeywordWithAssignments(command: string) { return command.replaceAll(/^STATS\s*(.*)/g, (__, statsOperations: string) => { return `STATS ${statsOperations.replaceAll( @@ -113,10 +174,12 @@ function escapeColumns(line: string) { const escapedBody = split(body.trim(), ',') .map((statement) => { const [lhs, rhs] = split(statement, '='); + if (!rhs) { return lhs; } - return `${removeColumnQuotesAndEscape(lhs)} = ${rhs}`; + const escapedRhs = escapeColumnsInFunctions(rhs); + return `${removeColumnQuotesAndEscape(lhs)} = ${escapedRhs}`; }) .join(', ');