diff --git a/build.mill b/build.mill index 592bbc31..ae3e7f46 100644 --- a/build.mill +++ b/build.mill @@ -54,6 +54,8 @@ trait CommonBase extends ScalaModule with PublishModule with ScalafixModule { co ivy"org.postgresql:postgresql:42.6.0", ivy"org.testcontainers:mysql:1.19.1", ivy"mysql:mysql-connector-java:8.0.33", + ivy"org.testcontainers:mssqlserver:1.19.1", + ivy"com.microsoft.sqlserver:mssql-jdbc:12.8.1.jre11", ivy"com.zaxxer:HikariCP:5.1.0" ) diff --git a/docs/reference.md b/docs/reference.md index 51f024b8..2b96e5ba 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -236,7 +236,10 @@ dbClient.transaction { db => LocalDate.parse("2000-01-01") ) ) - assert(generatedKeys == Seq(4, 5)) + if (!this.isInstanceOf[MsSqlSuite]) + assert(generatedKeys == Seq(4, 5)) + else + assert(generatedKeys == Seq(5)) db.run(Buyer.select) ==> List( Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), @@ -1997,7 +2000,6 @@ Buyer.select .leftJoin(ShippingInfo)(_.id `=` _.buyerId) .map { case (b, si) => (b.name, si.map(_.shippingDate)) } .sortBy(_._2) - .nullsFirst ``` @@ -2006,7 +2008,7 @@ Buyer.select SELECT buyer0.name AS res_0, shipping_info1.shipping_date AS res_1 FROM buyer buyer0 LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) - ORDER BY res_1 NULLS FIRST + ORDER BY res_1 ``` @@ -3395,7 +3397,7 @@ Purchase.delete(_ => true) * ```sql - DELETE FROM purchase WHERE ? + DELETE FROM purchase ``` @@ -4003,7 +4005,7 @@ Product.select ## UpdateJoin -`UPDATE` queries that use `JOIN`s +Basic `UPDATE` queries ### UpdateJoin.join ScalaSql supports performing `UPDATE`s with `FROM`/`JOIN` clauses using the @@ -6951,7 +6953,7 @@ Select.delete(_ => true) * ```sql - DELETE FROM "select" WHERE ? + DELETE FROM "select" ``` @@ -9774,7 +9776,7 @@ Expr(Bytes("Hello")).contains(Bytes("ll")) ## ExprMathOps -Math operations; supported by H2/Postgres/MySql, not supported by Sqlite +Math operations; supported by H2/Postgres/MySql/MsSql, not supported by Sqlite ### ExprMathOps.power @@ -10111,7 +10113,7 @@ val value = DataTypes[Sc]( myInt = 12345678, myBigInt = 12345678901L, myDouble = 3.14, - myBoolean = true, + myBoolean = false, myLocalDate = LocalDate.parse("2023-12-20"), myLocalTime = LocalTime.parse("10:15:30"), myLocalDateTime = LocalDateTime.parse("2011-12-03T10:15:30"), @@ -10122,6 +10124,23 @@ val value = DataTypes[Sc]( myEnum = MyEnum.bar ) +val value2 = DataTypes[Sc]( + 67.toByte, + mySmallInt = 32767.toShort, + myInt = 12345678, + myBigInt = 9876543210L, + myDouble = 2.71, + myBoolean = true, + myLocalDate = LocalDate.parse("2020-02-22"), + myLocalTime = LocalTime.parse("03:05:01"), + myLocalDateTime = LocalDateTime.parse("2021-06-07T02:01:03"), + myUtilDate = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS").parse("2021-06-07T02:01:03.000"), + myInstant = Instant.parse("2021-06-07T02:01:03Z"), + myVarBinary = new geny.Bytes(Array[Byte](9, 8, 7, 6, 5, 4, 3, 2)), + myUUID = new java.util.UUID(9876543210L, 1234567890L), + myEnum = MyEnum.baz +) + db.run( DataTypes.insert.columns( _.myTinyInt := value.myTinyInt, @@ -10140,8 +10159,26 @@ db.run( _.myEnum := value.myEnum ) ) ==> 1 +db.run( + DataTypes.insert.columns( + _.myTinyInt := value2.myTinyInt, + _.mySmallInt := value2.mySmallInt, + _.myInt := value2.myInt, + _.myBigInt := value2.myBigInt, + _.myDouble := value2.myDouble, + _.myBoolean := value2.myBoolean, + _.myLocalDate := value2.myLocalDate, + _.myLocalTime := value2.myLocalTime, + _.myLocalDateTime := value2.myLocalDateTime, + _.myUtilDate := value2.myUtilDate, + _.myInstant := value2.myInstant, + _.myVarBinary := value2.myVarBinary, + _.myUUID := value2.myUUID, + _.myEnum := value2.myEnum + ) +) ==> 1 -db.run(DataTypes.select) ==> Seq(value) +db.run(DataTypes.select) ==> Seq(value, value2) ``` @@ -11181,6 +11218,24 @@ val rowSome = OptDataTypes[Sc]( myUUID = Some(new java.util.UUID(1234567890L, 9876543210L)), myEnum = Some(MyEnum.bar) ) +val rowSome2 = OptDataTypes[Sc]( + myTinyInt = Some(67.toByte), + mySmallInt = Some(32767.toShort), + myInt = Some(23456789), + myBigInt = Some(9876543210L), + myDouble = Some(2.71), + myBoolean = Some(false), + myLocalDate = Some(LocalDate.parse("2020-02-22")), + myLocalTime = Some(LocalTime.parse("03:05:01")), + myLocalDateTime = Some(LocalDateTime.parse("2021-06-07T02:01:03")), + myUtilDate = Some( + new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS").parse("2021-06-07T02:01:03.000") + ), + myInstant = Some(Instant.parse("2021-06-07T02:01:03Z")), + myVarBinary = Some(new geny.Bytes(Array[Byte](9, 8, 7, 6, 5, 4, 3, 2))), + myUUID = Some(new java.util.UUID(9876543210L, 1234567890L)), + myEnum = Some(MyEnum.baz) +) val rowNone = OptDataTypes[Sc]( myTinyInt = None, @@ -11198,12 +11253,11 @@ val rowNone = OptDataTypes[Sc]( myUUID = None, myEnum = None ) - db.run( - OptDataTypes.insert.values(rowSome, rowNone) -) ==> 2 + OptDataTypes.insert.values(rowSome, rowSome2, rowNone) +) ==> 3 -db.run(OptDataTypes.select) ==> Seq(rowSome, rowNone) +db.run(OptDataTypes.select) ==> Seq(rowSome, rowSome2, rowNone) ``` @@ -12376,3 +12430,80 @@ db.concatWs(" ", "i", "am", "cow", 1337) ``` + +## MsSqlDialect +Operations specific to working with Microsoft SQL Databases +### MsSqlDialect.top + +For ScalaSql's Microsoft SQL dialect provides, the `.take(n)` operator translates +into a SQL `TOP(n)` clause + +```scala +Buyer.select.take(0) +``` + + +* + ```sql + SELECT TOP(?) buyer0.id AS id, buyer0.name AS name, buyer0.date_of_birth AS date_of_birth + FROM buyer buyer0 + ``` + + + +* + ```scala + Seq[Buyer[Sc]]() + ``` + + + +### MsSqlDialect.bool vs bit + +Insert rows with BIT values + +```scala +db.run( + BoolTypes.insert.columns( + _.nullable := value.nullable, + _.nonNullable := value.nonNullable, + _.a := value.a, + _.b := value.b, + _.comment := value.comment + ) +) ==> 1 +db.run( + BoolTypes.insert.columns( + _.nullable := value2.nullable, + _.nonNullable := value2.nonNullable, + _.a := value2.a, + _.b := value2.b, + _.comment := value2.comment + ) +) ==> 1 +``` + + + + + + +### MsSqlDialect.uodate BIT + + + +```scala +BoolTypes + .update(_.a `=` 1) + .set(_.nonNullable := true) +``` + + +* + ```sql + UPDATE bool_types SET non_nullable = ? WHERE (bool_types.a = ?) + ``` + + + + diff --git a/scalasql/core/src/Context.scala b/scalasql/core/src/Context.scala index 9af889cb..a6d5244b 100644 --- a/scalasql/core/src/Context.scala +++ b/scalasql/core/src/Context.scala @@ -19,6 +19,11 @@ trait Context { */ def exprNaming: Map[Expr.Identity, SqlStr] + /** + * Mark [[Expr]]s as a raw value for an INSERT or UPDATE context + */ + def valueMarker: Boolean + /** * The ScalaSql configuration */ @@ -28,9 +33,11 @@ trait Context { def withFromNaming(fromNaming: Map[Context.From, String]): Context def withExprNaming(exprNaming: Map[Expr.Identity, SqlStr]): Context + def markAsValue: Context } object Context { + trait From { /** @@ -58,6 +65,7 @@ object Context { case class Impl( fromNaming: Map[From, String], exprNaming: Map[Expr.Identity, SqlStr], + valueMarker: Boolean, config: Config, dialectConfig: DialectConfig ) extends Context { @@ -65,6 +73,10 @@ object Context { def withExprNaming(exprNaming: Map[Expr.Identity, SqlStr]): Context = copy(exprNaming = exprNaming) + + def markAsValue: Context = copy( + valueMarker = true + ) } /** @@ -96,7 +108,13 @@ object Context { .map { case (e, s) => (e, sql"${SqlStr.raw(newFromNaming(t), Array(e))}.$s") } } - Context.Impl(newFromNaming, newExprNaming, prevContext.config, prevContext.dialectConfig) + Context.Impl( + newFromNaming, + newExprNaming, + prevContext.valueMarker, + prevContext.config, + prevContext.dialectConfig + ) } } diff --git a/scalasql/core/src/DbApi.scala b/scalasql/core/src/DbApi.scala index d5981aa9..98b12ea4 100644 --- a/scalasql/core/src/DbApi.scala +++ b/scalasql/core/src/DbApi.scala @@ -131,7 +131,7 @@ object DbApi { config: Config, dialectConfig: DialectConfig ) = { - val ctx = Context.Impl(Map(), Map(), config, dialectConfig) + val ctx = Context.Impl(Map(), Map(), false, config, dialectConfig) val flattened = SqlStr.flatten(qr.renderSql(query, ctx)) flattened } @@ -583,7 +583,7 @@ object DbApi { try { val res = block(new DbApi.SavepointImpl(savepoint, () => rollbackSavepoint(savepoint))) - if (savepointStack.lastOption.exists(_ eq savepoint)) { + if (dialect.supportSavepointRelease && savepointStack.lastOption.exists(_ eq savepoint)) { // Only release if this savepoint has not been rolled back, // directly or indirectly connection.releaseSavepoint(savepoint) diff --git a/scalasql/core/src/DialectConfig.scala b/scalasql/core/src/DialectConfig.scala index 4c4f9930..9ee16313 100644 --- a/scalasql/core/src/DialectConfig.scala +++ b/scalasql/core/src/DialectConfig.scala @@ -3,11 +3,13 @@ package scalasql.core trait DialectConfig { that => def castParams: Boolean def escape(str: String): String + def supportSavepointRelease: Boolean def withCastParams(params: Boolean) = new DialectConfig { def castParams: Boolean = params - def escape(str: String): String = that.escape(str) + def supportSavepointRelease = that.supportSavepointRelease + def escape(str: String): String = that.escape(str) } } diff --git a/scalasql/query/src/CompoundSelect.scala b/scalasql/query/src/CompoundSelect.scala index b3ce890c..206527e2 100644 --- a/scalasql/query/src/CompoundSelect.scala +++ b/scalasql/query/src/CompoundSelect.scala @@ -112,7 +112,7 @@ object CompoundSelect { // columns are duplicates or not, and thus what final set of rows is returned lazy val preserveAll = query.compoundOps.exists(_.op != "UNION ALL") - def render(liveExprs: LiveExprs) = { + protected def prerender(liveExprs: LiveExprs) = { val innerLiveExprs = if (preserveAll) LiveExprs.none else liveExprs.map(_ ++ newReferencedExpressions) @@ -138,7 +138,14 @@ object CompoundSelect { SqlStr.join(compoundStrs) } - lhsStr + compound + sortOpt + limitOpt + offsetOpt + (lhsStr, compound, sortOpt, limitOpt, offsetOpt) + } + + def render(liveExprs: LiveExprs) = { + prerender(liveExprs) match { + case (lhsStr, compound, sortOpt, limitOpt, offsetOpt) => + lhsStr + compound + sortOpt + limitOpt + offsetOpt + } } def orderToSqlStr(newCtx: Context) = CompoundSelect.orderToSqlStr(query.orderBy, newCtx, gap = true) diff --git a/scalasql/query/src/Delete.scala b/scalasql/query/src/Delete.scala index 053981cf..fe905b57 100644 --- a/scalasql/query/src/Delete.scala +++ b/scalasql/query/src/Delete.scala @@ -1,8 +1,6 @@ package scalasql.query -import scalasql.core.DialectTypeMappers -import scalasql.core.Context -import scalasql.core.{Queryable, SqlStr, Expr} +import scalasql.core.{Context, DialectTypeMappers, Expr, ExprsToSql, Queryable, SqlStr} import scalasql.core.SqlStr.SqlStringSyntax /** @@ -26,6 +24,8 @@ object Delete { lazy val tableNameStr = SqlStr.raw(Table.fullIdentifier(table.value)) - def render() = sql"DELETE FROM $tableNameStr WHERE $expr" + lazy val filtersOpt = SqlStr.flatten(ExprsToSql.booleanExprs(sql" WHERE ", expr :: Nil)) + + def render() = sql"DELETE FROM $tableNameStr$filtersOpt" } } diff --git a/scalasql/query/src/InsertColumns.scala b/scalasql/query/src/InsertColumns.scala index 0a47f8a1..6a5b22db 100644 --- a/scalasql/query/src/InsertColumns.scala +++ b/scalasql/query/src/InsertColumns.scala @@ -24,7 +24,13 @@ object InsertColumns { protected def expr: V[Column] = WithSqlExpr.get(insert) private[scalasql] override def renderSql(ctx: Context) = - new Renderer(columns, ctx, valuesLists, Table.fullIdentifier(table.value)(ctx)).render() + new Renderer( + columns, + ctx.markAsValue, + valuesLists, + Table.fullIdentifier(table.value)(ctx) + ) + .render() override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = args.get(IntType) diff --git a/scalasql/query/src/InsertSelect.scala b/scalasql/query/src/InsertSelect.scala index e59e2a4b..f0ed3adf 100644 --- a/scalasql/query/src/InsertSelect.scala +++ b/scalasql/query/src/InsertSelect.scala @@ -23,7 +23,7 @@ object InsertSelect { new Renderer( select, select.qr.walkExprs(columns), - ctx, + ctx.markAsValue, Table.fullIdentifier(table.value)(ctx) ) .render() diff --git a/scalasql/query/src/InsertValues.scala b/scalasql/query/src/InsertValues.scala index 5ec85ed6..95c5ad2d 100644 --- a/scalasql/query/src/InsertValues.scala +++ b/scalasql/query/src/InsertValues.scala @@ -29,7 +29,7 @@ object InsertValues { values, qr, skippedColumns - )(ctx).render() + )(ctx.markAsValue).render() } override def skipColumns(x: (V[Column] => Column[?])*): InsertValues[V, R] = { diff --git a/scalasql/query/src/Update.scala b/scalasql/query/src/Update.scala index 920cc3b3..cc1d6f47 100644 --- a/scalasql/query/src/Update.scala +++ b/scalasql/query/src/Update.scala @@ -75,7 +75,7 @@ object Update { } private[scalasql] override def renderSql(ctx: Context): SqlStr = - new Renderer(joins, table, set0, where, ctx).render() + new Renderer(joins, table, set0, where, ctx.markAsValue).render() override protected def queryConstruct(args: Queryable.ResultSetIterator): Int = { args.get(dialect.IntType) diff --git a/scalasql/src/dialects/H2Dialect.scala b/scalasql/src/dialects/H2Dialect.scala index 01441a9b..d0765c8a 100644 --- a/scalasql/src/dialects/H2Dialect.scala +++ b/scalasql/src/dialects/H2Dialect.scala @@ -29,7 +29,9 @@ trait H2Dialect extends Dialect { def castParams = true - def escape(str: String) = s"\"${str.toUpperCase()}\"" + def escape(str: String) = s""""${str.toUpperCase()}"""" + + def supportSavepointRelease = true override implicit def EnumType[T <: Enumeration#Value]( implicit constructor: String => T @@ -67,6 +69,9 @@ trait H2Dialect extends Dialect { } object H2Dialect extends H2Dialect { + + override def supportSavepointRelease: Boolean = true + class DbApiOps(dialect: DialectTypeMappers) extends scalasql.operations.DbApiOps(dialect) with ConcatOps diff --git a/scalasql/src/dialects/MsSqlDialect.scala b/scalasql/src/dialects/MsSqlDialect.scala new file mode 100644 index 00000000..513eefd6 --- /dev/null +++ b/scalasql/src/dialects/MsSqlDialect.scala @@ -0,0 +1,384 @@ +package scalasql.dialects + +import scalasql.query.{AscDesc, GroupBy, Join, Nulls, OrderBy, SubqueryRef, Table} +import scalasql.core.{ + Aggregatable, + Context, + DbApi, + DialectTypeMappers, + Expr, + ExprsToSql, + LiveExprs, + Queryable, + SqlStr, + TypeMapper +} +import scalasql.{Sc, operations} +import scalasql.core.SqlStr.{Renderable, SqlStringSyntax} +import scalasql.operations.{ConcatOps, MathOps, TrimOps} + +import java.time.{Instant, LocalDateTime, OffsetDateTime, ZoneId, ZonedDateTime} +import java.sql.{JDBCType, PreparedStatement, ResultSet} +import scalasql.query.Column + +trait MsSqlDialect extends Dialect { + def castParams = false + + def escape(str: String): String = + s"[$str]" + + def supportSavepointRelease = false + + override implicit def IntType: TypeMapper[Int] = new MsSqlIntType + class MsSqlIntType extends IntType { override def castTypeString = "INT" } + + override implicit def StringType: TypeMapper[String] = new MsSqlStringType + class MsSqlStringType extends StringType { override def castTypeString = "VARCHAR" } + + override implicit def BooleanType: TypeMapper[Boolean] = new MsSqlBooleanType + class MsSqlBooleanType extends BooleanType { override def castTypeString = "BIT" } + override implicit def from(x: Boolean): Expr[Boolean] = + if (x) { + Expr.apply0(x, x) + } else { + Expr { ctx => + if (ctx.valueMarker) { + sql"$x" + } else { + sql"1 = $x" + } + } + } + + override implicit def DoubleType: TypeMapper[Double] = new MsDoubleType + class MsDoubleType extends DoubleType { override def castTypeString = "FLOAT" } + + override implicit def OptionType[T](implicit inner: TypeMapper[T]): TypeMapper[Option[T]] = + new TypeMapper[Option[T]] { + def jdbcType: JDBCType = inner.jdbcType + + def get(r: ResultSet, idx: Int): Option[T] = { + if (r.getObject(idx) == null) None else Some(inner.get(r, idx)) + } + + def put(r: PreparedStatement, idx: Int, v: Option[T]): Unit = { + v match { + case None => r.setObject(idx, null, inner.jdbcType) + case Some(value) => inner.put(r, idx, value) + } + } + } + + override implicit def UtilDateType: TypeMapper[java.util.Date] = new MsSqlUtilDateType + class MsSqlUtilDateType extends UtilDateType { override def castTypeString = "DATETIME2" } + + override implicit def LocalDateTimeType: TypeMapper[LocalDateTime] = new MsSqlLocalDateTimeType + class MsSqlLocalDateTimeType extends LocalDateTimeType { + override def castTypeString = "DATETIME2" + } + + override implicit def InstantType: TypeMapper[Instant] = new MsSqlInstantType + class MsSqlInstantType extends InstantType { override def castTypeString = "DATETIME2" } + + override implicit def ZonedDateTimeType: TypeMapper[ZonedDateTime] = new MsSqlZonedDateTimeType + class MsSqlZonedDateTimeType extends ZonedDateTimeType { + override def castTypeString = "DATETIMEOFFSET" + override def get(r: ResultSet, idx: Int) = { + val odt = r.getObject(idx, classOf[OffsetDateTime]) + if (odt == null) null + else odt.toZonedDateTime + } + + override def put(r: PreparedStatement, idx: Int, v: ZonedDateTime) = { + val odt = if (v == null) null else v.toOffsetDateTime + r.setObject(idx, odt) + } + } + + override implicit def OffsetDateTimeType: TypeMapper[OffsetDateTime] = new MsSqlOffsetDateTimeType + class MsSqlOffsetDateTimeType extends OffsetDateTimeType { + override def castTypeString = "DATETIMEOFFSET" + + override def get(r: ResultSet, idx: Int) = { + r.getObject(idx, classOf[OffsetDateTime]) + } + + override def put(r: PreparedStatement, idx: Int, v: OffsetDateTime) = { + r.setObject(idx, v) + } + } + + override implicit def EnumType[T <: Enumeration#Value]( + implicit constructor: String => T + ): TypeMapper[T] = new MsSqlEnumType[T] + + class MsSqlEnumType[T](implicit constructor: String => T) extends EnumType[T] { + override def put(r: PreparedStatement, idx: Int, v: T): Unit = r.setString(idx, v.toString) + } + + override implicit def ExprStringOpsConv(v: Expr[String]): MsSqlDialect.ExprStringOps[String] = + new MsSqlDialect.ExprStringOps(v) + + override implicit def ExprBlobOpsConv( + v: Expr[geny.Bytes] + ): MsSqlDialect.ExprStringLikeOps[geny.Bytes] = + new MsSqlDialect.ExprStringLikeOps(v) + + override implicit def ExprNumericOpsConv[T: Numeric: TypeMapper]( + v: Expr[T] + ): MsSqlDialect.ExprNumericOps[T] = new MsSqlDialect.ExprNumericOps(v) + + override implicit def TableOpsConv[V[_[_]]](t: Table[V]): scalasql.dialects.TableOps[V] = + new MsSqlDialect.TableOps(t) + + implicit def ExprAggOpsConv[T](v: Aggregatable[Expr[T]]): operations.ExprAggOps[T] = + new MsSqlDialect.ExprAggOps(v) + + override implicit def DbApiOpsConv(db: => DbApi): MsSqlDialect.DbApiOps = + new MsSqlDialect.DbApiOps(this) + + override implicit def ExprQueryable[T](implicit mt: TypeMapper[T]): Queryable.Row[Expr[T], T] = { + new MsSqlDialect.ExprQueryable[Expr, T]() + } +} + +object MsSqlDialect extends MsSqlDialect { + class DbApiOps(dialect: DialectTypeMappers) + extends scalasql.operations.DbApiOps(dialect) + with ConcatOps + with MathOps { + override def ln[T: Numeric](v: Expr[T]): Expr[Double] = Expr { implicit ctx => + sql"LOG($v)" + } + + override def atan2[T: Numeric](v: Expr[T], y: Expr[T]): Expr[Double] = Expr { implicit ctx => + sql"ATN2($v, $y)" + } + } + + class ExprAggOps[T](v: Aggregatable[Expr[T]]) extends scalasql.operations.ExprAggOps[T](v) { + def mkString(sep: Expr[String] = null)(implicit tm: TypeMapper[T]): Expr[String] = { + val sepRender = Option(sep).getOrElse(sql"''") + v.aggregateExpr(expr => implicit ctx => sql"STRING_AGG($expr + '', $sepRender)") + } + } + + class ExprStringOps[T](v: Expr[T]) extends ExprStringLikeOps(v) with operations.ExprStringOps[T] + class ExprStringLikeOps[T](protected val v: Expr[T]) + extends operations.ExprStringLikeOps(v) + with TrimOps { + + override def +(x: Expr[T]): Expr[T] = Expr { implicit ctx => + sql"($v + $x)" + } + + override def startsWith(other: Expr[T]): Expr[Boolean] = Expr { implicit ctx => + sql"($v LIKE CAST($other AS VARCHAR(MAX)) + '%')" + } + + override def endsWith(other: Expr[T]): Expr[Boolean] = Expr { implicit ctx => + sql"($v LIKE '%' + CAST($other AS VARCHAR(MAX)))" + } + + override def contains(other: Expr[T]): Expr[Boolean] = Expr { implicit ctx => + sql"($v LIKE '%' + CAST($other AS VARCHAR(MAX)) + '%')" + } + + override def length: Expr[Int] = Expr { implicit ctx => sql"LEN($v)" } + + override def octetLength: Expr[Int] = Expr { implicit ctx => sql"DATALENGTH($v)" } + + def indexOf(x: Expr[T]): Expr[Int] = Expr { implicit ctx => sql"CHARINDEX($x, $v)" } + def reverse: Expr[T] = Expr { implicit ctx => sql"REVERSE($v)" } + } + + class ExprNumericOps[T: Numeric: TypeMapper](protected val v: Expr[T]) + extends operations.ExprNumericOps[T](v) { + override def %[V: Numeric](x: Expr[V]): Expr[T] = Expr { implicit ctx => + sql"$v % $x" + } + + override def mod[V: Numeric](x: Expr[V]): Expr[T] = Expr { implicit ctx => + sql"$v % $x" + } + + override def ceil: Expr[T] = Expr { implicit ctx => + sql"CEILING($v)" + } + } + + class TableOps[V[_[_]]](t: Table[V]) extends scalasql.dialects.TableOps[V](t) { + + protected override def joinableToSelect: Select[V[Expr], V[Sc]] = { + val ref = Table.ref(t) + new SimpleSelect( + Table.metadata(t).vExpr(ref, dialectSelf).asInstanceOf[V[Expr]], + None, + None, + false, + Seq(ref), + Nil, + Nil, + None + )( + t.containerQr + ) + } + + } + + trait Select[Q, R] extends scalasql.query.Select[Q, R] { + override def newCompoundSelect[Q, R]( + lhs: scalasql.query.SimpleSelect[Q, R], + compoundOps: Seq[scalasql.query.CompoundSelect.Op[Q, R]], + orderBy: Seq[OrderBy], + limit: Option[Int], + offset: Option[Int] + )( + implicit qr: Queryable.Row[Q, R], + dialect: scalasql.core.DialectTypeMappers + ): scalasql.query.CompoundSelect[Q, R] = { + new CompoundSelect(lhs, compoundOps, orderBy, limit, offset) + } + + override def newSimpleSelect[Q, R]( + expr: Q, + exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], + preserveAll: Boolean, + from: Seq[Context.From], + joins: Seq[Join], + where: Seq[Expr[?]], + groupBy0: Option[GroupBy] + )( + implicit qr: Queryable.Row[Q, R], + dialect: scalasql.core.DialectTypeMappers + ): scalasql.query.SimpleSelect[Q, R] = { + new SimpleSelect(expr, exprPrefix, exprSuffix, preserveAll, from, joins, where, groupBy0) + } + + } + + class SimpleSelect[Q, R]( + expr: Q, + exprPrefix: Option[Context => SqlStr], + exprSuffix: Option[Context => SqlStr], + preserveAll: Boolean, + from: Seq[Context.From], + joins: Seq[Join], + where: Seq[Expr[?]], + groupBy0: Option[GroupBy] + )(implicit qr: Queryable.Row[Q, R]) + extends scalasql.query.SimpleSelect( + expr, + exprPrefix, + exprSuffix, + preserveAll, + from, + joins, + where, + groupBy0 + ) + with Select[Q, R] { + + override def take(n: Int): scalasql.query.Select[Q, R] = { + selectWithExprPrefix(true, _ => sql"TOP($n)") + } + + override def drop(n: Int): scalasql.query.Select[Q, R] = throw new Exception( + ".drop must follow .sortBy" + ) + + } + + class CompoundSelect[Q, R]( + lhs: scalasql.query.SimpleSelect[Q, R], + compoundOps: Seq[scalasql.query.CompoundSelect.Op[Q, R]], + orderBy: Seq[OrderBy], + limit: Option[Int], + offset: Option[Int] + )(implicit qr: Queryable.Row[Q, R]) + extends scalasql.query.CompoundSelect(lhs, compoundOps, orderBy, limit, offset) + with Select[Q, R] { + + override def take(n: Int): scalasql.query.Select[Q, R] = copy( + limit = Some(limit.fold(n)(math.min(_, n))), + offset = offset.orElse(Some(0)) + ) + + protected override def selectRenderer(prevContext: Context): SubqueryRef.Wrapped.Renderer = + new CompoundSelectRenderer(this, prevContext) + } + + class CompoundSelectRenderer[Q, R]( + query: scalasql.query.CompoundSelect[Q, R], + prevContext: Context + ) extends scalasql.query.CompoundSelect.Renderer(query, prevContext) { + override lazy val limitOpt = SqlStr.flatten(SqlStr.opt(query.limit) { limit => + sql" FETCH FIRST $limit ROWS ONLY" + }) + + override lazy val offsetOpt = SqlStr.flatten( + SqlStr.opt(query.offset.orElse(Option.when(query.limit.nonEmpty)(0))) { offset => + sql" OFFSET $offset ROWS" + } + ) + + override def render(liveExprs: LiveExprs): SqlStr = { + prerender(liveExprs) match { + case (lhsStr, compound, sortOpt, limitOpt, offsetOpt) => + lhsStr + compound + sortOpt + offsetOpt + limitOpt + } + } + + override def orderToSqlStr(newCtx: Context) = { + SqlStr.optSeq(query.orderBy) { orderBys => + val orderStr = SqlStr.join( + orderBys.map { orderBy => + val exprStr = Renderable.renderSql(orderBy.expr)(newCtx) + + (orderBy.ascDesc, orderBy.nulls) match { + case (Some(AscDesc.Asc), None | Some(Nulls.First)) => sql"$exprStr ASC" + case (Some(AscDesc.Desc), Some(Nulls.First)) => + sql"IIF($exprStr IS NULL, 0, 1), $exprStr DESC" + case (Some(AscDesc.Asc), Some(Nulls.Last)) => + sql"IIF($exprStr IS NULL, 1, 0), $exprStr ASC" + case (Some(AscDesc.Desc), None | Some(Nulls.Last)) => sql"$exprStr DESC" + case (None, None) => exprStr + case (None, Some(Nulls.First)) => sql"IIF($exprStr IS NULL, 0, 1), $exprStr" + case (None, Some(Nulls.Last)) => sql"IIF($exprStr IS NULL, 1, 0), $exprStr" + } + }, + SqlStr.commaSep + ) + + sql" ORDER BY $orderStr" + } + } + } + + class ExprQueryable[E[_] <: Expr[?], T]( + implicit tm: TypeMapper[T] + ) extends Expr.ExprQueryable[E, T] { + + override def walkExprs(q: E[T]): Seq[Expr[?]] = + if (tm.jdbcType == JDBCType.BOOLEAN) { + q match { + // with the introduction of the value marker this is only necessary for Scala 3 + case _: Column[?] => + Seq(Expr[Boolean] { implicit ctx: Context => + sql"$q" + }) + case _ => + Seq(Expr[Boolean] { implicit ctx: Context => + if (ctx.valueMarker) { + sql"$q" + } else { + sql"CASE WHEN $q THEN 1 ELSE 0 END" + } + }) + } + } else + super.walkExprs(q) + } +} diff --git a/scalasql/src/dialects/MySqlDialect.scala b/scalasql/src/dialects/MySqlDialect.scala index 0c0a2f37..8c3ac92d 100644 --- a/scalasql/src/dialects/MySqlDialect.scala +++ b/scalasql/src/dialects/MySqlDialect.scala @@ -47,6 +47,8 @@ trait MySqlDialect extends Dialect { def escape(str: String) = s"`$str`" + def supportSavepointRelease = true + override implicit def ByteType: TypeMapper[Byte] = new MySqlByteType class MySqlByteType extends ByteType { override def castTypeString = "SIGNED" } @@ -158,6 +160,8 @@ trait MySqlDialect extends Dialect { object MySqlDialect extends MySqlDialect { + override def supportSavepointRelease: Boolean = true + class DbApiOps(dialect: DialectTypeMappers) extends scalasql.operations.DbApiOps(dialect) with ConcatOps diff --git a/scalasql/src/dialects/PostgresDialect.scala b/scalasql/src/dialects/PostgresDialect.scala index 5df74c9c..5a6e95c2 100644 --- a/scalasql/src/dialects/PostgresDialect.scala +++ b/scalasql/src/dialects/PostgresDialect.scala @@ -19,7 +19,9 @@ trait PostgresDialect extends Dialect with ReturningDialect with OnConflictOps { def castParams = false - def escape(str: String) = s"\"$str\"" + def escape(str: String) = s""""$str"""" + + def supportSavepointRelease = true override implicit def ByteType: TypeMapper[Byte] = new PostgresByteType class PostgresByteType extends ByteType { override def castTypeString = "INTEGER" } diff --git a/scalasql/src/dialects/SqliteDialect.scala b/scalasql/src/dialects/SqliteDialect.scala index 5673df1f..94539989 100644 --- a/scalasql/src/dialects/SqliteDialect.scala +++ b/scalasql/src/dialects/SqliteDialect.scala @@ -20,7 +20,9 @@ import java.time.{Instant, LocalDate, LocalDateTime} trait SqliteDialect extends Dialect with ReturningDialect with OnConflictOps { def castParams = false - def escape(str: String) = s"\"$str\"" + def escape(str: String) = s""""$str"""" + + def supportSavepointRelease = true override implicit def LocalDateTimeType: TypeMapper[LocalDateTime] = new SqliteLocalDateTimeType class SqliteLocalDateTimeType extends LocalDateTimeType { @@ -55,6 +57,9 @@ trait SqliteDialect extends Dialect with ReturningDialect with OnConflictOps { } object SqliteDialect extends SqliteDialect { + + override def supportSavepointRelease: Boolean = true + class DbApiOps(dialect: DialectTypeMappers) extends scalasql.operations.DbApiOps(dialect) { /** diff --git a/scalasql/src/package.scala b/scalasql/src/package.scala index 7a04ee2e..c018dd20 100644 --- a/scalasql/src/package.scala +++ b/scalasql/src/package.scala @@ -55,4 +55,7 @@ package object scalasql { val SqliteDialect = dialects.SqliteDialect type SqliteDialect = dialects.SqliteDialect + + val MsSqlDialect = dialects.MsSqlDialect + type MsSqlDialect = dialects.MsSqlDialect } diff --git a/scalasql/test/resources/mssql-customer-schema.sql b/scalasql/test/resources/mssql-customer-schema.sql new file mode 100644 index 00000000..9e8f9a8c --- /dev/null +++ b/scalasql/test/resources/mssql-customer-schema.sql @@ -0,0 +1,129 @@ +IF EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = N'shipping_info') + ALTER TABLE shipping_info DROP CONSTRAINT IF EXISTS fk_shipping_info_buyer; +IF EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = N'purchase') + ALTER TABLE purchase DROP CONSTRAINT IF EXISTS fk_purchase_shipping_info; +IF EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = N'purchase') + ALTER TABLE purchase DROP CONSTRAINT IF EXISTS fk_purchase_product; +DROP TABLE IF EXISTS buyer; +DROP TABLE IF EXISTS product; +DROP TABLE IF EXISTS shipping_info; +DROP TABLE IF EXISTS purchase; +DROP TABLE IF EXISTS data_types; +DROP TABLE IF EXISTS a; +DROP TABLE IF EXISTS b; +DROP TABLE IF EXISTS non_round_trip_types; +DROP TABLE IF EXISTS opt_cols; +DROP TABLE IF EXISTS nested; +DROP TABLE IF EXISTS enclosing; +DROP TABLE IF EXISTS otherschema.invoice; +DROP SCHEMA IF EXISTS otherschema; +DROP TABLE IF EXISTS bool_types; +DROP TABLE IF EXISTS [select]; + +CREATE TABLE buyer ( + id INT PRIMARY KEY IDENTITY(1, 1), + name VARCHAR(256), + date_of_birth DATE +); + +CREATE TABLE product ( + id INT PRIMARY KEY IDENTITY(1, 1), + kebab_case_name VARCHAR(256), + name VARCHAR(256), + price DECIMAL(20, 2) +); + +CREATE TABLE shipping_info ( + id INT PRIMARY KEY IDENTITY(1, 1), + buyer_id INT, + shipping_date DATE, + CONSTRAINT fk_shipping_info_buyer + FOREIGN KEY(buyer_id) REFERENCES buyer(id) +); + +CREATE TABLE purchase ( + id INT PRIMARY KEY IDENTITY(1, 1), + shipping_info_id INT, + product_id INT, + count INT, + total DECIMAL(20, 2), + CONSTRAINT fk_purchase_shipping_info + FOREIGN KEY(shipping_info_id) REFERENCES shipping_info(id), + CONSTRAINT fk_purchase_product + FOREIGN KEY(product_id) REFERENCES product(id) +); + +CREATE TABLE data_types ( + my_tiny_int TINYINT, + my_small_int SMALLINT, + my_int INT, + my_big_int BIGINT, + my_double FLOAT(53), + my_boolean BIT, + my_local_date DATE, + my_local_time TIME, + my_local_date_time DATETIME2, + my_util_date DATETIME2, + my_instant DATETIME2, + my_var_binary VARBINARY(256), + my_uuid UNIQUEIDENTIFIER, + my_enum VARCHAR(256) +-- my_offset_time TIME WITH TIME ZONE, + +); + +CREATE TABLE a ( + id INT, + b_id INT +); + +CREATE TABLE b ( + id INT, + custom VARCHAR(256) +); + +CREATE TABLE non_round_trip_types( + my_zoned_date_time DATETIMEOFFSET, + my_offset_date_time DATETIMEOFFSET +); + +CREATE TABLE opt_cols( + my_int INT, + my_int2 INT +); + +CREATE TABLE nested( + foo_id INT, + my_boolean BIT +); + +CREATE TABLE enclosing( + bar_id INT, + my_string VARCHAR(256), + foo_id INT, + my_boolean BIT +); + +IF (NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'otherschema')) +BEGIN + EXEC ('CREATE SCHEMA otherschema') +END; + +CREATE TABLE otherschema.invoice( + id INT PRIMARY KEY IDENTITY(1, 1), + total DECIMAL(20, 2), + vendor_name VARCHAR(256) +); + +CREATE TABLE [select]( + id INT, + name VARCHAR(256) +); + +CREATE TABLE bool_types ( + nullable BIT, + non_nullable BIT NOT NULL, + a INT, + b INT, + comment VARCHAR(256) +); diff --git a/scalasql/test/src/ConcreteTestSuites.scala b/scalasql/test/src/ConcreteTestSuites.scala index 3eb10c21..758ecba5 100644 --- a/scalasql/test/src/ConcreteTestSuites.scala +++ b/scalasql/test/src/ConcreteTestSuites.scala @@ -37,7 +37,8 @@ import scalasql.dialects.{ MySqlDialectTests, PostgresDialectTests, SqliteDialectTests, - H2DialectTests + H2DialectTests, + MsSqlDialectTests } package postgres { @@ -283,3 +284,54 @@ package h2 { object H2DialectTests extends H2DialectTests } + +package mssql { + + import utils.MsSqlSuite + + object DbApiTests extends DbApiTests with MsSqlSuite + object TransactionTests extends TransactionTests with MsSqlSuite + + object SelectTests extends SelectTests with MsSqlSuite + object JoinTests extends JoinTests with MsSqlSuite + object FlatJoinTests extends FlatJoinTests with MsSqlSuite + object InsertTests extends InsertTests with MsSqlSuite + object UpdateTests extends UpdateTests with MsSqlSuite + object DeleteTests extends DeleteTests with MsSqlSuite + object CompoundSelectTests extends CompoundSelectTests with MsSqlSuite + object UpdateJoinTests extends UpdateJoinTests with MsSqlSuite + object UpdateSubQueryTests extends UpdateSubQueryTests with MsSqlSuite + // SQL Server does not support standalone VALUES + // object ValuesTests extends ValuesTests with MsSqlSuite + // SQL Server does not support RETURNING + // object ReturningTests extends ReturningTests with MsSqlSuite + // SQL Server does not support ON CONFLICT + // object OnConflictTests extends OnConflictTests with MsSqlSuite + // SQL Server does not support LATERAL JOIN + // object LateralJoinTests extends LateralJoinTests with MsSqlSuite + object WindowFunctionTests extends WindowFunctionTests with MsSqlSuite + object GetGeneratedKeysTests extends GetGeneratedKeysTests with MsSqlSuite + object SchemaTests extends SchemaTests with MsSqlSuite + + object SubQueryTests extends SubQueryTests with MsSqlSuite + object WithCteTests extends WithCteTests with MsSqlSuite + + object DbApiOpsTests extends DbApiOpsTests with MsSqlSuite + object ExprOpsTests extends ExprOpsTests with MsSqlSuite + // TODO these tests operate on raw Booleans, further disambiguation of BIT + // values and filter expressions is required + // object ExprBooleanOpsTests extends ExprBooleanOpsTests with MsSqlSuite + object ExprNumericOpsTests extends ExprNumericOpsTests with MsSqlSuite + object ExprSeqNumericOpsTests extends ExprAggNumericOpsTests with MsSqlSuite + object ExprSeqOpsTests extends ExprAggOpsTests with MsSqlSuite + object ExprStringOpsTests extends ExprStringOpsTests with MsSqlSuite + object ExprBlobOpsTests extends ExprBlobOpsTests with MsSqlSuite + object ExprMathOpsTests extends ExprMathOpsTests with MsSqlSuite + + object DataTypesTests extends datatypes.DataTypesTests with MsSqlSuite + + object OptionalTests extends datatypes.OptionalTests with MsSqlSuite + + object MsSqlDialectTests extends MsSqlDialectTests + +} diff --git a/scalasql/test/src/ExampleTests.scala b/scalasql/test/src/ExampleTests.scala index cecc902e..bf703dff 100644 --- a/scalasql/test/src/ExampleTests.scala +++ b/scalasql/test/src/ExampleTests.scala @@ -11,5 +11,6 @@ object ExampleTests extends TestSuite { test("h2") - example.H2Example.main(Array()) test("sqlite") - example.SqliteExample.main(Array()) test("hikari") - example.HikariCpExample.main(Array()) + test("mssql") - example.MsSqlExample.main(Array()) } } diff --git a/scalasql/test/src/api/DbApiTests.scala b/scalasql/test/src/api/DbApiTests.scala index dd840eac..65d992a1 100644 --- a/scalasql/test/src/api/DbApiTests.scala +++ b/scalasql/test/src/api/DbApiTests.scala @@ -3,9 +3,9 @@ package scalasql.api import geny.Generator import scalasql.core.SqlStr.SqlStringSyntax import scalasql.{Buyer, Sc} -import scalasql.utils.{MySqlSuite, ScalaSqlSuite, SqliteSuite} +import scalasql.utils.{MsSqlSuite, MySqlSuite, ScalaSqlSuite, SqliteSuite} import sourcecode.Text -import utest._ +import utest.* import java.time.LocalDate @@ -124,7 +124,7 @@ trait DbApiTests extends ScalaSqlSuite { } ) test("updateGetGeneratedKeysSql") - { - if (!this.isInstanceOf[SqliteSuite]) + if (!this.isInstanceOf[SqliteSuite] && !this.isInstanceOf[MsSqlSuite]) checker.recorded( """ Allows you to fetch the primary keys that were auto-generated for an INSERT @@ -209,7 +209,10 @@ trait DbApiTests extends ScalaSqlSuite { LocalDate.parse("2000-01-01") ) ) - assert(generatedKeys == Seq(4, 5)) + if (!this.isInstanceOf[MsSqlSuite]) + assert(generatedKeys == Seq(4, 5)) + else + assert(generatedKeys == Seq(5)) db.run(Buyer.select) ==> List( Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), diff --git a/scalasql/test/src/datatypes/DataTypesTests.scala b/scalasql/test/src/datatypes/DataTypesTests.scala index 677b2f06..655a977c 100644 --- a/scalasql/test/src/datatypes/DataTypesTests.scala +++ b/scalasql/test/src/datatypes/DataTypesTests.scala @@ -73,7 +73,7 @@ trait DataTypesTests extends ScalaSqlSuite { myInt = 12345678, myBigInt = 12345678901L, myDouble = 3.14, - myBoolean = true, + myBoolean = false, myLocalDate = LocalDate.parse("2023-12-20"), myLocalTime = LocalTime.parse("10:15:30"), myLocalDateTime = LocalDateTime.parse("2011-12-03T10:15:30"), @@ -85,6 +85,24 @@ trait DataTypesTests extends ScalaSqlSuite { myEnum = MyEnum.bar ) + val value2 = DataTypes[Sc]( + 67.toByte, + mySmallInt = 32767.toShort, + myInt = 12345678, + myBigInt = 9876543210L, + myDouble = 2.71, + myBoolean = true, + myLocalDate = LocalDate.parse("2020-02-22"), + myLocalTime = LocalTime.parse("03:05:01"), + myLocalDateTime = LocalDateTime.parse("2021-06-07T02:01:03"), + myUtilDate = + new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS").parse("2021-06-07T02:01:03.000"), + myInstant = Instant.parse("2021-06-07T02:01:03Z"), + myVarBinary = new geny.Bytes(Array[Byte](9, 8, 7, 6, 5, 4, 3, 2)), + myUUID = new java.util.UUID(9876543210L, 1234567890L), + myEnum = MyEnum.baz + ) + db.run( DataTypes.insert.columns( _.myTinyInt := value.myTinyInt, @@ -103,8 +121,26 @@ trait DataTypesTests extends ScalaSqlSuite { _.myEnum := value.myEnum ) ) ==> 1 + db.run( + DataTypes.insert.columns( + _.myTinyInt := value2.myTinyInt, + _.mySmallInt := value2.mySmallInt, + _.myInt := value2.myInt, + _.myBigInt := value2.myBigInt, + _.myDouble := value2.myDouble, + _.myBoolean := value2.myBoolean, + _.myLocalDate := value2.myLocalDate, + _.myLocalTime := value2.myLocalTime, + _.myLocalDateTime := value2.myLocalDateTime, + _.myUtilDate := value2.myUtilDate, + _.myInstant := value2.myInstant, + _.myVarBinary := value2.myVarBinary, + _.myUUID := value2.myUUID, + _.myEnum := value2.myEnum + ) + ) ==> 1 - db.run(DataTypes.select) ==> Seq(value) + db.run(DataTypes.select) ==> Seq(value, value2) } ) diff --git a/scalasql/test/src/datatypes/OptionalTests.scala b/scalasql/test/src/datatypes/OptionalTests.scala index c94c40a5..f709c0f9 100644 --- a/scalasql/test/src/datatypes/OptionalTests.scala +++ b/scalasql/test/src/datatypes/OptionalTests.scala @@ -412,6 +412,11 @@ trait OptionalTests extends ScalaSqlSuite { SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 FROM opt_cols opt_cols0 ORDER BY my_int IS NULL ASC, my_int + """, + """ + SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 + FROM opt_cols opt_cols0 + ORDER BY IIF(my_int IS NULL, 1, 0), my_int """ ), value = Seq( @@ -420,6 +425,14 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](None, None), OptCols[Sc](None, Some(4)) ), + moreValues = Seq( + Seq( + OptCols[Sc](Some(1), Some(2)), + OptCols[Sc](Some(3), None), + OptCols[Sc](None, Some(4)), + OptCols[Sc](None, None) + ) + ), // the MSSQL workaround for NULLS FIRST/LAST does not guarantee the ordering of other columns docs = """ `.nullsLast` and `.nullsFirst` translate to SQL `NULLS LAST` and `NULLS FIRST` clauses """ @@ -436,6 +449,11 @@ trait OptionalTests extends ScalaSqlSuite { SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 FROM opt_cols opt_cols0 ORDER BY my_int IS NULL DESC, my_int + """, + """ + SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 + FROM opt_cols opt_cols0 + ORDER BY IIF(my_int IS NULL, 0, 1), my_int """ ), value = Seq( @@ -443,7 +461,15 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](None, Some(4)), OptCols[Sc](Some(1), Some(2)), OptCols[Sc](Some(3), None) - ) + ), + moreValues = Seq( + Seq( + OptCols[Sc](None, Some(4)), + OptCols[Sc](None, None), + OptCols[Sc](Some(1), Some(2)), + OptCols[Sc](Some(3), None) + ) + ) // the MSSQL workaround for NULLS FIRST/LAST does not guarantee ordering of other columns ) test("ascNullsLast") - checker( query = Text { OptCols.select.sortBy(_.myInt).asc.nullsLast }, @@ -457,6 +483,11 @@ trait OptionalTests extends ScalaSqlSuite { SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 FROM opt_cols opt_cols0 ORDER BY my_int IS NULL ASC, my_int ASC + """, + """ + SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 + FROM opt_cols opt_cols0 + ORDER BY IIF(my_int IS NULL, 1, 0), my_int ASC """ ), value = Seq( @@ -464,7 +495,15 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](Some(3), None), OptCols[Sc](None, None), OptCols[Sc](None, Some(4)) - ) + ), + moreValues = Seq( + Seq( + OptCols[Sc](Some(1), Some(2)), + OptCols[Sc](Some(3), None), + OptCols[Sc](None, Some(4)), + OptCols[Sc](None, None) + ) + ) // the MSSQL workaround for NULLS FIRST/LAST does not guarantee ordering of other columns ) test("ascNullsFirst") - checker( query = Text { OptCols.select.sortBy(_.myInt).asc.nullsFirst }, @@ -485,7 +524,15 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](None, Some(4)), OptCols[Sc](Some(1), Some(2)), OptCols[Sc](Some(3), None) - ) + ), + moreValues = Seq( + Seq( + OptCols[Sc](None, None), + OptCols[Sc](None, Some(4)), + OptCols[Sc](Some(1), Some(2)), + OptCols[Sc](Some(3), None) + ) + ) // the MSSQL workaround for NULLS FIRST/LAST does not guarantee ordering of other columns ) test("descNullsLast") - checker( query = Text { OptCols.select.sortBy(_.myInt).desc.nullsLast }, @@ -506,7 +553,15 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](Some(1), Some(2)), OptCols[Sc](None, None), OptCols[Sc](None, Some(4)) - ) + ), + moreValues = Seq( + Seq( + OptCols[Sc](Some(3), None), + OptCols[Sc](Some(1), Some(2)), + OptCols[Sc](None, None), + OptCols[Sc](None, Some(4)) + ) + ) // the MSSQL workaround for NULLS FIRST/LAST does not guarantee ordering of other columns ) test("descNullsFirst") - checker( query = Text { OptCols.select.sortBy(_.myInt).desc.nullsFirst }, @@ -520,6 +575,11 @@ trait OptionalTests extends ScalaSqlSuite { SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 FROM opt_cols opt_cols0 ORDER BY my_int IS NULL DESC, my_int DESC + """, + """ + SELECT opt_cols0.my_int AS my_int, opt_cols0.my_int2 AS my_int2 + FROM opt_cols opt_cols0 + ORDER BY IIF(my_int IS NULL, 0, 1), my_int DESC """ ), value = Seq( @@ -527,7 +587,15 @@ trait OptionalTests extends ScalaSqlSuite { OptCols[Sc](None, Some(4)), OptCols[Sc](Some(3), None), OptCols[Sc](Some(1), Some(2)) - ) + ), + moreValues = Seq( + Seq( + OptCols[Sc](None, Some(4)), + OptCols[Sc](None, None), + OptCols[Sc](Some(3), None), + OptCols[Sc](Some(1), Some(2)) + ) + ) // the MSSQL workaround for NULLS FIRST/LAST does not guarantee ordering of other columns ) test("roundTripOptionalValues") - checker.recorded( """ @@ -579,6 +647,24 @@ trait OptionalTests extends ScalaSqlSuite { myUUID = Some(new java.util.UUID(1234567890L, 9876543210L)), myEnum = Some(MyEnum.bar) ) + val rowSome2 = OptDataTypes[Sc]( + myTinyInt = Some(67.toByte), + mySmallInt = Some(32767.toShort), + myInt = Some(23456789), + myBigInt = Some(9876543210L), + myDouble = Some(2.71), + myBoolean = Some(false), + myLocalDate = Some(LocalDate.parse("2020-02-22")), + myLocalTime = Some(LocalTime.parse("03:05:01")), + myLocalDateTime = Some(LocalDateTime.parse("2021-06-07T02:01:03")), + myUtilDate = Some( + new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS").parse("2021-06-07T02:01:03.000") + ), + myInstant = Some(Instant.parse("2021-06-07T02:01:03Z")), + myVarBinary = Some(new geny.Bytes(Array[Byte](9, 8, 7, 6, 5, 4, 3, 2))), + myUUID = Some(new java.util.UUID(9876543210L, 1234567890L)), + myEnum = Some(MyEnum.baz) + ) val rowNone = OptDataTypes[Sc]( myTinyInt = None, @@ -596,12 +682,11 @@ trait OptionalTests extends ScalaSqlSuite { myUUID = None, myEnum = None ) - db.run( - OptDataTypes.insert.values(rowSome, rowNone) - ) ==> 2 + OptDataTypes.insert.values(rowSome, rowSome2, rowNone) + ) ==> 3 - db.run(OptDataTypes.select) ==> Seq(rowSome, rowNone) + db.run(OptDataTypes.select) ==> Seq(rowSome, rowSome2, rowNone) } ) diff --git a/scalasql/test/src/dialects/MsSqlDialectTests.scala b/scalasql/test/src/dialects/MsSqlDialectTests.scala new file mode 100644 index 00000000..72348203 --- /dev/null +++ b/scalasql/test/src/dialects/MsSqlDialectTests.scala @@ -0,0 +1,84 @@ +package scalasql.dialects + +import scalasql._ +import sourcecode.Text +import utest._ +import utils.MsSqlSuite + +trait MsSqlDialectTests extends MsSqlSuite { + def description = "Operations specific to working with Microsoft SQL Databases" + + case class BoolTypes[T[_]]( + nullable: T[Option[Boolean]], + nonNullable: T[Boolean], + a: T[Int], + b: T[Int], + comment: T[String] + ) + + object BoolTypes extends Table[BoolTypes] + val value = BoolTypes[Sc]( + nullable = Some(false), + nonNullable = true, + a = 1, + b = 2, + "first" + ) + val value2 = BoolTypes[Sc]( + nullable = Some(true), + nonNullable = false, + a = 10, + b = 5, + "second" + ) + + def tests = { + Tests { + + test("top") - checker( + query = Buyer.select.take(0), + sql = """ + SELECT TOP(?) buyer0.id AS id, buyer0.name AS name, buyer0.date_of_birth AS date_of_birth + FROM buyer buyer0 + """, + value = Seq[Buyer[Sc]](), + docs = """ + For ScalaSql's Microsoft SQL dialect provides, the `.take(n)` operator translates + into a SQL `TOP(n)` clause + """ + ) + + test("bool vs bit") - checker.recorded( + """Insert rows with BIT values""", + Text { + db.run( + BoolTypes.insert.columns( + _.nullable := value.nullable, + _.nonNullable := value.nonNullable, + _.a := value.a, + _.b := value.b, + _.comment := value.comment + ) + ) ==> 1 + db.run( + BoolTypes.insert.columns( + _.nullable := value2.nullable, + _.nonNullable := value2.nonNullable, + _.a := value2.a, + _.b := value2.b, + _.comment := value2.comment + ) + ) ==> 1 + } + ) + + test("uodate BIT") - checker( + query = BoolTypes + .update(_.a `=` 1) + .set(_.nonNullable := true), + sql = "UPDATE bool_types SET non_nullable = ? WHERE (bool_types.a = ?)" + ) + + } + } +} diff --git a/scalasql/test/src/example/MsSqlExample.scala b/scalasql/test/src/example/MsSqlExample.scala new file mode 100644 index 00000000..6475d487 --- /dev/null +++ b/scalasql/test/src/example/MsSqlExample.scala @@ -0,0 +1,87 @@ +package scalasql.example + +import org.testcontainers.containers.MSSQLServerContainer +import org.testcontainers.containers.output.WaitingConsumer +import org.testcontainers.containers.output.OutputFrame.OutputType.STDOUT +import scalasql.Table +import scalasql.MsSqlDialect._ +import java.util.concurrent.TimeUnit + +object MsSqlExample { + case class ExampleProduct[T[_]]( + id: T[Int], + kebabCaseName: T[String], + name: T[String], + price: T[Double] + ) + + object ExampleProduct extends Table[ExampleProduct] + + lazy val mssql = { + println("Initializing MsSql") + val mssql = new MSSQLServerContainer("mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04") + mssql.acceptLicense() + mssql.addEnv("MSSQL_COLLATION", "Latin1_General_100_CS_AS_SC_UTF8") + mssql.start() + + val consumer = new WaitingConsumer() + mssql.followOutput(consumer, STDOUT) + consumer.waitUntil( + frame => frame.getUtf8String().contains("The default collation was successfully changed."), + 60, + TimeUnit.SECONDS + ) + + mssql + } + + val dataSource = new com.microsoft.sqlserver.jdbc.SQLServerDataSource + dataSource.setURL(mssql.getJdbcUrl) + dataSource.setUser(mssql.getUsername) + dataSource.setPassword(mssql.getPassword) + + lazy val mssqlClient = new scalasql.DbClient.DataSource( + dataSource, + config = new scalasql.Config {} + ) + + def main(args: Array[String]): Unit = { + mssqlClient.transaction { db => + db.updateRaw(""" + CREATE TABLE example_product ( + id INT PRIMARY KEY IDENTITY(1, 1), + kebab_case_name VARCHAR(256), + name VARCHAR(256), + price DECIMAL(20, 2) + ); + """) + + val inserted = db.run( + ExampleProduct.insert.batched(_.kebabCaseName, _.name, _.price)( + ("face-mask", "Face Mask", 8.88), + ("guitar", "Guitar", 300), + ("socks", "Socks", 3.14), + ("skate-board", "Skate Board", 123.45), + ("camera", "Camera", 1000.00), + ("cookie", "Cookie", 0.10) + ) + ) + + assert(inserted == 6) + + val result = + db.run(ExampleProduct.select.filter(_.price > 10).sortBy(_.price).desc.map(_.name)) + + assert(result == Seq("Camera", "Guitar", "Skate Board")) + + db.run(ExampleProduct.update(_.name === "Cookie").set(_.price := 11.0)) + + db.run(ExampleProduct.delete(_.name === "Guitar")) + + val result2 = + db.run(ExampleProduct.select.filter(_.price > 10).sortBy(_.price).desc.map(_.name)) + + assert(result2 == Seq("Camera", "Skate Board", "Cookie")) + } + } +} diff --git a/scalasql/test/src/operations/DbBlobOpsTests.scala b/scalasql/test/src/operations/DbBlobOpsTests.scala index 532a5c81..e52bb29c 100644 --- a/scalasql/test/src/operations/DbBlobOpsTests.scala +++ b/scalasql/test/src/operations/DbBlobOpsTests.scala @@ -10,32 +10,43 @@ trait ExprBlobOpsTests extends ScalaSqlSuite { def tests = Tests { test("plus") - checker( query = Expr(Bytes("hello")) + Expr(Bytes("world")), - sqls = Seq("SELECT (? || ?) AS res", "SELECT CONCAT(?, ?) AS res"), + sqls = Seq( + "SELECT (? || ?) AS res", + "SELECT CONCAT(?, ?) AS res", + "SELECT (? + ?) AS res" + ), value = Bytes("helloworld") ) test("like") - checker( query = Expr(Bytes("hello")).like(Bytes("he%")), - sql = "SELECT (? LIKE ?) AS res", + sqls = Seq( + "SELECT (? LIKE ?) AS res", + "SELECT CASE WHEN (? LIKE ?) THEN 1 ELSE 0 END AS res" + ), value = true ) test("length") - checker( query = Expr(Bytes("hello")).length, - sql = "SELECT LENGTH(?) AS res", + sqls = Seq("SELECT LENGTH(?) AS res", "SELECT LEN(?) AS res"), value = 5 ) test("octetLength") - checker( query = Expr(Bytes("叉烧包")).octetLength, - sql = "SELECT OCTET_LENGTH(?) AS res", + sqls = Seq("SELECT OCTET_LENGTH(?) AS res", "SELECT DATALENGTH(?) AS res"), value = 9, moreValues = Seq(6) // Not sure why HsqlExpr returns different value here ??? ) test("position") - checker( query = Expr(Bytes("hello")).indexOf(Bytes("ll")), - sqls = Seq("SELECT POSITION(? IN ?) AS res", "SELECT INSTR(?, ?) AS res"), + sqls = Seq( + "SELECT POSITION(? IN ?) AS res", + "SELECT INSTR(?, ?) AS res", + "SELECT CHARINDEX(?, ?) AS res" + ), value = 3 ) // Not supported by postgres @@ -62,7 +73,8 @@ trait ExprBlobOpsTests extends ScalaSqlSuite { query = Expr(Bytes("Hello")).startsWith(Bytes("Hel")), sqls = Seq( "SELECT (? LIKE ? || '%') AS res", - "SELECT (? LIKE CONCAT(?, '%')) AS res" + "SELECT (? LIKE CONCAT(?, '%')) AS res", + "SELECT CASE WHEN (? LIKE CAST(? AS VARCHAR(MAX)) + '%') THEN 1 ELSE 0 END AS res" ), value = true ) @@ -71,7 +83,8 @@ trait ExprBlobOpsTests extends ScalaSqlSuite { query = Expr(Bytes("Hello")).endsWith(Bytes("llo")), sqls = Seq( "SELECT (? LIKE '%' || ?) AS res", - "SELECT (? LIKE CONCAT('%', ?)) AS res" + "SELECT (? LIKE CONCAT('%', ?)) AS res", + "SELECT CASE WHEN (? LIKE '%' + CAST(? AS VARCHAR(MAX))) THEN 1 ELSE 0 END AS res" ), value = true ) @@ -80,7 +93,8 @@ trait ExprBlobOpsTests extends ScalaSqlSuite { query = Expr(Bytes("Hello")).contains(Bytes("ll")), sqls = Seq( "SELECT (? LIKE '%' || ? || '%') AS res", - "SELECT (? LIKE CONCAT('%', ?, '%')) AS res" + "SELECT (? LIKE CONCAT('%', ?, '%')) AS res", + "SELECT CASE WHEN (? LIKE '%' + CAST(? AS VARCHAR(MAX)) + '%') THEN 1 ELSE 0 END AS res" ), value = true ) diff --git a/scalasql/test/src/operations/DbMathOpsTests.scala b/scalasql/test/src/operations/DbMathOpsTests.scala index 6212a876..1bd76aec 100644 --- a/scalasql/test/src/operations/DbMathOpsTests.scala +++ b/scalasql/test/src/operations/DbMathOpsTests.scala @@ -6,7 +6,7 @@ import utest._ trait ExprMathOpsTests extends ScalaSqlSuite { override implicit def DbApiOpsConv(db: => DbApi): DbApiOps & MathOps = ??? - def description = "Math operations; supported by H2/Postgres/MySql, not supported by Sqlite" + def description = "Math operations; supported by H2/Postgres/MySql/MsSql, not supported by Sqlite" def tests = Tests { test("power") - checker( @@ -23,7 +23,7 @@ trait ExprMathOpsTests extends ScalaSqlSuite { test("ln") - checker( query = db.ln(16.0), - sql = "SELECT LN(?) AS res" + sqls = Seq("SELECT LN(?) AS res", "SELECT LOG(?) AS res") ) test("log") - checker( @@ -73,7 +73,7 @@ trait ExprMathOpsTests extends ScalaSqlSuite { test("atan2") - checker( query = db.atan2(16.0, 23.0), - sql = "SELECT ATAN2(?, ?) AS res" + sqls = Seq("SELECT ATAN2(?, ?) AS res", "SELECT ATN2(?, ?) AS res") ) test("pi") - checker( diff --git a/scalasql/test/src/operations/DbNumericOpsTests.scala b/scalasql/test/src/operations/DbNumericOpsTests.scala index 26d05474..80de928c 100644 --- a/scalasql/test/src/operations/DbNumericOpsTests.scala +++ b/scalasql/test/src/operations/DbNumericOpsTests.scala @@ -16,7 +16,14 @@ trait ExprNumericOpsTests extends ScalaSqlSuite { test("divide") - checker(query = Expr(6) / Expr(2), sql = "SELECT (? / ?) AS res", value = 3) - test("modulo") - checker(query = Expr(6) % Expr(2), sql = "SELECT MOD(?, ?) AS res", value = 0) + test("modulo") - checker( + query = Expr(6) % Expr(2), + sqls = Seq( + "SELECT MOD(?, ?) AS res", + "SELECT ? % ? AS res" + ), + value = 0 + ) test("bitwiseAnd") - checker( query = Expr(6) & Expr(2), @@ -32,7 +39,10 @@ trait ExprNumericOpsTests extends ScalaSqlSuite { test("between") - checker( query = Expr(4).between(Expr(2), Expr(6)), - sql = "SELECT ? BETWEEN ? AND ? AS res", + sqls = Seq( + "SELECT ? BETWEEN ? AND ? AS res", + "SELECT CASE WHEN ? BETWEEN ? AND ? THEN 1 ELSE 0 END AS res" + ), value = true ) @@ -49,9 +59,23 @@ trait ExprNumericOpsTests extends ScalaSqlSuite { test("abs") - checker(query = Expr(-4).abs, sql = "SELECT ABS(?) AS res", value = 4) - test("mod") - checker(query = Expr(8).mod(Expr(3)), sql = "SELECT MOD(?, ?) AS res", value = 2) + test("mod") - checker( + query = Expr(8).mod(Expr(3)), + sqls = Seq( + "SELECT MOD(?, ?) AS res", + "SELECT ? % ? AS res" + ), + value = 2 + ) - test("ceil") - checker(query = Expr(4.3).ceil, sql = "SELECT CEIL(?) AS res", value = 5.0) + test("ceil") - checker( + query = Expr(4.3).ceil, + sqls = Seq( + "SELECT CEIL(?) AS res", + "SELECT CEILING(?) AS res" + ), + value = 5.0 + ) test("floor") - checker(query = Expr(4.7).floor, sql = "SELECT FLOOR(?) AS res", value = 4.0) diff --git a/scalasql/test/src/operations/DbOpsTests.scala b/scalasql/test/src/operations/DbOpsTests.scala index 3cc5edff..594d2eef 100644 --- a/scalasql/test/src/operations/DbOpsTests.scala +++ b/scalasql/test/src/operations/DbOpsTests.scala @@ -4,7 +4,7 @@ import scalasql._ import scalasql.core.SqlStr.SqlStringSyntax import scalasql.core.Expr import utest._ -import utils.ScalaSqlSuite +import utils.{MsSqlSuite, ScalaSqlSuite} trait ExprOpsTests extends ScalaSqlSuite { def description = "Operations that can be performed on `Expr[T]` for any `T`" @@ -12,44 +12,128 @@ trait ExprOpsTests extends ScalaSqlSuite { test("numeric") { test("greaterThan") - - checker(query = Expr(6) > Expr(2), sql = "SELECT (? > ?) AS res", value = true) + checker( + query = Expr(6) > Expr(2), + sqls = Seq( + "SELECT (? > ?) AS res", + "SELECT CASE WHEN (? > ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) test("lessThan") - - checker(query = Expr(6) < Expr(2), sql = "SELECT (? < ?) AS res", value = false) + checker( + query = Expr(6) < Expr(2), + sqls = Seq( + "SELECT (? < ?) AS res", + "SELECT CASE WHEN (? < ?) THEN 1 ELSE 0 END AS res" + ), + value = false + ) test("greaterThanOrEquals") - - checker(query = Expr(6) >= Expr(2), sql = "SELECT (? >= ?) AS res", value = true) + checker( + query = Expr(6) >= Expr(2), + sqls = Seq( + "SELECT (? >= ?) AS res", + "SELECT CASE WHEN (? >= ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) test("lessThanOrEquals") - - checker(query = Expr(6) <= Expr(2), sql = "SELECT (? <= ?) AS res", value = false) + checker( + query = Expr(6) <= Expr(2), + sqls = Seq( + "SELECT (? <= ?) AS res", + "SELECT CASE WHEN (? <= ?) THEN 1 ELSE 0 END AS res" + ), + value = false + ) } test("string") { test("greaterThan") - - checker(query = Expr("A") > Expr("B"), sql = "SELECT (? > ?) AS res", value = false) + checker( + query = Expr("A") > Expr("B"), + sqls = Seq( + "SELECT (? > ?) AS res", + "SELECT CASE WHEN (? > ?) THEN 1 ELSE 0 END AS res" + ), + value = false + ) test("lessThan") - - checker(query = Expr("A") < Expr("B"), sql = "SELECT (? < ?) AS res", value = true) + checker( + query = Expr("A") < Expr("B"), + sqls = Seq( + "SELECT (? < ?) AS res", + "SELECT CASE WHEN (? < ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) test("greaterThanOrEquals") - - checker(query = Expr("A") >= Expr("B"), sql = "SELECT (? >= ?) AS res", value = false) + checker( + query = Expr("A") >= Expr("B"), + sqls = Seq( + "SELECT (? >= ?) AS res", + "SELECT CASE WHEN (? >= ?) THEN 1 ELSE 0 END AS res" + ), + value = false + ) test("lessThanOrEquals") - - checker(query = Expr("A") <= Expr("B"), sql = "SELECT (? <= ?) AS res", value = true) + checker( + query = Expr("A") <= Expr("B"), + sqls = Seq( + "SELECT (? <= ?) AS res", + "SELECT CASE WHEN (? <= ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) } test("boolean") { test("greaterThan") - - checker(query = Expr(true) > Expr(false), sql = "SELECT (? > ?) AS res", value = true) + checker( + query = Expr(true) > Expr(false), + sqls = Seq( + "SELECT (? > ?) AS res", + "SELECT CASE WHEN (? > ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) test("lessThan") - - checker(query = Expr(true) < Expr(true), sql = "SELECT (? < ?) AS res", value = false) + checker( + query = Expr(true) < Expr(true), + sqls = Seq( + "SELECT (? < ?) AS res", + "SELECT CASE WHEN (? < ?) THEN 1 ELSE 0 END AS res" + ), + value = false + ) test("greaterThanOrEquals") - - checker(query = Expr(true) >= Expr(true), sql = "SELECT (? >= ?) AS res", value = true) + checker( + query = Expr(true) >= Expr(true), + sqls = Seq( + "SELECT (? >= ?) AS res", + "SELECT CASE WHEN (? >= ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) test("lessThanOrEquals") - - checker(query = Expr(true) <= Expr(true), sql = "SELECT (? <= ?) AS res", value = true) + checker( + query = Expr(true) <= Expr(true), + sqls = Seq( + "SELECT (? <= ?) AS res", + "SELECT CASE WHEN (? <= ?) THEN 1 ELSE 0 END AS res" + ), + value = true + ) } test("cast") { @@ -76,6 +160,7 @@ trait ExprOpsTests extends ScalaSqlSuite { query = Expr(1234.1234).cast[Int], sqls = Seq( "SELECT CAST(? AS INTEGER) AS res", + "SELECT CAST(? AS INT) AS res", "SELECT CAST(? AS SIGNED) AS res" ), value = 1234 @@ -97,7 +182,8 @@ trait ExprOpsTests extends ScalaSqlSuite { "SELECT CAST(? AS VARCHAR) AS res", "SELECT CAST(? AS CHAR) AS res" ), - value = "1234.5678" + value = "1234.5678", + moreValues = Seq("1234.57") // MsSQL rounds to 2 decimal places ) test("localdate") - checker( @@ -113,6 +199,7 @@ trait ExprOpsTests extends ScalaSqlSuite { query = Expr("2023-11-12 03:22:41").cast[java.time.LocalDateTime], sqls = Seq( "SELECT CAST(? AS DATETIME) AS res", + "SELECT CAST(? AS DATETIME2) AS res", "SELECT CAST(? AS TIMESTAMP) AS res", "SELECT CAST(? AS VARCHAR) AS res" ), @@ -123,6 +210,7 @@ trait ExprOpsTests extends ScalaSqlSuite { query = Expr("2023-11-12 03:22:41").cast[java.util.Date], sqls = Seq( "SELECT CAST(? AS DATETIME) AS res", + "SELECT CAST(? AS DATETIME2) AS res", "SELECT CAST(? AS TIMESTAMP) AS res", "SELECT CAST(? AS VARCHAR) AS res" ), @@ -133,18 +221,23 @@ trait ExprOpsTests extends ScalaSqlSuite { query = Expr("2007-12-03 10:15:30.00").cast[java.time.Instant], sqls = Seq( "SELECT CAST(? AS DATETIME) AS res", + "SELECT CAST(? AS DATETIME2) AS res", "SELECT CAST(? AS TIMESTAMP) AS res", "SELECT CAST(? AS VARCHAR) AS res" ), value = java.time.Instant.parse("2007-12-03T02:15:30.00Z") ) - test("castNamed") - checker( - query = Expr(1234.5678).castNamed[String](sql"CHAR(3)"), - sql = "SELECT CAST(? AS CHAR(3)) AS res", - value = "123", - moreValues = Seq("1234.5678") // SQLITE doesn't truncate on cast - ) + test("castNamed") - { + // Microsoft SQL throws "Arithmetic overflow error for type varchar" + if (!this.isInstanceOf[MsSqlSuite]) + checker( + query = Expr(1234.5678).castNamed[String](sql"CHAR(3)"), + sql = "SELECT CAST(? AS CHAR(3)) AS res", + value = "123", + moreValues = Seq("1234.5678") // SQLITE doesn't truncate on cast + ) + } } } } diff --git a/scalasql/test/src/operations/DbStringOpsTests.scala b/scalasql/test/src/operations/DbStringOpsTests.scala index 436207ae..1098ab18 100644 --- a/scalasql/test/src/operations/DbStringOpsTests.scala +++ b/scalasql/test/src/operations/DbStringOpsTests.scala @@ -10,32 +10,43 @@ trait ExprStringOpsTests extends ScalaSqlSuite { def tests = Tests { test("plus") - checker( query = Expr("hello") + Expr("world"), - sqls = Seq("SELECT (? || ?) AS res", "SELECT CONCAT(?, ?) AS res"), + sqls = Seq( + "SELECT (? || ?) AS res", + "SELECT CONCAT(?, ?) AS res", + "SELECT (? + ?) AS res" + ), value = "helloworld" ) test("like") - checker( query = Expr("hello").like("he%"), - sql = "SELECT (? LIKE ?) AS res", + sqls = Seq( + "SELECT (? LIKE ?) AS res", + "SELECT CASE WHEN (? LIKE ?) THEN 1 ELSE 0 END AS res" + ), value = true ) test("length") - checker( query = Expr("hello").length, - sql = "SELECT LENGTH(?) AS res", + sqls = Seq("SELECT LENGTH(?) AS res", "SELECT LEN(?) AS res"), value = 5 ) test("octetLength") - checker( query = Expr("叉烧包").octetLength, - sql = "SELECT OCTET_LENGTH(?) AS res", + sqls = Seq("SELECT OCTET_LENGTH(?) AS res", "SELECT DATALENGTH(?) AS res"), value = 9, moreValues = Seq(6) // Not sure why HsqlExpr returns different value here ??? ) test("position") - checker( query = Expr("hello").indexOf("ll"), - sqls = Seq("SELECT POSITION(? IN ?) AS res", "SELECT INSTR(?, ?) AS res"), + sqls = Seq( + "SELECT POSITION(? IN ?) AS res", + "SELECT INSTR(?, ?) AS res", + "SELECT CHARINDEX(?, ?) AS res" + ), value = 3 ) @@ -73,7 +84,8 @@ trait ExprStringOpsTests extends ScalaSqlSuite { query = Expr("Hello").startsWith("Hel"), sqls = Seq( "SELECT (? LIKE ? || '%') AS res", - "SELECT (? LIKE CONCAT(?, '%')) AS res" + "SELECT (? LIKE CONCAT(?, '%')) AS res", + "SELECT CASE WHEN (? LIKE CAST(? AS VARCHAR(MAX)) + '%') THEN 1 ELSE 0 END AS res" ), value = true ) @@ -82,7 +94,8 @@ trait ExprStringOpsTests extends ScalaSqlSuite { query = Expr("Hello").endsWith("llo"), sqls = Seq( "SELECT (? LIKE '%' || ?) AS res", - "SELECT (? LIKE CONCAT('%', ?)) AS res" + "SELECT (? LIKE CONCAT('%', ?)) AS res", + "SELECT CASE WHEN (? LIKE '%' + CAST(? AS VARCHAR(MAX))) THEN 1 ELSE 0 END AS res" ), value = true ) @@ -91,7 +104,8 @@ trait ExprStringOpsTests extends ScalaSqlSuite { query = Expr("Hello").contains("ll"), sqls = Seq( "SELECT (? LIKE '%' || ? || '%') AS res", - "SELECT (? LIKE CONCAT('%', ?, '%')) AS res" + "SELECT (? LIKE CONCAT('%', ?, '%')) AS res", + "SELECT CASE WHEN (? LIKE '%' + CAST(? AS VARCHAR(MAX)) + '%') THEN 1 ELSE 0 END AS res" ), value = true ) diff --git a/scalasql/test/src/operations/DbAggOpsTests.scala b/scalasql/test/src/operations/ExprAggOpsTests.scala similarity index 75% rename from scalasql/test/src/operations/DbAggOpsTests.scala rename to scalasql/test/src/operations/ExprAggOpsTests.scala index cd49e6b0..b1702fa6 100644 --- a/scalasql/test/src/operations/DbAggOpsTests.scala +++ b/scalasql/test/src/operations/ExprAggOpsTests.scala @@ -1,7 +1,7 @@ package scalasql.operations import scalasql._ -import scalasql.H2Dialect +import scalasql.{H2Dialect, MsSqlDialect} import utest._ import utils.ScalaSqlSuite @@ -29,7 +29,10 @@ trait ExprAggOpsTests extends ScalaSqlSuite { test("none") - checker( query = Purchase.select.filter(_ => false).sumByOpt(_.count), - sql = "SELECT SUM(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + sqls = Seq( + "SELECT SUM(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + "SELECT SUM(purchase0.count) AS res FROM purchase purchase0 WHERE 1 = ?" + ), value = Option.empty[Int] ) } @@ -49,7 +52,10 @@ trait ExprAggOpsTests extends ScalaSqlSuite { test("none") - checker( query = Purchase.select.filter(_ => false).minByOpt(_.count), - sql = "SELECT MIN(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + sqls = Seq( + "SELECT MIN(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + "SELECT MIN(purchase0.count) AS res FROM purchase purchase0 WHERE 1 = ?" + ), value = Option.empty[Int] ) } @@ -69,7 +75,10 @@ trait ExprAggOpsTests extends ScalaSqlSuite { test("none") - checker( query = Purchase.select.filter(_ => false).maxByOpt(_.count), - sql = "SELECT MAX(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + sqls = Seq( + "SELECT MAX(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + "SELECT MAX(purchase0.count) AS res FROM purchase purchase0 WHERE 1 = ?" + ), value = Option.empty[Int] ) } @@ -89,7 +98,10 @@ trait ExprAggOpsTests extends ScalaSqlSuite { test("none") - checker( query = Purchase.select.filter(_ => false).avgByOpt(_.count), - sql = "SELECT AVG(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + sqls = Seq( + "SELECT AVG(purchase0.count) AS res FROM purchase purchase0 WHERE ?", + "SELECT AVG(purchase0.count) AS res FROM purchase purchase0 WHERE 1 = ?" + ), value = Option.empty[Int] ) } @@ -100,19 +112,21 @@ trait ExprAggOpsTests extends ScalaSqlSuite { "SELECT STRING_AGG(buyer0.name || '', '') AS res FROM buyer buyer0", "SELECT GROUP_CONCAT(buyer0.name || '', '') AS res FROM buyer buyer0", "SELECT LISTAGG(buyer0.name || '', '') AS res FROM buyer buyer0", - "SELECT GROUP_CONCAT(CONCAT(buyer0.name, '') SEPARATOR '') AS res FROM buyer buyer0" + "SELECT GROUP_CONCAT(CONCAT(buyer0.name, '') SEPARATOR '') AS res FROM buyer buyer0", + "SELECT STRING_AGG(buyer0.name + '', '') AS res FROM buyer buyer0" ), value = "James Bond叉烧包Li Haoyi" ) test("sep") - { - if (!this.isInstanceOf[H2Dialect]) + if (!this.isInstanceOf[H2Dialect] && !this.isInstanceOf[MsSqlDialect]) checker( query = Buyer.select.map(_.name).mkString(", "), sqls = Seq( "SELECT STRING_AGG(buyer0.name || '', ?) AS res FROM buyer buyer0", "SELECT GROUP_CONCAT(buyer0.name || '', ?) AS res FROM buyer buyer0", - "SELECT GROUP_CONCAT(CONCAT(buyer0.name, '') SEPARATOR ?) AS res FROM buyer buyer0" + "SELECT GROUP_CONCAT(CONCAT(buyer0.name, '') SEPARATOR ?) AS res FROM buyer buyer0", + "SELECT STRING_AGG(buyer0.name + '', ?) AS res FROM buyer buyer0" ), value = "James Bond, 叉烧包, Li Haoyi" ) diff --git a/scalasql/test/src/query/CompoundSelectTests.scala b/scalasql/test/src/query/CompoundSelectTests.scala index 7a984349..21397298 100644 --- a/scalasql/test/src/query/CompoundSelectTests.scala +++ b/scalasql/test/src/query/CompoundSelectTests.scala @@ -50,7 +50,10 @@ trait CompoundSelectTests extends ScalaSqlSuite { test("sortLimit") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).take(2) }, - sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + sqls = Seq( + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Cookie", "Socks"), docs = """ ScalaSql also supports various combinations of `.take` and `.drop`, translating to SQL @@ -61,14 +64,18 @@ trait CompoundSelectTests extends ScalaSqlSuite { query = Text { Product.select.sortBy(_.price).map(_.name).drop(2) }, sqls = Seq( "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ?", - "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?" + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS" ), value = Seq("Face Mask", "Skate Board", "Guitar", "Camera") ) test("sortLimitTwiceHigher") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).take(2).take(3) }, - sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + sqls = Seq( + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Cookie", "Socks"), docs = """ Note that `.drop` and `.take` follow Scala collections' semantics, so calling e.g. `.take` @@ -79,48 +86,68 @@ trait CompoundSelectTests extends ScalaSqlSuite { test("sortLimitTwiceLower") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).take(2).take(1) }, - sql = "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + sqls = Seq( + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Cookie") ) test("sortLimitOffset") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).take(2) }, - sql = + sqls = Seq( "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Face Mask", "Skate Board") ) test("sortLimitOffsetTwice") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).drop(2).take(1) }, - sql = + sqls = Seq( "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Guitar") ) test("sortOffsetLimit") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).drop(2).take(2) }, - sql = + sqls = Seq( "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Face Mask", "Skate Board") ) test("sortLimitOffset") - checker( query = Text { Product.select.sortBy(_.price).map(_.name).take(2).drop(1) }, - sql = + sqls = Seq( "SELECT product0.name AS res FROM product product0 ORDER BY product0.price LIMIT ? OFFSET ?", + "SELECT product0.name AS res FROM product product0 ORDER BY product0.price OFFSET ? ROWS FETCH FIRST ? ROWS ONLY" + ), value = Seq("Socks") ) } test("distinct") - checker( query = Text { Purchase.select.sortBy(_.total).desc.take(3).map(_.shippingInfoId).distinct }, - sql = """ - SELECT DISTINCT subquery0.res AS res - FROM (SELECT purchase0.shipping_info_id AS res - FROM purchase purchase0 - ORDER BY purchase0.total DESC - LIMIT ?) subquery0 - """, + sqls = Seq( + """ + SELECT DISTINCT subquery0.res AS res + FROM (SELECT purchase0.shipping_info_id AS res + FROM purchase purchase0 + ORDER BY purchase0.total DESC + LIMIT ?) subquery0 + """, + """ + SELECT DISTINCT subquery0.res AS res + FROM (SELECT purchase0.shipping_info_id AS res + FROM purchase purchase0 + ORDER BY purchase0.total DESC + OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0 + """ + ), value = Seq(1, 2), normalize = (x: Seq[Int]) => x.sorted, docs = """ @@ -134,15 +161,26 @@ trait CompoundSelectTests extends ScalaSqlSuite { Product.crossJoin().filter(_.id === p.productId).map(_.name) } }, - sql = """ - SELECT product1.name AS res - FROM (SELECT purchase0.product_id AS product_id, purchase0.total AS total - FROM purchase purchase0 - ORDER BY total DESC - LIMIT ?) subquery0 - CROSS JOIN product product1 - WHERE (product1.id = subquery0.product_id) - """, + sqls = Seq( + """ + SELECT product1.name AS res + FROM (SELECT purchase0.product_id AS product_id, purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + LIMIT ?) subquery0 + CROSS JOIN product product1 + WHERE (product1.id = subquery0.product_id) + """, + """ + SELECT product1.name AS res + FROM (SELECT purchase0.product_id AS product_id, purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0 + CROSS JOIN product product1 + WHERE (product1.id = subquery0.product_id) + """ + ), value = Seq("Camera", "Face Mask", "Guitar"), normalize = (x: Seq[String]) => x.sorted, docs = """ @@ -155,13 +193,22 @@ trait CompoundSelectTests extends ScalaSqlSuite { test("sumBy") - checker( query = Text { Purchase.select.sortBy(_.total).desc.take(3).sumBy(_.total) }, - sql = """ - SELECT SUM(subquery0.total) AS res - FROM (SELECT purchase0.total AS total - FROM purchase purchase0 - ORDER BY total DESC - LIMIT ?) subquery0 - """, + sqls = Seq( + """ + SELECT SUM(subquery0.total) AS res + FROM (SELECT purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + LIMIT ?) subquery0 + """, + """ + SELECT SUM(subquery0.total) AS res + FROM (SELECT purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0 + """ + ), value = 11788.0, normalize = (x: Double) => x.round.toDouble ) @@ -174,13 +221,22 @@ trait CompoundSelectTests extends ScalaSqlSuite { .take(3) .aggregate(p => (p.sumBy(_.total), p.avgBy(_.total))) }, - sql = """ - SELECT SUM(subquery0.total) AS res_0, AVG(subquery0.total) AS res_1 - FROM (SELECT purchase0.total AS total - FROM purchase purchase0 - ORDER BY total DESC - LIMIT ?) subquery0 - """, + sqls = Seq( + """ + SELECT SUM(subquery0.total) AS res_0, AVG(subquery0.total) AS res_1 + FROM (SELECT purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + LIMIT ?) subquery0 + """, + """ + SELECT SUM(subquery0.total) AS res_0, AVG(subquery0.total) AS res_1 + FROM (SELECT purchase0.total AS total + FROM purchase purchase0 + ORDER BY total DESC + OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) subquery0 + """ + ), value = (11788.0, 3929.0), normalize = (x: (Double, Double)) => (x._1.round.toDouble, x._2.round.toDouble) ) @@ -325,19 +381,34 @@ trait CompoundSelectTests extends ScalaSqlSuite { .drop(4) .take(4) }, - sql = """ - SELECT LOWER(product0.name) AS res - FROM product product0 - UNION ALL - SELECT LOWER(buyer0.name) AS res - FROM buyer buyer0 - UNION - SELECT LOWER(product0.kebab_case_name) AS res - FROM product product0 - ORDER BY res - LIMIT ? - OFFSET ? - """, + sqls = Seq( + """ + SELECT LOWER(product0.name) AS res + FROM product product0 + UNION ALL + SELECT LOWER(buyer0.name) AS res + FROM buyer buyer0 + UNION + SELECT LOWER(product0.kebab_case_name) AS res + FROM product product0 + ORDER BY res + LIMIT ? + OFFSET ? + """, + """ + SELECT LOWER(product0.name) AS res + FROM product product0 + UNION ALL + SELECT LOWER(buyer0.name) AS res + FROM buyer buyer0 + UNION + SELECT LOWER(product0.kebab_case_name) AS res + FROM product product0 + ORDER BY res + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY + """ + ), value = Seq("guitar", "james bond", "li haoyi", "skate board") ) } diff --git a/scalasql/test/src/query/DeleteTests.scala b/scalasql/test/src/query/DeleteTests.scala index 5431c74c..0e4c91ec 100644 --- a/scalasql/test/src/query/DeleteTests.scala +++ b/scalasql/test/src/query/DeleteTests.scala @@ -53,7 +53,7 @@ trait DeleteTests extends ScalaSqlSuite { test("all") { checker( query = Purchase.delete(_ => true), - sql = "DELETE FROM purchase WHERE ?", + sql = "DELETE FROM purchase", value = 7, docs = """ If you actually want to delete all rows in the table, you can explicitly diff --git a/scalasql/test/src/query/EscapedTableNameTests.scala b/scalasql/test/src/query/EscapedTableNameTests.scala index 34b7ff8f..b0058467 100644 --- a/scalasql/test/src/query/EscapedTableNameTests.scala +++ b/scalasql/test/src/query/EscapedTableNameTests.scala @@ -50,7 +50,7 @@ trait EscapedTableNameTests extends ScalaSqlSuite { query = Text { Select.delete(_ => true) }, - sql = s"DELETE FROM $tableNameEscaped WHERE ?", + sql = s"DELETE FROM $tableNameEscaped", value = 0, docs = "" ) diff --git a/scalasql/test/src/query/FlatJoinTests.scala b/scalasql/test/src/query/FlatJoinTests.scala index 455cb6fe..fd5c9d89 100644 --- a/scalasql/test/src/query/FlatJoinTests.scala +++ b/scalasql/test/src/query/FlatJoinTests.scala @@ -267,22 +267,42 @@ trait FlatJoinTests extends ScalaSqlSuite { si <- ShippingInfo.select.sortBy(_.id).asc.take(1).crossJoin() } yield (b.name, si.shippingDate) }, - sql = """ - SELECT - subquery0.name AS res_0, - subquery1.shipping_date AS res_1 - FROM - (SELECT buyer0.id AS id, buyer0.name AS name - FROM buyer buyer0 - ORDER BY id ASC - LIMIT ?) subquery0 - CROSS JOIN (SELECT - shipping_info1.id AS id, - shipping_info1.shipping_date AS shipping_date - FROM shipping_info shipping_info1 - ORDER BY id ASC - LIMIT ?) subquery1 - """, + sqls = Seq( + """ + SELECT + subquery0.name AS res_0, + subquery1.shipping_date AS res_1 + FROM + (SELECT buyer0.id AS id, buyer0.name AS name + FROM buyer buyer0 + ORDER BY id ASC + LIMIT ?) subquery0 + CROSS JOIN (SELECT + shipping_info1.id AS id, + shipping_info1.shipping_date AS shipping_date + FROM shipping_info shipping_info1 + ORDER BY id ASC + LIMIT ?) subquery1 + """, + """ + SELECT + subquery0.name AS res_0, + subquery1.shipping_date AS res_1 + FROM + (SELECT buyer0.id AS id, buyer0.name AS name + FROM buyer buyer0 + ORDER BY id ASC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + CROSS JOIN (SELECT + shipping_info1.id AS id, + shipping_info1.shipping_date AS shipping_date + FROM shipping_info shipping_info1 + ORDER BY id ASC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery1 + """ + ), value = Seq( ("James Bond", LocalDate.parse("2010-02-03")) ), diff --git a/scalasql/test/src/query/GetGeneratedKeysTests.scala b/scalasql/test/src/query/GetGeneratedKeysTests.scala index d30801bb..e4685714 100644 --- a/scalasql/test/src/query/GetGeneratedKeysTests.scala +++ b/scalasql/test/src/query/GetGeneratedKeysTests.scala @@ -1,7 +1,8 @@ package scalasql.query import scalasql._ -import scalasql.utils.ScalaSqlSuite +import scalasql.core.SqlStr.SqlStringSyntax +import scalasql.utils.{MsSqlSuite, ScalaSqlSuite} import utest._ import java.time.LocalDate @@ -14,6 +15,8 @@ trait GetGeneratedKeysTests extends ScalaSqlSuite { test("single") { test("values") - { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert .values( Buyer[Sc](17, "test buyer", LocalDate.parse("2023-09-09")) @@ -35,6 +38,8 @@ trait GetGeneratedKeysTests extends ScalaSqlSuite { test("columns") - { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert .columns( _.name := "test buyer", @@ -94,7 +99,8 @@ trait GetGeneratedKeysTests extends ScalaSqlSuite { INSERT INTO buyer (name, date_of_birth) VALUES (?, ?), (?, ?), (?, ?) """, - value = Seq(4, 5, 6), + // https://github.com/microsoft/mssql-jdbc/issues/245 + value = if (this.isInstanceOf[MsSqlSuite]) Seq(6) else Seq(4, 5, 6), docs = """ `getGeneratedKeys` can return multiple generated primary key values for a batch insert statement @@ -133,7 +139,8 @@ trait GetGeneratedKeysTests extends ScalaSqlSuite { FROM buyer buyer0 WHERE (buyer0.name <> ?) """, - value = Seq(4, 5), + // https://github.com/microsoft/mssql-jdbc/issues/245 + value = if (this.isInstanceOf[MsSqlSuite]) Seq(5) else Seq(4, 5), docs = """ `getGeneratedKeys` can return multiple generated primary key values for an `insert` based on a `select` diff --git a/scalasql/test/src/query/InsertTests.scala b/scalasql/test/src/query/InsertTests.scala index afd2c942..386f9b4e 100644 --- a/scalasql/test/src/query/InsertTests.scala +++ b/scalasql/test/src/query/InsertTests.scala @@ -1,8 +1,9 @@ package scalasql.query import scalasql._ +import scalasql.core.SqlStr.SqlStringSyntax import utest._ -import utils.ScalaSqlSuite +import utils.{MsSqlSuite, ScalaSqlSuite} import java.time.LocalDate @@ -13,6 +14,8 @@ trait InsertTests extends ScalaSqlSuite { test("single") { test("values") - { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert.values( Buyer[Sc](4, "test buyer", LocalDate.parse("2023-09-09")) ), @@ -52,6 +55,8 @@ trait InsertTests extends ScalaSqlSuite { test("columns") - { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert.columns( _.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09"), @@ -90,6 +95,8 @@ trait InsertTests extends ScalaSqlSuite { test("conflict") - intercept[Exception] { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert.columns( _.name := "test buyer", _.dateOfBirth := LocalDate.parse("2023-09-09"), @@ -102,6 +109,8 @@ trait InsertTests extends ScalaSqlSuite { test("batch") { test("values") - { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert.values( Buyer[Sc](4, "test buyer A", LocalDate.parse("2001-04-07")), Buyer[Sc](5, "test buyer B", LocalDate.parse("2002-05-08")), @@ -164,6 +173,8 @@ trait InsertTests extends ScalaSqlSuite { test("select") { test("caseclass") { checker( + preQuery = + Option.when(this.isInstanceOf[MsSqlSuite])(sql"SET IDENTITY_INSERT buyer ON").orNull, query = Buyer.insert.select( identity, Buyer.select diff --git a/scalasql/test/src/query/JoinTests.scala b/scalasql/test/src/query/JoinTests.scala index b4054d42..2d890117 100644 --- a/scalasql/test/src/query/JoinTests.scala +++ b/scalasql/test/src/query/JoinTests.scala @@ -251,21 +251,13 @@ trait JoinTests extends ScalaSqlSuite { .leftJoin(ShippingInfo)(_.id `=` _.buyerId) .map { case (b, si) => (b.name, si.map(_.shippingDate)) } .sortBy(_._2) - .nullsFirst }, sqls = Seq( """ SELECT buyer0.name AS res_0, shipping_info1.shipping_date AS res_1 FROM buyer buyer0 LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) - ORDER BY res_1 NULLS FIRST - """, - // MySQL doesn't support NULLS FIRST syntax and needs a workaround - """ - SELECT buyer0.name AS res_0, shipping_info1.shipping_date AS res_1 - FROM buyer buyer0 - LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) - ORDER BY res_1 IS NULL DESC, res_1 + ORDER BY res_1 """ ), value = Seq[(String, Option[LocalDate])]( @@ -274,6 +266,14 @@ trait JoinTests extends ScalaSqlSuite { ("James Bond", Some(LocalDate.parse("2012-04-05"))), ("叉烧包", Some(LocalDate.parse("2012-05-06"))) ), + moreValues = Seq( + Seq[(String, Option[LocalDate])]( + ("叉烧包", Some(LocalDate.parse("2010-02-03"))), + ("James Bond", Some(LocalDate.parse("2012-04-05"))), + ("叉烧包", Some(LocalDate.parse("2012-05-06"))), + ("Li Haoyi", None) + ) + ), docs = """ `JoinNullable[Expr[T]]`s can be implicitly used as `Expr[Option[T]]`s. This allows them to participate in any database query logic than any other `Expr[Option[T]]`s @@ -289,12 +289,20 @@ trait JoinTests extends ScalaSqlSuite { .distinct .sortBy(_._1) }, - sql = """ + sqls = Seq( + """ SELECT DISTINCT buyer0.name AS res_0, (shipping_info1.id IS NOT NULL) AS res_1 FROM buyer buyer0 LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) ORDER BY res_0 - """, + """, + """ + SELECT DISTINCT buyer0.name AS res_0, CASE WHEN (shipping_info1.id IS NOT NULL) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) + ORDER BY res_0 + """ + ), value = Seq( ("James Bond", true), ("Li Haoyi", false), @@ -312,13 +320,22 @@ trait JoinTests extends ScalaSqlSuite { .leftJoin(ShippingInfo)(_.id `=` _.buyerId) .map { case (b, si) => (b.name, si.map(_.shippingDate) > b.dateOfBirth) } }, - sql = """ + sqls = Seq( + """ SELECT buyer0.name AS res_0, (shipping_info1.shipping_date > buyer0.date_of_birth) AS res_1 FROM buyer buyer0 LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) - """, + """, + """ + SELECT + buyer0.name AS res_0, + CASE WHEN (shipping_info1.shipping_date > buyer0.date_of_birth) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) + """ + ), value = Seq( ("James Bond", true), ("Li Haoyi", false), @@ -335,13 +352,22 @@ trait JoinTests extends ScalaSqlSuite { (b.name, JoinNullable.toExpr(si.map(_.shippingDate)) > b.dateOfBirth) } }, - sql = """ + sqls = Seq( + """ SELECT buyer0.name AS res_0, (shipping_info1.shipping_date > buyer0.date_of_birth) AS res_1 FROM buyer buyer0 LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) - """, + """, + """ + SELECT + buyer0.name AS res_0, + CASE WHEN (shipping_info1.shipping_date > buyer0.date_of_birth) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + LEFT JOIN shipping_info shipping_info1 ON (buyer0.id = shipping_info1.buyer_id) + """ + ), value = Seq( ("James Bond", true), ("Li Haoyi", false), diff --git a/scalasql/test/src/query/SelectTests.scala b/scalasql/test/src/query/SelectTests.scala index 8a26ef29..d2c0f89f 100644 --- a/scalasql/test/src/query/SelectTests.scala +++ b/scalasql/test/src/query/SelectTests.scala @@ -310,15 +310,28 @@ trait SelectTests extends ScalaSqlSuite { ) ) }, - sql = """ - SELECT - product0.name AS res_0, - (SELECT purchase1.total AS res - FROM purchase purchase1 - WHERE (purchase1.product_id = product0.id) - ORDER BY res DESC - LIMIT ?) AS res_1 - FROM product product0""", + sqls = Seq( + """ + SELECT + product0.name AS res_0, + (SELECT purchase1.total AS res + FROM purchase purchase1 + WHERE (purchase1.product_id = product0.id) + ORDER BY res DESC + LIMIT ?) AS res_1 + FROM product product0 + """, + """ + SELECT + product0.name AS res_0, + (SELECT purchase1.total AS res + FROM purchase purchase1 + WHERE (purchase1.product_id = product0.id) + ORDER BY res DESC + OFFSET ? ROWS FETCH FIRST ? ROWS ONLY) AS res_1 + FROM product product0 + """ + ), value = Seq( ("Face Mask", 888.0), ("Guitar", 900.0), @@ -531,39 +544,44 @@ trait SelectTests extends ScalaSqlSuite { """ ) - test("containsMultiple") - checker( - query = Text { - Buyer.select.filter(b => - ShippingInfo.select - .map(s => (s.buyerId, s.shippingDate)) - .contains((b.id, LocalDate.parse("2010-02-03"))) + test("containsMultiple") - { + // Microsoft SQL Server does not support tuple IN + if (!this.isInstanceOf[MsSqlDialect]) + checker( + query = Text { + Buyer.select.filter(b => + ShippingInfo.select + .map(s => (s.buyerId, s.shippingDate)) + .contains((b.id, LocalDate.parse("2010-02-03"))) + ) + }, + sql = """ + SELECT buyer0.id AS id, buyer0.name AS name, buyer0.date_of_birth AS date_of_birth + FROM buyer buyer0 + WHERE ((buyer0.id, ?) IN (SELECT + shipping_info1.buyer_id AS res_0, + shipping_info1.shipping_date AS res_1 + FROM shipping_info shipping_info1)) + """, + value = Seq( + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")) + ), + docs = """ + ScalaSql's `.contains` can take a compound Scala value, which translates into + SQL's `IN` syntax on a tuple with multiple columns. e.g. this query uses that ability + to find the `Buyer` which has a shipment on a specific date, as an alternative + to doing a `JOIN`. + """ ) - }, - sql = """ - SELECT buyer0.id AS id, buyer0.name AS name, buyer0.date_of_birth AS date_of_birth - FROM buyer buyer0 - WHERE ((buyer0.id, ?) IN (SELECT - shipping_info1.buyer_id AS res_0, - shipping_info1.shipping_date AS res_1 - FROM shipping_info shipping_info1)) - """, - value = Seq( - Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")) - ), - docs = """ - ScalaSql's `.contains` can take a compound Scala value, which translates into - SQL's `IN` syntax on a tuple with multiple columns. e.g. this query uses that ability - to find the `Buyer` which has a shipment on a specific date, as an alternative - to doing a `JOIN`. - """ - ) + } test("nonEmpty") - checker( query = Text { Buyer.select .map(b => (b.name, ShippingInfo.select.filter(_.buyerId `=` b.id).map(_.id).nonEmpty)) }, - sql = """ + sqls = Seq( + """ SELECT buyer0.name AS res_0, (EXISTS (SELECT @@ -571,7 +589,17 @@ trait SelectTests extends ScalaSqlSuite { FROM shipping_info shipping_info1 WHERE (shipping_info1.buyer_id = buyer0.id))) AS res_1 FROM buyer buyer0 - """, + """, + """ + SELECT + buyer0.name AS res_0, + CASE WHEN (EXISTS (SELECT + shipping_info1.id AS res + FROM shipping_info shipping_info1 + WHERE (shipping_info1.buyer_id = buyer0.id))) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + """ + ), value = Seq(("James Bond", true), ("叉烧包", true), ("Li Haoyi", false)), docs = """ ScalaSql's `.nonEmpty` and `.isEmpty` translates to SQL's `EXISTS` and `NOT EXISTS` syntax @@ -583,7 +611,8 @@ trait SelectTests extends ScalaSqlSuite { Buyer.select .map(b => (b.name, ShippingInfo.select.filter(_.buyerId `=` b.id).map(_.id).isEmpty)) }, - sql = """ + sqls = Seq( + """ SELECT buyer0.name AS res_0, (NOT EXISTS (SELECT @@ -591,7 +620,17 @@ trait SelectTests extends ScalaSqlSuite { FROM shipping_info shipping_info1 WHERE (shipping_info1.buyer_id = buyer0.id))) AS res_1 FROM buyer buyer0 - """, + """, + """ + SELECT + buyer0.name AS res_0, + CASE WHEN (NOT EXISTS (SELECT + shipping_info1.id AS res + FROM shipping_info shipping_info1 + WHERE (shipping_info1.buyer_id = buyer0.id))) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + """ + ), value = Seq(("James Bond", false), ("叉烧包", false), ("Li Haoyi", true)) ) @@ -639,6 +678,31 @@ trait SelectTests extends ScalaSqlSuite { ) ) ), + moreValues = Seq[Seq[(Int, (Buyer[Sc], (Int, ShippingInfo[Sc])))]]( + Seq( + ( + 1, + ( + Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), + (2, ShippingInfo[Sc](2, 1, LocalDate.parse("2012-04-05"))) + ) + ), + ( + 2, + ( + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + (3, ShippingInfo[Sc](3, 2, LocalDate.parse("2012-05-06"))) + ) + ), + ( + 2, + ( + Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), + (1, ShippingInfo[Sc](1, 2, LocalDate.parse("2010-02-03"))) + ) + ) + ) + ), docs = """ Queries can output arbitrarily nested tuples of `Expr[T]` and `case class` instances of `Foo[Expr]`, which will be de-serialized into nested tuples @@ -676,6 +740,15 @@ trait SelectTests extends ScalaSqlSuite { WHEN (product0.price <= ?) THEN CONCAT(product0.name, ?) END AS res FROM product product0 + """, + """ + SELECT + CASE + WHEN (product0.price > ?) THEN (product0.name + ?) + WHEN (product0.price > ?) THEN (product0.name + ?) + WHEN (product0.price <= ?) THEN (product0.name + ?) + END AS res + FROM product product0 """ ), value = Seq( @@ -719,6 +792,15 @@ trait SelectTests extends ScalaSqlSuite { ELSE CONCAT(product0.name, ?) END AS res FROM product product0 + """, + """ + SELECT + CASE + WHEN (product0.price > ?) THEN (product0.name + ?) + WHEN (product0.price > ?) THEN (product0.name + ?) + ELSE (product0.name + ?) + END AS res + FROM product product0 """ ), value = Seq( diff --git a/scalasql/test/src/query/SubQueryTests.scala b/scalasql/test/src/query/SubQueryTests.scala index ee3a762a..65b3e3ce 100644 --- a/scalasql/test/src/query/SubQueryTests.scala +++ b/scalasql/test/src/query/SubQueryTests.scala @@ -19,7 +19,8 @@ trait SubQueryTests extends ScalaSqlSuite { .join(Product.select.sortBy(_.price).desc.take(1))(_.productId `=` _.id) .map { case (purchase, product) => purchase.total } }, - sql = """ + sqls = Seq( + """ SELECT purchase0.total AS res FROM purchase purchase0 JOIN (SELECT product1.id AS id, product1.price AS price @@ -27,7 +28,18 @@ trait SubQueryTests extends ScalaSqlSuite { ORDER BY price DESC LIMIT ?) subquery1 ON (purchase0.product_id = subquery1.id) - """, + """, + """ + SELECT purchase0.total AS res + FROM purchase purchase0 + JOIN (SELECT product1.id AS id, product1.price AS price + FROM product product1 + ORDER BY price DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery1 + ON (purchase0.product_id = subquery1.id) + """ + ), value = Seq(10000.0), docs = """ A ScalaSql `.join` referencing a `.select` translates straightforwardly @@ -41,14 +53,25 @@ trait SubQueryTests extends ScalaSqlSuite { case (product, purchase) => purchase.total } }, - sql = """ - SELECT purchase1.total AS res - FROM (SELECT product0.id AS id, product0.price AS price - FROM product product0 - ORDER BY price DESC - LIMIT ?) subquery0 - JOIN purchase purchase1 ON (subquery0.id = purchase1.product_id) - """, + sqls = Seq( + """ + SELECT purchase1.total AS res + FROM (SELECT product0.id AS id, product0.price AS price + FROM product product0 + ORDER BY price DESC + LIMIT ?) subquery0 + JOIN purchase purchase1 ON (subquery0.id = purchase1.product_id) + """, + """ + SELECT purchase1.total AS res + FROM (SELECT product0.id AS id, product0.price AS price + FROM product product0 + ORDER BY price DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + JOIN purchase purchase1 ON (subquery0.id = purchase1.product_id) + """ + ), value = Seq(10000.0), docs = """ Some sequences of operations cannot be expressed as a single SQL query, @@ -68,25 +91,48 @@ trait SubQueryTests extends ScalaSqlSuite { .join(Purchase.select.sortBy(_.count).desc.take(3))(_.id `=` _.productId) .map { case (product, purchase) => (product.name, purchase.count) } }, - sql = """ - SELECT - subquery0.name AS res_0, - subquery1.count AS res_1 - FROM (SELECT - product0.id AS id, - product0.name AS name, - product0.price AS price - FROM product product0 - ORDER BY price DESC - LIMIT ?) subquery0 - JOIN (SELECT - purchase1.product_id AS product_id, - purchase1.count AS count - FROM purchase purchase1 - ORDER BY count DESC - LIMIT ?) subquery1 - ON (subquery0.id = subquery1.product_id) - """, + sqls = Seq( + """ + SELECT + subquery0.name AS res_0, + subquery1.count AS res_1 + FROM (SELECT + product0.id AS id, + product0.name AS name, + product0.price AS price + FROM product product0 + ORDER BY price DESC + LIMIT ?) subquery0 + JOIN (SELECT + purchase1.product_id AS product_id, + purchase1.count AS count + FROM purchase purchase1 + ORDER BY count DESC + LIMIT ?) subquery1 + ON (subquery0.id = subquery1.product_id) + """, + """ + SELECT + subquery0.name AS res_0, + subquery1.count AS res_1 + FROM (SELECT + product0.id AS id, + product0.name AS name, + product0.price AS price + FROM product product0 + ORDER BY price DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + JOIN (SELECT + purchase1.product_id AS product_id, + purchase1.count AS count + FROM purchase purchase1 + ORDER BY count DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery1 + ON (subquery0.id = subquery1.product_id) + """ + ), value = Seq(("Camera", 10)), docs = """ This example shows a ScalaSql query that results in a subquery in both @@ -98,17 +144,32 @@ trait SubQueryTests extends ScalaSqlSuite { query = Text { Product.select.sortBy(_.price).desc.take(4).sortBy(_.price).asc.take(2).map(_.name) }, - sql = """ - SELECT subquery0.name AS res - FROM (SELECT - product0.name AS name, - product0.price AS price - FROM product product0 - ORDER BY price DESC - LIMIT ?) subquery0 - ORDER BY subquery0.price ASC - LIMIT ? - """, + sqls = Seq( + """ + SELECT subquery0.name AS res + FROM (SELECT + product0.name AS name, + product0.price AS price + FROM product product0 + ORDER BY price DESC + LIMIT ?) subquery0 + ORDER BY subquery0.price ASC + LIMIT ? + """, + """ + SELECT subquery0.name AS res + FROM (SELECT + product0.name AS name, + product0.price AS price + FROM product product0 + ORDER BY price DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + ORDER BY subquery0.price ASC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY + """ + ), value = Seq("Face Mask", "Skate Board"), docs = """ Performing multiple sorts with `.take`s in between is also something @@ -121,17 +182,31 @@ trait SubQueryTests extends ScalaSqlSuite { query = Text { Purchase.select.sortBy(_.count).take(5).groupBy(_.productId)(_.sumBy(_.total)) }, - sql = """ - SELECT subquery0.product_id AS res_0, SUM(subquery0.total) AS res_1 - FROM (SELECT - purchase0.product_id AS product_id, - purchase0.count AS count, - purchase0.total AS total - FROM purchase purchase0 - ORDER BY count - LIMIT ?) subquery0 - GROUP BY subquery0.product_id - """, + sqls = Seq( + """ + SELECT subquery0.product_id AS res_0, SUM(subquery0.total) AS res_1 + FROM (SELECT + purchase0.product_id AS product_id, + purchase0.count AS count, + purchase0.total AS total + FROM purchase purchase0 + ORDER BY count + LIMIT ?) subquery0 + GROUP BY subquery0.product_id + """, + """ + SELECT subquery0.product_id AS res_0, SUM(subquery0.total) AS res_1 + FROM (SELECT + purchase0.product_id AS product_id, + purchase0.count AS count, + purchase0.total AS total + FROM purchase purchase0 + ORDER BY count + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + GROUP BY subquery0.product_id + """ + ), value = Seq((1, 44.4), (2, 900.0), (3, 15.7), (4, 493.8), (5, 10000.0)), normalize = (x: Seq[(Int, Double)]) => x.sorted ) @@ -216,7 +291,8 @@ trait SubQueryTests extends ScalaSqlSuite { query = Text { Buyer.select.map(c => (c, ShippingInfo.select.filter(p => c.id `=` p.buyerId).size `=` 1)) }, - sql = """ + sqls = Seq( + """ SELECT buyer0.id AS res_0_id, buyer0.name AS res_0_name, @@ -226,7 +302,19 @@ trait SubQueryTests extends ScalaSqlSuite { FROM shipping_info shipping_info1 WHERE (buyer0.id = shipping_info1.buyer_id)) = ?) AS res_1 FROM buyer buyer0 - """, + """, + """ + SELECT + buyer0.id AS res_0_id, + buyer0.name AS res_0_name, + buyer0.date_of_birth AS res_0_date_of_birth, + CASE WHEN ((SELECT + COUNT(1) AS res + FROM shipping_info shipping_info1 + WHERE (buyer0.id = shipping_info1.buyer_id)) = ?) THEN 1 ELSE 0 END AS res_1 + FROM buyer buyer0 + """ + ), value = Seq( (Buyer[Sc](1, "James Bond", LocalDate.parse("2001-02-03")), true), (Buyer[Sc](2, "叉烧包", LocalDate.parse("1923-11-12")), false), @@ -241,16 +329,25 @@ trait SubQueryTests extends ScalaSqlSuite { .take(2) .unionAll(Product.select.map(_.kebabCaseName.toLowerCase)) }, - sql = """ - SELECT subquery0.res AS res - FROM (SELECT - LOWER(buyer0.name) AS res + sqls = Seq( + """ + SELECT subquery0.res AS res + FROM (SELECT + LOWER(buyer0.name) AS res + FROM buyer buyer0 + LIMIT ?) subquery0 + UNION ALL + SELECT LOWER(product0.kebab_case_name) AS res + FROM product product0 + """, + """ + SELECT TOP(?) LOWER(buyer0.name) AS res FROM buyer buyer0 - LIMIT ?) subquery0 - UNION ALL - SELECT LOWER(product0.kebab_case_name) AS res - FROM product product0 - """, + UNION ALL + SELECT LOWER(product0.kebab_case_name) AS res + FROM product product0 + """ + ), value = Seq("james bond", "叉烧包", "face-mask", "guitar", "socks", "skate-board", "camera", "cookie") ) @@ -261,16 +358,25 @@ trait SubQueryTests extends ScalaSqlSuite { .map(_.name.toLowerCase) .unionAll(Product.select.map(_.kebabCaseName.toLowerCase).take(2)) }, - sql = """ - SELECT LOWER(buyer0.name) AS res - FROM buyer buyer0 - UNION ALL - SELECT subquery0.res AS res - FROM (SELECT - LOWER(product0.kebab_case_name) AS res + sqls = Seq( + """ + SELECT LOWER(buyer0.name) AS res + FROM buyer buyer0 + UNION ALL + SELECT subquery0.res AS res + FROM (SELECT + LOWER(product0.kebab_case_name) AS res + FROM product product0 + LIMIT ?) subquery0 + """, + """ + SELECT LOWER(buyer0.name) AS res + FROM buyer buyer0 + UNION ALL + SELECT TOP(?) LOWER(product0.kebab_case_name) AS res FROM product product0 - LIMIT ?) subquery0 - """, + """ + ), value = Seq("james bond", "叉烧包", "li haoyi", "face-mask", "guitar") ) @@ -351,26 +457,51 @@ trait SubQueryTests extends ScalaSqlSuite { .toExpr } }, - sql = """ - SELECT - buyer0.name AS res_0, - (SELECT - (SELECT - (SELECT product3.price AS res - FROM product product3 - WHERE (product3.id = purchase2.product_id) + sqls = Seq( + """ + SELECT + buyer0.name AS res_0, + (SELECT + (SELECT + (SELECT product3.price AS res + FROM product product3 + WHERE (product3.id = purchase2.product_id) + ORDER BY res DESC + LIMIT ?) AS res + FROM purchase purchase2 + WHERE (purchase2.shipping_info_id = shipping_info1.id) + ORDER BY res DESC + LIMIT ?) AS res + FROM shipping_info shipping_info1 + WHERE (shipping_info1.buyer_id = buyer0.id) ORDER BY res DESC - LIMIT ?) AS res - FROM purchase purchase2 - WHERE (purchase2.shipping_info_id = shipping_info1.id) - ORDER BY res DESC - LIMIT ?) AS res - FROM shipping_info shipping_info1 - WHERE (shipping_info1.buyer_id = buyer0.id) - ORDER BY res DESC - LIMIT ?) AS res_1 - FROM buyer buyer0 - """, + LIMIT ?) AS res_1 + FROM buyer buyer0 + """, + """ + SELECT + buyer0.name AS res_0, + (SELECT + (SELECT + (SELECT product3.price AS res + FROM product product3 + WHERE (product3.id = purchase2.product_id) + ORDER BY res DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) AS res + FROM purchase purchase2 + WHERE (purchase2.shipping_info_id = shipping_info1.id) + ORDER BY res DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) AS res + FROM shipping_info shipping_info1 + WHERE (shipping_info1.buyer_id = buyer0.id) + ORDER BY res DESC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) AS res_1 + FROM buyer buyer0 + """ + ), value = Seq( ("James Bond", 1000.0), ("叉烧包", 300.0), diff --git a/scalasql/test/src/query/UpdateJoinTests.scala b/scalasql/test/src/query/UpdateJoinTests.scala index 3091cd28..b6331536 100644 --- a/scalasql/test/src/query/UpdateJoinTests.scala +++ b/scalasql/test/src/query/UpdateJoinTests.scala @@ -120,6 +120,18 @@ trait UpdateJoinTests extends ScalaSqlSuite { LIMIT ?) subquery0 ON (buyer.id = subquery0.buyer_id) SET buyer.date_of_birth = subquery0.shipping_date WHERE (buyer.name = ?) + """, + """ + UPDATE buyer SET date_of_birth = subquery0.shipping_date + FROM (SELECT + shipping_info0.id AS id, + shipping_info0.buyer_id AS buyer_id, + shipping_info0.shipping_date AS shipping_date + FROM shipping_info shipping_info0 + ORDER BY id ASC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + WHERE (buyer.id = subquery0.buyer_id) AND (buyer.name = ?) """ ), value = 1, @@ -166,6 +178,17 @@ trait UpdateJoinTests extends ScalaSqlSuite { LIMIT ?) subquery0 ON (buyer.id = subquery0.buyer_id) SET buyer.date_of_birth = ? WHERE (buyer.name = ?) + """, + """ + UPDATE buyer SET date_of_birth = ? + FROM (SELECT + shipping_info0.id AS id, + shipping_info0.buyer_id AS buyer_id + FROM shipping_info shipping_info0 + ORDER BY id ASC + OFFSET ? ROWS + FETCH FIRST ? ROWS ONLY) subquery0 + WHERE (buyer.id = subquery0.buyer_id) AND (buyer.name = ?) """ ), value = 1 diff --git a/scalasql/test/src/query/WindowFunctionTests.scala b/scalasql/test/src/query/WindowFunctionTests.scala index ef721e3f..147994ec 100644 --- a/scalasql/test/src/query/WindowFunctionTests.scala +++ b/scalasql/test/src/query/WindowFunctionTests.scala @@ -1,7 +1,7 @@ package scalasql.query import scalasql._ -import scalasql.MySqlDialect +import scalasql.{MsSqlDialect, MySqlDialect} import sourcecode.Text import utest._ import utils.ScalaSqlSuite @@ -334,34 +334,38 @@ trait WindowFunctionTests extends ScalaSqlSuite { normalize = (x: Seq[(Int, Double, Double)]) => x.sorted ) - test("nthValue") - checker( - query = Text { - Purchase.select.map(p => - ( - p.shippingInfoId, - p.total, - db.nthValue(p.total, 2).over.partitionBy(p.shippingInfoId).sortBy(p.total).asc - ) + test("nthValue") - { + // Microsoft SQL don't support `.nthValue` + if (!this.isInstanceOf[MsSqlDialect]) + checker( + query = Text { + Purchase.select.map(p => + ( + p.shippingInfoId, + p.total, + db.nthValue(p.total, 2).over.partitionBy(p.shippingInfoId).sortBy(p.total).asc + ) + ) + }, + sql = """ + SELECT + purchase0.shipping_info_id AS res_0, + purchase0.total AS res_1, + NTH_VALUE(purchase0.total, ?) OVER (PARTITION BY purchase0.shipping_info_id ORDER BY purchase0.total ASC) AS res_2 + FROM purchase purchase0 + """, + value = Seq[(Int, Double, Double)]( + (1, 15.7, 0.0), + (1, 888.0, 888.0), + (1, 900.0, 888.0), + (2, 493.8, 0.0), + (2, 10000.0, 10000.0), + (3, 1.3, 0.0), + (3, 44.4, 44.4) + ), + normalize = (x: Seq[(Int, Double, Double)]) => x.sorted ) - }, - sql = """ - SELECT - purchase0.shipping_info_id AS res_0, - purchase0.total AS res_1, - NTH_VALUE(purchase0.total, ?) OVER (PARTITION BY purchase0.shipping_info_id ORDER BY purchase0.total ASC) AS res_2 - FROM purchase purchase0 - """, - value = Seq[(Int, Double, Double)]( - (1, 15.7, 0.0), - (1, 888.0, 888.0), - (1, 900.0, 888.0), - (2, 493.8, 0.0), - (2, 10000.0, 10000.0), - (3, 1.3, 0.0), - (3, 44.4, 44.4) - ), - normalize = (x: Seq[(Int, Double, Double)]) => x.sorted - ) + } } test("aggregate") { @@ -471,8 +475,8 @@ trait WindowFunctionTests extends ScalaSqlSuite { ) } test("frames") - { - // MySql doesn't support `.exclude` - if (!this.isInstanceOf[MySqlDialect]) + // MySql and Microsoft SQL don't support `.exclude` + if (!(this.isInstanceOf[MySqlDialect] | this.isInstanceOf[MsSqlDialect])) checker( query = Text { Purchase.select.mapAggregate((p, ps) => @@ -522,8 +526,8 @@ trait WindowFunctionTests extends ScalaSqlSuite { } test("filter") - { - // MySql doesn't support FILTER - if (!this.isInstanceOf[MySqlDialect]) + // MySql and Microsoft SQL don't support FILTER + if (!(this.isInstanceOf[MySqlDialect] | this.isInstanceOf[MsSqlDialect])) checker( query = Text { Purchase.select.mapAggregate((p, ps) => diff --git a/scalasql/test/src/query/WithCteTests.scala b/scalasql/test/src/query/WithCteTests.scala index 93c837e1..8c85cc8a 100644 --- a/scalasql/test/src/query/WithCteTests.scala +++ b/scalasql/test/src/query/WithCteTests.scala @@ -28,6 +28,11 @@ trait WithCteTests extends ScalaSqlSuite { WITH cte0 (res) AS (SELECT buyer0.name AS res FROM buyer buyer0) SELECT CONCAT(cte0.res, ?) AS res FROM cte0 + """, + """ + WITH cte0 (res) AS (SELECT buyer0.name AS res FROM buyer buyer0) + SELECT (cte0.res + ?) AS res + FROM cte0 """ ), value = Seq("James Bond-suffix", "叉烧包-suffix", "Li Haoyi-suffix"), @@ -85,6 +90,11 @@ trait WithCteTests extends ScalaSqlSuite { WITH cte0 (name) AS (SELECT buyer0.name AS name FROM buyer buyer0) SELECT CONCAT(cte0.name, ?) AS res FROM cte0 + """, + """ + WITH cte0 (name) AS (SELECT buyer0.name AS name FROM buyer buyer0) + SELECT (cte0.name + ?) AS res + FROM cte0 """ ), value = Seq("James Bond-suffix", "叉烧包-suffix", "Li Haoyi-suffix"), @@ -94,55 +104,59 @@ trait WithCteTests extends ScalaSqlSuite { """ ) - test("subquery") - checker( - query = Text { - db.withCte(Buyer.select) { bs => - db.withCte(ShippingInfo.select) { sis => - bs.join(sis)(_.id === _.buyerId) - } - }.join( - db.withCte(Product.select) { prs => - Purchase.select.join(prs)(_.productId === _.id) - } - )(_._2.id === _._1.shippingInfoId) - .map { case (b, s, (pu, pr)) => (b.name, pr.name) } - }, - sql = """ - SELECT subquery0.res_0_name AS res_0, subquery1.res_1_name AS res_1 - FROM (WITH - cte0 (id, name) - AS (SELECT buyer0.id AS id, buyer0.name AS name FROM buyer buyer0), - cte1 (id, buyer_id) - AS (SELECT shipping_info1.id AS id, shipping_info1.buyer_id AS buyer_id - FROM shipping_info shipping_info1) - SELECT cte0.name AS res_0_name, cte1.id AS res_1_id - FROM cte0 - JOIN cte1 ON (cte0.id = cte1.buyer_id)) subquery0 - JOIN (WITH - cte1 (id, name) - AS (SELECT product1.id AS id, product1.name AS name FROM product product1) - SELECT - purchase2.shipping_info_id AS res_0_shipping_info_id, - cte1.name AS res_1_name - FROM purchase purchase2 - JOIN cte1 ON (purchase2.product_id = cte1.id)) subquery1 - ON (subquery0.res_1_id = subquery1.res_0_shipping_info_id) - """, - value = Seq[(String, String)]( - ("James Bond", "Camera"), - ("James Bond", "Skate Board"), - ("叉烧包", "Cookie"), - ("叉烧包", "Face Mask"), - ("叉烧包", "Face Mask"), - ("叉烧包", "Guitar"), - ("叉烧包", "Socks") - ), - docs = """ - ScalaSql's `withCte` can be used anywhere a `.select` operator can be used. The - generated `WITH` clauses may be wrapped in sub-queries in scenarios where they - cannot be easily combined into a single query - """, - normalize = (x: Seq[(String, String)]) => x.sorted - ) + test("subquery") - { + // Microsoft SQL does not support CTEs in subqueries + if (!this.isInstanceOf[MsSqlDialect]) + checker( + query = Text { + db.withCte(Buyer.select) { bs => + db.withCte(ShippingInfo.select) { sis => + bs.join(sis)(_.id === _.buyerId) + } + }.join( + db.withCte(Product.select) { prs => + Purchase.select.join(prs)(_.productId === _.id) + } + )(_._2.id === _._1.shippingInfoId) + .map { case (b, s, (pu, pr)) => (b.name, pr.name) } + }, + sql = """ + SELECT subquery0.res_0_name AS res_0, subquery1.res_1_name AS res_1 + FROM (WITH + cte0 (id, name) + AS (SELECT buyer0.id AS id, buyer0.name AS name FROM buyer buyer0), + cte1 (id, buyer_id) + AS (SELECT shipping_info1.id AS id, shipping_info1.buyer_id AS buyer_id + FROM shipping_info shipping_info1) + SELECT cte0.name AS res_0_name, cte1.id AS res_1_id + FROM cte0 + JOIN cte1 ON (cte0.id = cte1.buyer_id)) subquery0 + JOIN (WITH + cte1 (id, name) + AS (SELECT product1.id AS id, product1.name AS name FROM product product1) + SELECT + purchase2.shipping_info_id AS res_0_shipping_info_id, + cte1.name AS res_1_name + FROM purchase purchase2 + JOIN cte1 ON (purchase2.product_id = cte1.id)) subquery1 + ON (subquery0.res_1_id = subquery1.res_0_shipping_info_id) + """, + value = Seq[(String, String)]( + ("James Bond", "Camera"), + ("James Bond", "Skate Board"), + ("叉烧包", "Cookie"), + ("叉烧包", "Face Mask"), + ("叉烧包", "Face Mask"), + ("叉烧包", "Guitar"), + ("叉烧包", "Socks") + ), + docs = """ + ScalaSql's `withCte` can be used anywhere a `.select` operator can be used. The + generated `WITH` clauses may be wrapped in sub-queries in scenarios where they + cannot be easily combined into a single query + """, + normalize = (x: Seq[(String, String)]) => x.sorted + ) + } } } diff --git a/scalasql/test/src/utils/ScalaSqlSuite.scala b/scalasql/test/src/utils/ScalaSqlSuite.scala index 7f84c2bb..1043bdc0 100644 --- a/scalasql/test/src/utils/ScalaSqlSuite.scala +++ b/scalasql/test/src/utils/ScalaSqlSuite.scala @@ -83,3 +83,16 @@ trait MySqlSuite extends ScalaSqlSuite with MySqlDialect { checker.reset() } + +trait MsSqlSuite extends ScalaSqlSuite with MsSqlDialect { + val checker = new TestChecker( + scalasql.example.MsSqlExample.mssqlClient, + "mssql-customer-schema.sql", + "customer-data-plus-schema.sql", + getClass.getName, + suiteLine.value, + description + ) + + checker.reset() +} diff --git a/scalasql/test/src/utils/TestChecker.scala b/scalasql/test/src/utils/TestChecker.scala index 665d1f3e..529223c3 100644 --- a/scalasql/test/src/utils/TestChecker.scala +++ b/scalasql/test/src/utils/TestChecker.scala @@ -2,6 +2,7 @@ package scalasql.utils import com.github.vertical_blank.sqlformatter.SqlFormatter import pprint.PPrinter +import scalasql.core.SqlStr import scalasql.query.SubqueryRef import scalasql.{DbClient, Queryable, Expr, UtestFramework} @@ -46,6 +47,7 @@ class TestChecker( res } def apply[T, V]( + preQuery: SqlStr = null, query: sourcecode.Text[T], sql: String = null, sqls: Seq[String] = Nil, @@ -86,7 +88,10 @@ class TestChecker( assert(matchedSql.nonEmpty, pprint.apply(SqlFormatter.format(sqlResult))) } - val result = autoCommitConnection.run(query.value) + val result = dbClient.transaction { db => + Option(preQuery).foreach(q => db.updateSql(q)) + db.run(query.value) + } val values = Option(value).map(_.value) ++ moreValues val normalized = normalize(result)