Skip to content

Commit

Permalink
Fixes 5082 tsvector (#5104)
Browse files Browse the repository at this point in the history
* Add match operator for tsvector and jsonb

Shared with Jsonb and TsVector type
Add MatchOperatorExpressionMixin for SqlExp and annotations
Add PostgreSqlTypeResolver matchOperatorExpression boolean result types

Operator resolver type

e.g || concat

* Add tests

Integration tests for TsVector
|| concat test

* add ts_rank function

Useful text search function

ts_rank returns a REAL type

* Add function websearch_to_tsquery

Seems to be commonly used for searches - returns ts_query type that is a text type
Tested locally with

`
SELECT *
FROM Recipes WHERE search @@ websearch_to_tsquery('english', ?);
`
  • Loading branch information
griffio committed Apr 4, 2024
1 parent 8b4f448 commit 5628209
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
UUID(ClassName("java.util", "UUID")),
NUMERIC(ClassName("java.math", "BigDecimal")),
JSON(STRING),
TSVECTOR(STRING),
;

override fun prepareStatementBinder(columnIndex: CodeBlock, value: CodeBlock): CodeBlock {
Expand All @@ -36,7 +37,7 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
)

NUMERIC -> CodeBlock.of("bindBigDecimal(%L, %L)\n", columnIndex, value)
JSON -> CodeBlock.of(
JSON, TSVECTOR -> CodeBlock.of(
"bindObject(%L, %L, %M)\n",
columnIndex,
value,
Expand All @@ -53,7 +54,7 @@ internal enum class PostgreSqlType(override val javaType: TypeName) : DialectTyp
BIG_INT -> "$cursorName.getLong($columnIndex)"
DATE, TIME, TIMESTAMP, TIMESTAMP_TIMEZONE, INTERVAL, UUID -> "$cursorName.getObject<%T>($columnIndex)"
NUMERIC -> "$cursorName.getBigDecimal($columnIndex)"
JSON -> "$cursorName.getString($columnIndex)"
JSON, TSVECTOR -> "$cursorName.getString($columnIndex)"
},
javaType,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
jsonDataType != null -> PostgreSqlType.JSON
booleanDataType != null -> BOOLEAN
blobDataType != null -> BLOB
tsvectorDataType != null -> PostgreSqlType.TSVECTOR
else -> throw IllegalArgumentException("Unknown kotlin type for sql type ${this.text}")
},
)
Expand Down Expand Up @@ -178,6 +179,10 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
"regexp_count", "regexp_instr" -> IntermediateType(INTEGER)
"regexp_like" -> IntermediateType(BOOLEAN)
"regexp_replace", "regexp_substr" -> IntermediateType(TEXT)
"to_tsquery" -> IntermediateType(TEXT)
"to_tsvector" -> IntermediateType(PostgreSqlType.TSVECTOR)
"ts_rank" -> encapsulatingType(exprList, REAL, TEXT)
"websearch_to_tsquery" -> IntermediateType(TEXT)
else -> null
}

Expand Down Expand Up @@ -240,6 +245,7 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
PostgreSqlType.TIMESTAMP_TIMEZONE,
PostgreSqlType.TIMESTAMP,
PostgreSqlType.JSON,
PostgreSqlType.TSVECTOR,
)
}
}
Expand Down Expand Up @@ -269,6 +275,9 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
IntermediateType(PostgreSqlType.JSON)
}
}
matchOperatorExpression != null -> {
IntermediateType(BOOLEAN)
}
else -> parentResolver.resolvedType(this)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ type_name ::= (
date_data_type |
boolean_data_type |
json_data_type |
blob_data_type
blob_data_type |
tsvector_data_type
) [ '[]' ] {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlTypeNameImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlTypeName"
Expand Down Expand Up @@ -210,6 +211,8 @@ json_data_type ::= 'JSON' | 'JSONB'

blob_data_type ::= 'BYTEA'

tsvector_data_type ::= 'TSVECTOR'

interval_expression ::= 'INTERVAL' string_literal

with_clause_auxiliary_stmt ::= {compound_select_stmt} | delete_stmt_limited | insert_stmt | update_stmt_limited {
Expand Down Expand Up @@ -358,7 +361,7 @@ compound_select_stmt ::= [ {with_clause} ] {select_stmt} ( {compound_operator}
override = true
}

extension_expr ::= array_agg_stmt| string_agg_stmt | json_expression | boolean_literal | boolean_not_expression | window_function_expr {
extension_expr ::= match_operator_expression | array_agg_stmt| string_agg_stmt | json_expression | boolean_literal | boolean_not_expression | window_function_expr {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlExtensionExprImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlExtensionExpr"
override = true
Expand All @@ -378,7 +381,13 @@ json_expression ::= {column_expr} ( jsona_binary_operator | jsonb_binary_operato
}
jsona_binary_operator ::= '->' | '->>' | '#>' | '#>>'
jsonb_binary_operator ::= '#-'
jsonb_boolean_operator ::= '@@' | '@>' | '<@' | '@?' | '??|' | '??&' | '??'
jsonb_boolean_operator ::= '@>' | '<@' | '@?' | '??|' | '??&' | '??'
match_operator ::= '@@'

match_operator_expression ::= ( {function_expr} | {column_expr} ) match_operator <<expr '-1'>> {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.MatchOperatorExpressionMixin"
pin = 2
}

extension_stmt ::= create_sequence_stmt | copy_stdin | truncate_stmt | set_stmt | drop_sequence_stmt | alter_sequence_stmt | create_extension_stmt | drop_extension_stmt | alter_extension_stmt {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlExtensionStmtImpl"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlMatchOperatorExpression
import com.alecstrong.sql.psi.core.SqlAnnotationHolder
import com.alecstrong.sql.psi.core.psi.SqlBinaryExpr
import com.alecstrong.sql.psi.core.psi.SqlColumnDef
import com.alecstrong.sql.psi.core.psi.SqlColumnName
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.alecstrong.sql.psi.core.psi.SqlExpr
import com.intellij.lang.ASTNode

/**
* The "@@" match operator is used by TsVector and Jsonb
* The type annotation is performed here for both types
* For other json operators see JsonExpressionMixin
*/
internal abstract class MatchOperatorExpressionMixin(node: ASTNode) :
SqlCompositeElementImpl(node),
SqlBinaryExpr,
PostgreSqlMatchOperatorExpression {

override fun annotate(annotationHolder: SqlAnnotationHolder) {
val columnType = ((firstChild.firstChild.reference?.resolve() as? SqlColumnName)?.parent as? SqlColumnDef)?.columnType?.typeName?.text
when {
columnType == null -> super.annotate(annotationHolder)
columnType == "JSONB" -> super.annotate(annotationHolder)
columnType == "JSON" -> annotationHolder.createErrorAnnotation(firstChild.firstChild, "Left side of jsonb expression must be a jsonb column.")
columnType != "TSVECTOR" -> annotationHolder.createErrorAnnotation(firstChild.firstChild, "Left side of match expression must be a tsvector column.")
}
super.annotate(annotationHolder)
}
override fun getExprList(): List<SqlExpr> {
return children.filterIsInstance<SqlExpr>()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PostgreSqlFixturesTest(name: String, fixtureRoot: File) : FixturesTest(nam
"BLOB" to "TEXT",
"id TEXT GENERATED ALWAYS AS (2) UNIQUE NOT NULL" to "id TEXT GENERATED ALWAYS AS (2) STORED UNIQUE NOT NULL",
"'(', ')', ',', '.', <binary like operator real>, BETWEEN or IN expected, got ','"
to "'#-', '(', ')', ',', '.', <binary like operator real>, <jsona binary operator real>, <jsonb boolean operator real>, BETWEEN or IN expected, got ','",
to "'#-', '(', ')', ',', '.', <binary like operator real>, <jsona binary operator real>, <jsonb boolean operator real>, '@@', BETWEEN or IN expected, got ','",
)

override fun setupDialect() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ SELECT
data @@ '$.b[*] > 0'
FROM myTable;

SELECT data ->> 'a', datab -> 'b', data #> '{aa}', datab #>> '{bb}', datab || datab, datab - 'b', datab - 1
SELECT data ->> 'a', datab -> 'b', data #> '{aa}', datab #>> '{bb}', datab || datab, datab - 'b', datab - 1, datab @@ '$.b[*] > 0'
FROM myTable;
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE t1 (
c1 TSVECTOR
);

INSERT INTO t1 (c1) VALUES ('the rain in spain falls mainly on the plains') ;

SELECT c1 @@ 'fail'
FROM t1;
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
CREATE TABLE search(
content TSVECTOR NOT NULL
);

insertLiteral:
INSERT INTO search (content) VALUES (?);

contains:
SELECT content @@ ?
FROM search;

search:
SELECT *
FROM search WHERE content @@ ?;

tsQuery:
SELECT to_tsquery(?);

tsVector:
SELECT to_tsvector(?);

concat:
SELECT content || to_tsvector(?)
FROM search;

rank:
SELECT ts_rank(content, ?)
FROM search;
Original file line number Diff line number Diff line change
Expand Up @@ -782,4 +782,50 @@ class PostgreSqlTest {
assertThat(first().datab).isEqualTo("""{"b": 2}""")
}
}

@Test
fun testSelectTsVectorSearch() {
database.textSearchQueries.insertLiteral("the rain in spain")
with(database.textSearchQueries.search("rain").executeAsList()) {
assertThat(first()).isEqualTo("'in' 'rain' 'spain' 'the'")
}
}

@Test
fun testSelectTsVectorContains() {
database.textSearchQueries.insertLiteral("the rain in spain")
with(database.textSearchQueries.contains("rain").executeAsList()) {
assertThat(first()).isEqualTo(true)
}
}

@Test
fun testSelectTsQuery() {
with(database.textSearchQueries.tsQuery("the & rain & spain'").executeAsList()) {
assertThat(first()).isEqualTo("'rain' & 'spain'")
}
}

@Test
fun testSelectTsVector() {
with(database.textSearchQueries.tsVector("the rain in spain").executeAsList()) {
assertThat(first()).isEqualTo("'rain':2 'spain':4")
}
}

@Test
fun testContactTsVector() {
database.textSearchQueries.insertLiteral("the rain in spain")
with(database.textSearchQueries.concat("falls mainly on the plains").executeAsList()) {
assertThat(first()).isEqualTo("'fall':1 'in' 'main':2 'plain':5 'rain' 'spain' 'the'")
}
}

@Test
fun testContactTsVectorRank() {
database.textSearchQueries.insertLiteral("the rain in spain")
with(database.textSearchQueries.rank("rain | plain").executeAsList()) {
assertThat(first()).isEqualTo("0.030396355")
}
}
}

0 comments on commit 5628209

Please sign in to comment.