Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for lifted values #11

Merged
merged 2 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
22 changes: 22 additions & 0 deletions src/it/scala/zio/cassandra/session/cql/CqlSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zio.cassandra.session.cql
import com.datastax.oss.driver.api.core.ConsistencyLevel
import zio.cassandra.session.{ Session, ZIOCassandraSpec }
import zio.cassandra.session.cql.codec.UnexpectedNullValue
import zio.cassandra.session.cql.unsafe.lift
import zio.stream.ZStream
import zio.test.Assertion._
import zio.test.TestAspect.ignore
Expand Down Expand Up @@ -448,6 +449,27 @@ object CqlSpec extends ZIOCassandraSpec {
}
}
}
),
suite("put lifted values as is")(
test("should handle lifted values (though redundant) in cqlConst") {
val tableName = "tests.test_data"
for {
results <- cqlConst"select data from ${lift(tableName)} where id in (1)".as[String].select.runCollect
} yield assertTrue(results == Chunk("one"))
},
test("should handle lifted values in cql") {
val tableName = "tests.test_data"
for {
results <- cql"select data from ${lift(tableName)} where id in ${List(1L)}".as[String].select.runCollect
} yield assertTrue(results == Chunk("one"))
},
test("should handle lifted values in cqlt") {
val tableName = "tests.test_data"
for {
query <- cqlt"select data from ${lift(tableName)} where id in ${Put[List[Long]]}".as[String].prepare
results <- query(List(1L)).select.runCollect
} yield assertTrue(results == Chunk("one"))
}
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import shapeless.HNil
import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate }

/** Provides a way to lift arbitrary strings into CQL so you can parameterize on values that are not valid CQL
* parameters Please note that this is not escaped so do not use this with user-supplied input for your application
* (only use cqlConst for input that you as the application author control)
* parameters <br> Please note that this is not escaped so do not use this with user-supplied input for your
* application (only use cqlConst for input that you as the application author control)
*/
class CqlConstInterpolator(ctx: StringContext) {
def apply(args: Any*): ParameterizedQuery[HNil, Row] =
def apply(args: Any*): SimpleQuery[Row] =
ParameterizedQuery(QueryTemplate(ctx.s(args: _*), identity), HNil)
}
117 changes: 66 additions & 51 deletions src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,53 @@
package zio.cassandra.session.cql

import com.datastax.oss.driver.api.core.cql.{ BoundStatement, Row }
import shapeless.ops.hlist.ToList
import shapeless.{ ::, HList, HNil, ProductArgs }
import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate }

import scala.annotation.{ nowarn, tailrec }
import scala.annotation.tailrec

class CqlTemplateStringInterpolator(ctx: StringContext) extends ProductArgs {
sealed abstract class CqlStringInterpolatorBase {

@tailrec
final protected def assembleQuery(
strings: Iterator[String],
expressions: Iterator[InterpolatorValue],
acc: String
): String =
if (strings.hasNext && expressions.hasNext) {
val str = strings.next()
val expr = expressions.next()

val placeholder = expr match {
case l: LiftedValue => l.toString
case _: BoundValue[_] | _: Put[_] => "?"
}

assembleQuery(
strings = strings,
expressions = expressions,
acc = acc + s"$str$placeholder"
)
} else if (strings.hasNext && !expressions.hasNext) {
val str = strings.next()
assembleQuery(
strings = strings,
expressions = expressions,
acc + str
)
} else acc

}

class CqlTemplateStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase with ProductArgs {
import CqlTemplateStringInterpolator._
@nowarn("msg=is never used")
def applyProduct[P <: HList, V <: HList](params: P)(implicit
bb: BindableBuilder.Aux[P, V]
): QueryTemplate[V, Row] = {
def applyProduct[P <: HList, V <: HList](
params: P
)(implicit bb: BindableBuilder.Aux[P, V], toList: ToList[P, CqltValue]): QueryTemplate[V, Row] = {
implicit val binder: Binder[V] = bb.binder
QueryTemplate[V, Row](ctx.parts.mkString("?"), identity)
val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, params.toList[CqltValue].iterator, "")
QueryTemplate[V, Row](queryWithQuestionMarks, identity)
}
}

Expand All @@ -26,12 +60,15 @@ object CqlTemplateStringInterpolator {

object BindableBuilder {
type Aux[P, Repr0] = BindableBuilder[P] { type Repr = Repr0 }

def apply[P](implicit builder: BindableBuilder[P]): BindableBuilder.Aux[P, builder.Repr] = builder
implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] {

implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] {
override type Repr = HNil
override def binder: Binder[HNil] = Binder[HNil]
}
implicit def hConsBindableBuilder[PH <: Put[_], T: Binder, PT <: HList, RT <: HList](implicit

implicit def hConsPutBindableBuilder[T: Binder, PT <: HList, RT <: HList](implicit
f: BindableBuilder.Aux[PT, RT]
): BindableBuilder.Aux[Put[T] :: PT, T :: RT] = new BindableBuilder[Put[T] :: PT] {
override type Repr = T :: RT
Expand All @@ -40,57 +77,35 @@ object CqlTemplateStringInterpolator {
Binder[T :: RT]
}
}
}
}

/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that
* a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally
*/
private[cql] final case class BoundValue[A](value: A, ev: Binder[A])
implicit def hConsLiftedValueBindableBuilder[PT <: HList, RT <: HList](implicit
f: BindableBuilder.Aux[PT, RT]
): BindableBuilder.Aux[LiftedValue :: PT, RT] = new BindableBuilder[LiftedValue :: PT] {
override type Repr = RT
override def binder: Binder[RT] = f.binder
}

object BoundValue {
// This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string
implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] =
BoundValue(a, ev)
}
}

final class CqlStringInterpolator(ctx: StringContext) {
@tailrec
private def replaceValuesWithQuestionMark(
strings: Iterator[String],
expressions: Iterator[BoundValue[_]],
acc: String
): String =
if (strings.hasNext && expressions.hasNext) {
val str = strings.next()
val _ = expressions.next()
replaceValuesWithQuestionMark(
strings = strings,
expressions = expressions,
acc = acc + s"$str?"
)
} else if (strings.hasNext && !expressions.hasNext) {
val str = strings.next()
replaceValuesWithQuestionMark(
strings = strings,
expressions = expressions,
acc + str
)
} else acc
final class CqlStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase {

def apply(values: CqlValue*): SimpleQuery[Row] = {
val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, values.iterator, "")

final def apply(values: BoundValue[_]*): SimpleQuery[Row] = {
val queryWithQuestionMark = replaceValuesWithQuestionMark(ctx.parts.iterator, values.iterator, "")
val assignValuesToStatement: BoundStatement => BoundStatement = { in: BoundStatement =>
val (configuredBoundStatement, _) =
values.foldLeft((in, 0)) { case ((current, index), bv: BoundValue[a]) =>
val binder: Binder[a] = bv.ev
val value: a = bv.value
val statement = binder.bind(current, index, value)
val nextIndex = binder.nextIndex(index)
(statement, nextIndex)
values.foldLeft((in, 0)) { case ((current, index), qv: CqlValue) =>
qv match {
case _: LiftedValue => (current, index)
case BoundValue(value, binder) =>
val statement = binder.bind(current, index, value)
val nextIndex = binder.nextIndex(index)
(statement, nextIndex)
}
}
configuredBoundStatement
}
ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMark, assignValuesToStatement), HNil)
ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMarks, assignValuesToStatement), HNil)
}
}
40 changes: 40 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/InterpolatorValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package zio.cassandra.session.cql

sealed trait InterpolatorValue

sealed trait CqltValue extends InterpolatorValue

sealed trait CqlValue extends InterpolatorValue

object CqlValue {

// This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string
implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] =
BoundValue(a, ev)

}

/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that
* a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally
*/
final case class BoundValue[A](value: A, ev: Binder[A]) extends CqlValue

sealed trait Put[T] extends CqltValue

object Put {

def apply[T: Binder]: Put[T] = put.asInstanceOf[Put[T]]

private val put: Put[Any] = new Put[Any] {}

}

/** LiftedValue is useful when you want to inject a value into cql query as is without escaping (similar to
* [[zio.cassandra.session.cql.CqlConstInterpolator]], but on a lower lever). <br> Please only use LiftedValue for
* input that you as the application author control.
*/
final case class LiftedValue(value: Any) extends CqltValue with CqlValue {

override def toString: String = value.toString // to keep things simple with cqlConst

}
6 changes: 0 additions & 6 deletions src/main/scala/zio/cassandra/session/cql/Put.scala

This file was deleted.

22 changes: 22 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,26 @@ package object cql {
val cql = new CqlStringInterpolator(ctx)
val cqlConst = new CqlConstInterpolator(ctx)
}

object unsafe {

/** lifting a value is useful when you want to inject a value into cql query as is without escaping (similar to
* cqlConst interpolator, but on a lower lever). <br> Please only use `lift()` for input that you as the
* application author control. Example:
* {{{
* import unsafe._
*
* private val tableName = "my_table"
* def selectById(ids: Seq[Long) = cql"select id from ${lift(tableName)} where id in $ids".as[Int]
* }}}
* instead of
* {{{
* private val tableName = "my_table"
* def selectById(ids: Seq[Long) = (cqlConst"select id from $tableName" ++ cql"where id in $ids").as[Int]
* }}}
*/
def lift(value: Any): LiftedValue = LiftedValue(value)

}

}