Skip to content

Commit

Permalink
Fix MSSQL string literal escape (#1193)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvorisek authored Apr 13, 2024
1 parent 2dd317b commit b0238b1
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 25 deletions.
4 changes: 3 additions & 1 deletion src/Persistence/Sql.php
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ protected function initPersistence(Model $model): void
*/
public function expr(Model $model, string $template, array $arguments = []): Expression
{
$quotedTokenRegex = $this->getConnection()->expr()::QUOTED_TOKEN_REGEX;

preg_replace_callback(
'~(?!\[\w*\])' . Expression::QUOTED_TOKEN_REGEX . '\K|\[\w*\]|\{\w*\}~',
'~(?!\[\w*\])' . $quotedTokenRegex . '\K|\[\w*\]|\{\w*\}~',
static function ($matches) use ($model, &$arguments) {
if ($matches[0] === '') {
return '';
Expand Down
8 changes: 4 additions & 4 deletions src/Persistence/Sql/Expression.php
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ abstract class Expression implements Expressionable, \ArrayAccess

public const QUOTED_TOKEN_REGEX = <<<'EOF'
(?:(?sx)
'(?:[^'\\]+|\\.|'')*+'
|"(?:[^"\\]+|\\.|"")*+"
|`(?:[^`\\]+|\\.|``)*+`
|\[(?:[^\]\\]+|\\.|\]\])*+\]
'(?:[^']+|'')*+'
|"(?:[^"]+|"")*+"
|`(?:[^`]+|``)*+`
|\[(?:[^\]]+|\]\])*+\]
|(?:--|\#)[^\r\n]*+
|/\*(?:[^*]+|\*(?!/))*+\*/
)
Expand Down
6 changes: 4 additions & 2 deletions src/Persistence/Sql/Mssql/ExpressionTrait.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ protected function escapeStringLiteral(string $value): string
}

if ($v !== '') {
$parts[] = '\'' . str_replace('\'', '\'\'', $v) . '\'';
foreach (mb_str_split($v, 4000) as $v2) {
$parts[] = '\'' . str_replace('\'', '\'\'', $v2) . '\'';
}
}
}

Expand All @@ -49,7 +51,7 @@ protected function escapeStringLiteral(string $value): string
return reset($parts);
};

return $buildConcatSqlFx($parts);
return str_replace(["\\\n", "\\\r"], ["\\\\\n\n", "\\\\\r"], $buildConcatSqlFx($parts));
}

#[\Override]
Expand Down
11 changes: 11 additions & 0 deletions src/Persistence/Sql/Mysql/Expression.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,16 @@ class Expression extends BaseExpression
{
use ExpressionTrait;

public const QUOTED_TOKEN_REGEX = <<<'EOF'
(?:(?sx)
'(?:[^'\\]+|\\.|'')*+'
|"(?:[^"\\]+|\\.|"")*+"
|`(?:[^`]+|``)*+`
|\[(?:[^\]]+|\]\])*+\]
|(?:--|\#)[^\r\n]*+
|/\*(?:[^*]+|\*(?!/))*+\*/
)
EOF;

protected string $identifierEscapeChar = '`';
}
2 changes: 2 additions & 0 deletions src/Persistence/Sql/Mysql/Query.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Query extends BaseQuery
{
use ExpressionTrait;

public const QUOTED_TOKEN_REGEX = Expression::QUOTED_TOKEN_REGEX;

protected string $identifierEscapeChar = '`';
protected string $expressionClass = Expression::class;

Expand Down
6 changes: 4 additions & 2 deletions src/Schema/TestCase.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Atk4\Data\Model;
use Atk4\Data\Persistence;
use Atk4\Data\Persistence\Sql\Expression;
use Atk4\Data\Persistence\Sql\Sqlite\Expression as SqliteExpression;
use Atk4\Data\Reference;
use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Platforms\AbstractPlatform;
Expand Down Expand Up @@ -97,8 +98,9 @@ protected function logQuery(string $sql, array $params, array $types): void

// needed for \Atk4\Data\Persistence\Sql\*\ExpressionTrait::updateRenderBeforeExecute() fixes
$i = 0;
$quotedTokenRegex = $this->getConnection()->expr()::QUOTED_TOKEN_REGEX;
$sql = preg_replace_callback(
'~' . Expression::QUOTED_TOKEN_REGEX . '\K|(\?)|cast\((\?|:\w+) as (BOOLEAN|INTEGER|BIGINT|DOUBLE PRECISION|BINARY_DOUBLE)\)|\((\?|:\w+) \+ 0\.00\)~',
'~' . $quotedTokenRegex . '\K|(\?)|cast\((\?|:\w+) as (BOOLEAN|INTEGER|BIGINT|DOUBLE PRECISION|BINARY_DOUBLE)\)|\((\?|:\w+) \+ 0\.00\)~',
static function ($matches) use (&$types, &$params, &$i) {
if ($matches[0] === '') {
return '';
Expand Down Expand Up @@ -162,7 +164,7 @@ private function convertSqlFromSqlite(string $sql): string
$platform = $this->getDatabasePlatform();

$convertedSql = preg_replace_callback(
'~(?![\'`])' . Expression::QUOTED_TOKEN_REGEX . '\K|' . Expression::QUOTED_TOKEN_REGEX . '|:(\w+)~',
'~(?![\'`])' . SqliteExpression::QUOTED_TOKEN_REGEX . '\K|' . SqliteExpression::QUOTED_TOKEN_REGEX . '|:(\w+)~',
static function ($matches) use ($platform) {
if ($matches[0] === '') {
return '';
Expand Down
2 changes: 1 addition & 1 deletion tests/Persistence/Sql/ExpressionTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public static function provideNoTemplatingInSqlStringCases(): iterable
'\'{}\'',
'\'{{}}\'',
'\'[a]\'',
'\'\\\'[]\'',
'\'\[]\'',
'\'\\\[]\'',
'\'[\'\']\'',
'\'\'\'[]\'',
Expand Down
107 changes: 92 additions & 15 deletions tests/Persistence/Sql/WithDb/SelectTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,91 @@ public function testExecuteException(): void
}
}

public function testQuotedTokenRegexConstant(): void
{
$hasBackslashSupport = $this->getDatabasePlatform() instanceof MySQLPlatform;

self::assertSame(
'(?:(?sx)' . "\n"
. ' \'(?:[^\'' . ($hasBackslashSupport ? '\\\\' : '') . ']+' . ($hasBackslashSupport ? '|\\\.' : '') . '|\'\')*+\'' . "\n"
. ' |"(?:[^"' . ($hasBackslashSupport ? '\\\\' : '') . ']+' . ($hasBackslashSupport ? '|\\\.' : '') . '|"")*+"' . "\n"
. ' |`(?:[^`]+|``)*+`' . "\n"
. ' |\[(?:[^\]]+|\]\])*+\]' . "\n"
. ' |(?:--|\#)[^\r\n]*+' . "\n"
. ' |/\*(?:[^*]+|\*(?!/))*+\*/' . "\n"
. ')',
$this->e()::QUOTED_TOKEN_REGEX
);

self::assertSame($this->e()::QUOTED_TOKEN_REGEX, $this->q()::QUOTED_TOKEN_REGEX);

$sqlTwoEscape = '\'\'\'\'';
$sqlBackslashEscape = '\'\\\'-- \'';
if ($this->getDatabasePlatform() instanceof OraclePlatform) {
$sqlBackslashEscape .= "\n/**/";
}

$query = $this->q()->field($this->e($sqlTwoEscape));
self::assertSame('\'', $query->getOne());

$query = $this->q()->field($this->e($sqlBackslashEscape));
self::assertSame($hasBackslashSupport ? '\'-- ' : '\\', $query->getOne());

foreach (['"', '`'] as $chr) {
if ($chr === '`' && ($this->getDatabasePlatform() instanceof PostgreSQLPlatform || $this->getDatabasePlatform() instanceof SQLServerPlatform || $this->getDatabasePlatform() instanceof OraclePlatform)) {
continue;
}

$replaceFx = static fn ($v) => str_replace('\'', $chr, $v);
$needsExplicitAs = $chr === '"' && $this->getDatabasePlatform() instanceof MySQLPlatform;

if ($chr !== '"' || !$this->getDatabasePlatform() instanceof OraclePlatform) {
$query = $this->q()->field($this->e('\'x\' ' . ($needsExplicitAs ? 'as ' : '') . $replaceFx($sqlTwoEscape)));
self::assertSame([$chr => 'x'], $query->getRow());
}

$query = $this->q()->field($this->e('\'x\' ' . ($needsExplicitAs ? 'as ' : '') . $replaceFx($sqlBackslashEscape)));
self::assertSame([$hasBackslashSupport && $chr === '"' ? $chr . '-- ' : '\\' => 'x'], $query->getRow());
}

if (!($this->getDatabasePlatform() instanceof MySQLPlatform || $this->getDatabasePlatform() instanceof PostgreSQLPlatform || $this->getDatabasePlatform() instanceof OraclePlatform)) {
$query = $this->q()->field($this->e('\'x\' [a*b]'));
self::assertSame(['a*b' => 'x'], $query->getRow());

$replaceFx = static fn ($v) => str_replace('\'', ']', preg_replace('~^\'~', '[a*b', $v));

if ($this->getDatabasePlatform() instanceof SQLServerPlatform) {
$query = $this->q()->field($this->e('\'x\' ' . $replaceFx($sqlTwoEscape)));
self::assertSame(['a*b]' => 'x'], $query->getRow());
}

$query = $this->q()->field($this->e('\'x\' ' . $replaceFx($sqlBackslashEscape)));
self::assertSame(['a*b\\' => 'x'], $query->getRow());
}
}

public function testEscapeStringLiteral(): void
{
// TODO full binary support
$maxOrd = $this->getDatabasePlatform() instanceof PostgreSQLPlatform
|| $this->getDatabasePlatform() instanceof SQLServerPlatform
? 0x7F
: 0xFF;

$str = '';
for ($i = 0; $i <= 0x7F; ++$i) {
$str .= chr($i);
for ($i = 0; $i <= $maxOrd; ++$i) {
$chr = chr($i);
for ($j = 1; $j <= 5; ++$j) { // TODO PostgreSQL/MSSQL is failing with "$j <= 1"
$str .= str_repeat($chr, $j) . '_';
for ($k = 1; $k <= 5; ++$k) {
$str .= str_repeat('\\', $k) . str_repeat($chr, $j) . '_';
}
}
}

// Oracle string literal is limited to 4000 bytes
if ($this->getDatabasePlatform() instanceof OraclePlatform) {
$str = substr($str, 0, 4000);
}

// PostgreSQL does not support \0 character
Expand All @@ -497,22 +577,19 @@ public function testEscapeStringLiteral(): void
? str_replace("\0", '-', $str)
: $str;

$dummyExpression = $this->getConnection()->expr();
$dummyExpression = $this->e();
$strSql = \Closure::bind(static fn () => $dummyExpression->escapeStringLiteral($str2), null, Expression::class)();
$query = $this->getConnection()->dsql()
->field($this->getConnection()->expr($strSql));
$res = $query->getOne();
self::assertSame(bin2hex($str2), bin2hex($res));

$strSql = \Closure::bind(static fn () => $dummyExpression->escapeStringLiteral($str), null, Expression::class)();
$query = $this->getConnection()->dsql()
->field($this->getConnection()->expr($strSql));
if ($this->getDatabasePlatform() instanceof PostgreSQLPlatform) {
$query = $this->q()->field($this->e($strSql));
self::assertSame(bin2hex($str2), bin2hex($query->getOne()));

if ($str2 !== $str) {
$strSql = \Closure::bind(static fn () => $dummyExpression->escapeStringLiteral($str), null, Expression::class)();
$query = $this->q()->field($this->e($strSql));

$this->expectException(ExecuteException::class);
$this->expectExceptionMessage('Character not in repertoire');
$query->getOne();
}
$res = $query->getOne();
self::assertSame(bin2hex($str), bin2hex($res));
}

public function testUtf8mb4Support(): void
Expand Down Expand Up @@ -552,7 +629,7 @@ public function testImportAndAutoincrement(): void
$pk = 'myid';
if ($this->getDatabasePlatform() instanceof MySQLPlatform) {
self::assertFalse($this->getConnection()->inTransaction());
$this->getConnection()->expr('analyze table {}', [$table])->executeStatement();
$this->e('analyze table {}', [$table])->executeStatement();
$query = $this->q()->table('INFORMATION_SCHEMA.TABLES')
->field($this->e('{} - 1', ['AUTO_INCREMENT']))
->where('TABLE_NAME', $table);
Expand Down

0 comments on commit b0238b1

Please sign in to comment.