diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 267c752..17f5715 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ master ] + branches: [ master, zio-1.x ] pull_request: - branches: [ master ] + branches: [ master, zio-1.x ] jobs: build: @@ -19,5 +19,7 @@ jobs: with: docker_channel: stable docker_version: 20.10 + - name: Validate Scaladoc + run: sbt doc - name: Run tests run: sbt test +IntegrationTest/test diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3b82ce8..7473e11 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -1,7 +1,7 @@ name: Release on: push: - branches: [master, main] + branches: [master, main, zio-1.x] tags: ["*"] jobs: publish: diff --git a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala index adc9afa..405561b 100644 --- a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala +++ b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala @@ -3,6 +3,7 @@ package zio.cassandra.session.cql import com.datastax.oss.driver.api.core.ConsistencyLevel import zio.cassandra.session.Session import zio.cassandra.session.cql.codec.UnexpectedNullValue +import zio.cassandra.session.cql.unsafe.lift import zio.duration._ import zio.stream.Stream import zio.test.Assertion._ @@ -449,6 +450,27 @@ object CqlSpec { } } } + ), + suite("put lifted values as is")( + testM("should handle lifted values (though redundant) in cqlConst") { + val tableName = "tests.test_data" + for { + results <- cqlConst"select data from ${lift(tableName)} where id in (1)".as[String].select.runCollect + } yield assertTrue(results == Chunk("one")) + }, + testM("should handle lifted values in cql") { + val tableName = "tests.test_data" + for { + results <- cql"select data from ${lift(tableName)} where id in ${List(1L)}".as[String].select.runCollect + } yield assertTrue(results == Chunk("one")) + }, + testM("should handle lifted values in cqlt") { + val tableName = "tests.test_data" + for { + query <- cqlt"select data from ${lift(tableName)} where id in ${Put[List[Long]]}".as[String].prepare + results <- query(List(1L)).select.runCollect + } yield assertTrue(results == Chunk("one")) + } ) ) ) diff --git a/src/main/scala/zio/cassandra/session/cql/CqlConstInterpolator.scala b/src/main/scala/zio/cassandra/session/cql/CqlConstInterpolator.scala index 1548f51..83f9d60 100644 --- a/src/main/scala/zio/cassandra/session/cql/CqlConstInterpolator.scala +++ b/src/main/scala/zio/cassandra/session/cql/CqlConstInterpolator.scala @@ -5,10 +5,10 @@ import shapeless.HNil import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate } /** Provides a way to lift arbitrary strings into CQL so you can parameterize on values that are not valid CQL - * parameters Please note that this is not escaped so do not use this with user-supplied input for your application - * (only use cqlConst for input that you as the application author control) + * parameters
Please note that this is not escaped so do not use this with user-supplied input for your + * application (only use cqlConst for input that you as the application author control) */ class CqlConstInterpolator(ctx: StringContext) { - def apply(args: Any*): ParameterizedQuery[HNil, Row] = + def apply(args: Any*): SimpleQuery[Row] = ParameterizedQuery(QueryTemplate(ctx.s(args: _*), identity), HNil) } diff --git a/src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala b/src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala index 4740a27..f705fdc 100644 --- a/src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala +++ b/src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala @@ -1,19 +1,53 @@ package zio.cassandra.session.cql import com.datastax.oss.driver.api.core.cql.{ BoundStatement, Row } +import shapeless.ops.hlist.ToList import shapeless.{ ::, HList, HNil, ProductArgs } import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate } -import scala.annotation.{ nowarn, tailrec } +import scala.annotation.tailrec -class CqlTemplateStringInterpolator(ctx: StringContext) extends ProductArgs { +sealed abstract class CqlStringInterpolatorBase { + + @tailrec + final protected def assembleQuery( + strings: Iterator[String], + expressions: Iterator[InterpolatorValue], + acc: String + ): String = + if (strings.hasNext && expressions.hasNext) { + val str = strings.next() + val expr = expressions.next() + + val placeholder = expr match { + case l: LiftedValue => l.toString + case _: BoundValue[_] | _: Put[_] => "?" + } + + assembleQuery( + strings = strings, + expressions = expressions, + acc = acc + s"$str$placeholder" + ) + } else if (strings.hasNext && !expressions.hasNext) { + val str = strings.next() + assembleQuery( + strings = strings, + expressions = expressions, + acc + str + ) + } else acc + +} + +class CqlTemplateStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase with ProductArgs { import CqlTemplateStringInterpolator._ - @nowarn("msg=is never used") - def applyProduct[P <: HList, V <: HList](params: P)(implicit - bb: BindableBuilder.Aux[P, V] - ): QueryTemplate[V, Row] = { + def applyProduct[P <: HList, V <: HList]( + params: P + )(implicit bb: BindableBuilder.Aux[P, V], toList: ToList[P, CqltValue]): QueryTemplate[V, Row] = { implicit val binder: Binder[V] = bb.binder - QueryTemplate[V, Row](ctx.parts.mkString("?"), identity) + val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, params.toList[CqltValue].iterator, "") + QueryTemplate[V, Row](queryWithQuestionMarks, identity) } } @@ -26,12 +60,15 @@ object CqlTemplateStringInterpolator { object BindableBuilder { type Aux[P, Repr0] = BindableBuilder[P] { type Repr = Repr0 } + def apply[P](implicit builder: BindableBuilder[P]): BindableBuilder.Aux[P, builder.Repr] = builder - implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] { + + implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] { override type Repr = HNil override def binder: Binder[HNil] = Binder[HNil] } - implicit def hConsBindableBuilder[PH <: Put[_], T: Binder, PT <: HList, RT <: HList](implicit + + implicit def hConsPutBindableBuilder[T: Binder, PT <: HList, RT <: HList](implicit f: BindableBuilder.Aux[PT, RT] ): BindableBuilder.Aux[Put[T] :: PT, T :: RT] = new BindableBuilder[Put[T] :: PT] { override type Repr = T :: RT @@ -40,57 +77,35 @@ object CqlTemplateStringInterpolator { Binder[T :: RT] } } - } -} -/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that - * a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally - */ -private[cql] final case class BoundValue[A](value: A, ev: Binder[A]) + implicit def hConsLiftedValueBindableBuilder[PT <: HList, RT <: HList](implicit + f: BindableBuilder.Aux[PT, RT] + ): BindableBuilder.Aux[LiftedValue :: PT, RT] = new BindableBuilder[LiftedValue :: PT] { + override type Repr = RT + override def binder: Binder[RT] = f.binder + } -object BoundValue { - // This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string - implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] = - BoundValue(a, ev) + } } -final class CqlStringInterpolator(ctx: StringContext) { - @tailrec - private def replaceValuesWithQuestionMark( - strings: Iterator[String], - expressions: Iterator[BoundValue[_]], - acc: String - ): String = - if (strings.hasNext && expressions.hasNext) { - val str = strings.next() - val _ = expressions.next() - replaceValuesWithQuestionMark( - strings = strings, - expressions = expressions, - acc = acc + s"$str?" - ) - } else if (strings.hasNext && !expressions.hasNext) { - val str = strings.next() - replaceValuesWithQuestionMark( - strings = strings, - expressions = expressions, - acc + str - ) - } else acc +final class CqlStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase { + + def apply(values: CqlValue*): SimpleQuery[Row] = { + val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, values.iterator, "") - final def apply(values: BoundValue[_]*): SimpleQuery[Row] = { - val queryWithQuestionMark = replaceValuesWithQuestionMark(ctx.parts.iterator, values.iterator, "") val assignValuesToStatement: BoundStatement => BoundStatement = { in: BoundStatement => val (configuredBoundStatement, _) = - values.foldLeft((in, 0)) { case ((current, index), bv: BoundValue[a]) => - val binder: Binder[a] = bv.ev - val value: a = bv.value - val statement = binder.bind(current, index, value) - val nextIndex = binder.nextIndex(index) - (statement, nextIndex) + values.foldLeft((in, 0)) { case ((current, index), qv: CqlValue) => + qv match { + case _: LiftedValue => (current, index) + case BoundValue(value, binder) => + val statement = binder.bind(current, index, value) + val nextIndex = binder.nextIndex(index) + (statement, nextIndex) + } } configuredBoundStatement } - ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMark, assignValuesToStatement), HNil) + ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMarks, assignValuesToStatement), HNil) } } diff --git a/src/main/scala/zio/cassandra/session/cql/InterpolatorValue.scala b/src/main/scala/zio/cassandra/session/cql/InterpolatorValue.scala new file mode 100644 index 0000000..f4986c9 --- /dev/null +++ b/src/main/scala/zio/cassandra/session/cql/InterpolatorValue.scala @@ -0,0 +1,40 @@ +package zio.cassandra.session.cql + +sealed trait InterpolatorValue + +sealed trait CqltValue extends InterpolatorValue + +sealed trait CqlValue extends InterpolatorValue + +object CqlValue { + + // This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string + implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] = + BoundValue(a, ev) + +} + +/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that + * a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally + */ +final case class BoundValue[A](value: A, ev: Binder[A]) extends CqlValue + +sealed trait Put[T] extends CqltValue + +object Put { + + def apply[T: Binder]: Put[T] = put.asInstanceOf[Put[T]] + + private val put: Put[Any] = new Put[Any] {} + +} + +/** LiftedValue is useful when you want to inject a value into cql query as is without escaping (similar to + * [[zio.cassandra.session.cql.CqlConstInterpolator]], but on a lower lever).
Please only use LiftedValue for + * input that you as the application author control. + */ +final case class LiftedValue(value: Any) extends CqltValue with CqlValue { + + override def toString: String = value.toString // to keep things simple with cqlConst + +} diff --git a/src/main/scala/zio/cassandra/session/cql/Put.scala b/src/main/scala/zio/cassandra/session/cql/Put.scala deleted file mode 100644 index dea6c00..0000000 --- a/src/main/scala/zio/cassandra/session/cql/Put.scala +++ /dev/null @@ -1,6 +0,0 @@ -package zio.cassandra.session.cql - -trait Put[T] -object Put { - def apply[T: Binder]: Put[T] = new Put[T] {} -} diff --git a/src/main/scala/zio/cassandra/session/cql/package.scala b/src/main/scala/zio/cassandra/session/cql/package.scala index 223f49d..6fc1826 100644 --- a/src/main/scala/zio/cassandra/session/cql/package.scala +++ b/src/main/scala/zio/cassandra/session/cql/package.scala @@ -12,4 +12,26 @@ package object cql { val cql = new CqlStringInterpolator(ctx) val cqlConst = new CqlConstInterpolator(ctx) } + + object unsafe { + + /** lifting a value is useful when you want to inject a value into cql query as is without escaping (similar to + * cqlConst interpolator, but on a lower lever).
Please only use `lift()` for input that you as the + * application author control. Example: + * {{{ + * import unsafe._ + * + * private val tableName = "my_table" + * def selectById(ids: Seq[Long]) = cql"select id from \${lift(tableName)} where id in \$ids".as[Int] + * }}} + * instead of + * {{{ + * private val tableName = "my_table" + * def selectById(ids: Seq[Long]) = (cqlConst"select id from \$tableName" ++ cql"where id in \$ids").as[Int] + * }}} + */ + def lift(value: Any): LiftedValue = LiftedValue(value) + + } + }