Skip to content

Commit

Permalink
Merge pull request #13 from narma/support_for_lifted_values_zio-1.x
Browse files Browse the repository at this point in the history
Support for lifted values zio 1.x
  • Loading branch information
myazinn authored Aug 22, 2022
2 parents b195cca + 28ee51e commit e105a17
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 63 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [ master, zio-1.x ]
pull_request:
branches: [ master ]
branches: [ master, zio-1.x ]

jobs:
build:
Expand All @@ -19,5 +19,7 @@ jobs:
with:
docker_channel: stable
docker_version: 20.10
- name: Validate Scaladoc
run: sbt doc
- name: Run tests
run: sbt test +IntegrationTest/test
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Release
on:
push:
branches: [master, main]
branches: [master, main, zio-1.x]
tags: ["*"]
jobs:
publish:
Expand Down
22 changes: 22 additions & 0 deletions src/it/scala/zio/cassandra/session/cql/CqlSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package zio.cassandra.session.cql
import com.datastax.oss.driver.api.core.ConsistencyLevel
import zio.cassandra.session.Session
import zio.cassandra.session.cql.codec.UnexpectedNullValue
import zio.cassandra.session.cql.unsafe.lift
import zio.duration._
import zio.stream.Stream
import zio.test.Assertion._
Expand Down Expand Up @@ -449,6 +450,27 @@ object CqlSpec {
}
}
}
),
suite("put lifted values as is")(
testM("should handle lifted values (though redundant) in cqlConst") {
val tableName = "tests.test_data"
for {
results <- cqlConst"select data from ${lift(tableName)} where id in (1)".as[String].select.runCollect
} yield assertTrue(results == Chunk("one"))
},
testM("should handle lifted values in cql") {
val tableName = "tests.test_data"
for {
results <- cql"select data from ${lift(tableName)} where id in ${List(1L)}".as[String].select.runCollect
} yield assertTrue(results == Chunk("one"))
},
testM("should handle lifted values in cqlt") {
val tableName = "tests.test_data"
for {
query <- cqlt"select data from ${lift(tableName)} where id in ${Put[List[Long]]}".as[String].prepare
results <- query(List(1L)).select.runCollect
} yield assertTrue(results == Chunk("one"))
}
)
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import shapeless.HNil
import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate }

/** Provides a way to lift arbitrary strings into CQL so you can parameterize on values that are not valid CQL
* parameters Please note that this is not escaped so do not use this with user-supplied input for your application
* (only use cqlConst for input that you as the application author control)
* parameters <br> Please note that this is not escaped so do not use this with user-supplied input for your
* application (only use cqlConst for input that you as the application author control)
*/
class CqlConstInterpolator(ctx: StringContext) {
def apply(args: Any*): ParameterizedQuery[HNil, Row] =
def apply(args: Any*): SimpleQuery[Row] =
ParameterizedQuery(QueryTemplate(ctx.s(args: _*), identity), HNil)
}
117 changes: 66 additions & 51 deletions src/main/scala/zio/cassandra/session/cql/CqlStringInterpolator.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,53 @@
package zio.cassandra.session.cql

import com.datastax.oss.driver.api.core.cql.{ BoundStatement, Row }
import shapeless.ops.hlist.ToList
import shapeless.{ ::, HList, HNil, ProductArgs }
import zio.cassandra.session.cql.query.{ ParameterizedQuery, QueryTemplate }

import scala.annotation.{ nowarn, tailrec }
import scala.annotation.tailrec

class CqlTemplateStringInterpolator(ctx: StringContext) extends ProductArgs {
sealed abstract class CqlStringInterpolatorBase {

@tailrec
final protected def assembleQuery(
strings: Iterator[String],
expressions: Iterator[InterpolatorValue],
acc: String
): String =
if (strings.hasNext && expressions.hasNext) {
val str = strings.next()
val expr = expressions.next()

val placeholder = expr match {
case l: LiftedValue => l.toString
case _: BoundValue[_] | _: Put[_] => "?"
}

assembleQuery(
strings = strings,
expressions = expressions,
acc = acc + s"$str$placeholder"
)
} else if (strings.hasNext && !expressions.hasNext) {
val str = strings.next()
assembleQuery(
strings = strings,
expressions = expressions,
acc + str
)
} else acc

}

class CqlTemplateStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase with ProductArgs {
import CqlTemplateStringInterpolator._
@nowarn("msg=is never used")
def applyProduct[P <: HList, V <: HList](params: P)(implicit
bb: BindableBuilder.Aux[P, V]
): QueryTemplate[V, Row] = {
def applyProduct[P <: HList, V <: HList](
params: P
)(implicit bb: BindableBuilder.Aux[P, V], toList: ToList[P, CqltValue]): QueryTemplate[V, Row] = {
implicit val binder: Binder[V] = bb.binder
QueryTemplate[V, Row](ctx.parts.mkString("?"), identity)
val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, params.toList[CqltValue].iterator, "")
QueryTemplate[V, Row](queryWithQuestionMarks, identity)
}
}

Expand All @@ -26,12 +60,15 @@ object CqlTemplateStringInterpolator {

object BindableBuilder {
type Aux[P, Repr0] = BindableBuilder[P] { type Repr = Repr0 }

def apply[P](implicit builder: BindableBuilder[P]): BindableBuilder.Aux[P, builder.Repr] = builder
implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] {

implicit def hNilBindableBuilder: BindableBuilder.Aux[HNil, HNil] = new BindableBuilder[HNil] {
override type Repr = HNil
override def binder: Binder[HNil] = Binder[HNil]
}
implicit def hConsBindableBuilder[PH <: Put[_], T: Binder, PT <: HList, RT <: HList](implicit

implicit def hConsPutBindableBuilder[T: Binder, PT <: HList, RT <: HList](implicit
f: BindableBuilder.Aux[PT, RT]
): BindableBuilder.Aux[Put[T] :: PT, T :: RT] = new BindableBuilder[Put[T] :: PT] {
override type Repr = T :: RT
Expand All @@ -40,57 +77,35 @@ object CqlTemplateStringInterpolator {
Binder[T :: RT]
}
}
}
}

/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that
* a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally
*/
private[cql] final case class BoundValue[A](value: A, ev: Binder[A])
implicit def hConsLiftedValueBindableBuilder[PT <: HList, RT <: HList](implicit
f: BindableBuilder.Aux[PT, RT]
): BindableBuilder.Aux[LiftedValue :: PT, RT] = new BindableBuilder[LiftedValue :: PT] {
override type Repr = RT
override def binder: Binder[RT] = f.binder
}

object BoundValue {
// This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string
implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] =
BoundValue(a, ev)
}
}

final class CqlStringInterpolator(ctx: StringContext) {
@tailrec
private def replaceValuesWithQuestionMark(
strings: Iterator[String],
expressions: Iterator[BoundValue[_]],
acc: String
): String =
if (strings.hasNext && expressions.hasNext) {
val str = strings.next()
val _ = expressions.next()
replaceValuesWithQuestionMark(
strings = strings,
expressions = expressions,
acc = acc + s"$str?"
)
} else if (strings.hasNext && !expressions.hasNext) {
val str = strings.next()
replaceValuesWithQuestionMark(
strings = strings,
expressions = expressions,
acc + str
)
} else acc
final class CqlStringInterpolator(ctx: StringContext) extends CqlStringInterpolatorBase {

def apply(values: CqlValue*): SimpleQuery[Row] = {
val queryWithQuestionMarks = assembleQuery(ctx.parts.iterator, values.iterator, "")

final def apply(values: BoundValue[_]*): SimpleQuery[Row] = {
val queryWithQuestionMark = replaceValuesWithQuestionMark(ctx.parts.iterator, values.iterator, "")
val assignValuesToStatement: BoundStatement => BoundStatement = { in: BoundStatement =>
val (configuredBoundStatement, _) =
values.foldLeft((in, 0)) { case ((current, index), bv: BoundValue[a]) =>
val binder: Binder[a] = bv.ev
val value: a = bv.value
val statement = binder.bind(current, index, value)
val nextIndex = binder.nextIndex(index)
(statement, nextIndex)
values.foldLeft((in, 0)) { case ((current, index), qv: CqlValue) =>
qv match {
case _: LiftedValue => (current, index)
case BoundValue(value, binder) =>
val statement = binder.bind(current, index, value)
val nextIndex = binder.nextIndex(index)
(statement, nextIndex)
}
}
configuredBoundStatement
}
ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMark, assignValuesToStatement), HNil)
ParameterizedQuery(QueryTemplate[HNil, Row](queryWithQuestionMarks, assignValuesToStatement), HNil)
}
}
40 changes: 40 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/InterpolatorValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package zio.cassandra.session.cql

sealed trait InterpolatorValue

sealed trait CqltValue extends InterpolatorValue

sealed trait CqlValue extends InterpolatorValue

object CqlValue {

// This implicit conversion automatically captures the value and evidence of the Binder in a cql interpolated string
implicit def aToBoundValue[A](a: A)(implicit ev: Binder[A]): BoundValue[A] =
BoundValue(a, ev)

}

/** BoundValue is used to capture the value inside the cql interpolated string along with evidence of its Binder so that
* a ParameterizedQuery can be built and the values can be bound to the BoundStatement internally
*/
final case class BoundValue[A](value: A, ev: Binder[A]) extends CqlValue

sealed trait Put[T] extends CqltValue

object Put {

def apply[T: Binder]: Put[T] = put.asInstanceOf[Put[T]]

private val put: Put[Any] = new Put[Any] {}

}

/** LiftedValue is useful when you want to inject a value into cql query as is without escaping (similar to
* [[zio.cassandra.session.cql.CqlConstInterpolator]], but on a lower lever). <br> Please only use LiftedValue for
* input that you as the application author control.
*/
final case class LiftedValue(value: Any) extends CqltValue with CqlValue {

override def toString: String = value.toString // to keep things simple with cqlConst

}
6 changes: 0 additions & 6 deletions src/main/scala/zio/cassandra/session/cql/Put.scala

This file was deleted.

22 changes: 22 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,26 @@ package object cql {
val cql = new CqlStringInterpolator(ctx)
val cqlConst = new CqlConstInterpolator(ctx)
}

object unsafe {

/** lifting a value is useful when you want to inject a value into cql query as is without escaping (similar to
* cqlConst interpolator, but on a lower lever). <br> Please only use `lift()` for input that you as the
* application author control. Example:
* {{{
* import unsafe._
*
* private val tableName = "my_table"
* def selectById(ids: Seq[Long]) = cql"select id from \${lift(tableName)} where id in \$ids".as[Int]
* }}}
* instead of
* {{{
* private val tableName = "my_table"
* def selectById(ids: Seq[Long]) = (cqlConst"select id from \$tableName" ++ cql"where id in \$ids").as[Int]
* }}}
*/
def lift(value: Any): LiftedValue = LiftedValue(value)

}

}

0 comments on commit e105a17

Please sign in to comment.