From ff3dabc75f9a03627caa988b85f88be04a6c70a4 Mon Sep 17 00:00:00 2001 From: tobymao Date: Fri, 14 Jun 2024 16:53:57 -0700 Subject: [PATCH] feat(tsql): index on closes #3658 --- sqlglot/expressions.py | 1 + sqlglot/generator.py | 4 +++- sqlglot/parser.py | 3 +++ tests/dialects/test_tsql.py | 7 +++++++ 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9ce84c7676..b3e8fbb772 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2127,6 +2127,7 @@ class IndexParameters(Expression): "partition_by": False, "tablespace": False, "where": False, + "on": False, } diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 9ef556d917..e3ff09670e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1298,8 +1298,10 @@ def indexparameters_sql(self, expression: exp.IndexParameters) -> str: with_storage = f" WITH ({with_storage})" if with_storage else "" tablespace = self.sql(expression, "tablespace") tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else "" + on = self.sql(expression, "on") + on = f" ON {on}" if on else "" - return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}" + return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}{on}" def index_sql(self, expression: exp.Index) -> str: unique = "UNIQUE " if expression.args.get("unique") else "" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9108b3e6cf..ab961c7be0 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -3188,6 +3188,8 @@ def _parse_index_params(self) -> exp.IndexParameters: ) where = self._parse_where() + on = self._parse_field() if self._match(TokenType.ON) else None + return self.expression( exp.IndexParameters, using=using, @@ -3197,6 +3199,7 @@ def _parse_index_params(self) -> exp.IndexParameters: where=where, with_storage=with_storage, tablespace=tablespace, + on=on, ) def _parse_index( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7455650910..9240e64278 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -8,6 +8,13 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON X([y])" + ) + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON PRIMARY" + ) + self.assertEqual( annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), "SELECT 1 WHERE EXISTS(SELECT 1)",