Skip to content

Commit

Permalink
Feat(mysql): support STRAIGHT_JOIN (#3623)
Browse files Browse the repository at this point in the history
* Feat(mysql): support STRAIGHT_JOIN

* PR feedback
  • Loading branch information
georgesittas authored Jun 11, 2024
1 parent caa3051 commit c49cefa
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 4 deletions.
8 changes: 8 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[

klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms

if enum not in ("", "doris", "mysql"):
klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | {
TokenType.STRAIGHT_JOIN,
}
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.STRAIGHT_JOIN,
}

if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,9 @@ def join_sql(self, expression: exp.Join) -> str:

return f", {this_sql}"

op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
if op_sql != "STRAIGHT_JOIN":
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"

return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"

def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
Expand Down
7 changes: 4 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,11 +588,12 @@ class Parser(metaclass=_Parser):
}

JOIN_KINDS = {
TokenType.ANTI,
TokenType.CROSS,
TokenType.INNER,
TokenType.OUTER,
TokenType.CROSS,
TokenType.SEMI,
TokenType.ANTI,
TokenType.STRAIGHT_JOIN,
}

JOIN_HINTS: t.Set[str] = set()
Expand Down Expand Up @@ -3106,7 +3107,7 @@ def _parse_join(
index = self._index
method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN)

if not skip_join_token and not join:
self._retreat(index)
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ class TokenType(AutoName):
SORT_BY = auto()
START_WITH = auto()
STORAGE_INTEGRATION = auto()
STRAIGHT_JOIN = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TAG = auto()
Expand Down Expand Up @@ -765,6 +766,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SOME": TokenType.SOME,
"SORT BY": TokenType.SORT_BY,
"START WITH": TokenType.START_WITH,
"STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN,
"TABLE": TokenType.TABLE,
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
Expand Down
7 changes: 7 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def test_duckdb(self):
"WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1",
)

self.validate_all(
"SELECT straight_join",
write={
"duckdb": "SELECT straight_join",
"mysql": "SELECT `straight_join`",
},
)
self.validate_all(
"SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
read={
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_ddl(self):
)

def test_identity(self):
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2")
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
Expand Up @@ -872,3 +872,4 @@ SELECT name
SELECT copy
SELECT rollup
SELECT unnest
SELECT * FROM a STRAIGHT_JOIN b

0 comments on commit c49cefa

Please sign in to comment.