Skip to content

Commit

Permalink
Validation for OpenAPI generated endpoints (zio#2786)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Aug 14, 2024
1 parent 1c939b3 commit a34b0cc
Show file tree
Hide file tree
Showing 13 changed files with 583 additions and 130 deletions.
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ object Dependencies {
val ZioVersion = "2.1.7"
val ZioCliVersion = "0.5.0"
val ZioJsonVersion = "0.7.1"
val ZioSchemaVersion = "1.3.0"
val ZioSchemaVersion = "1.4.0"
val SttpVersion = "3.3.18"
val ZioConfigVersion = "4.0.2"

Expand Down
272 changes: 205 additions & 67 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala

Large diffs are not rendered by default.

24 changes: 19 additions & 5 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import scala.meta.prettyprinters.XtensionSyntax

import zio.http.{Method, Status}

import com.sun.tools.javac.code.TypeMetadata.Annotations

sealed trait Code extends Product with Serializable

object Code {
Expand Down Expand Up @@ -79,17 +81,29 @@ object Code {
abstractMembers: List[Field] = Nil,
) extends ScalaType

sealed abstract case class Field private (name: String, fieldType: ScalaType) extends Code {
final case class Annotation(value: String)

sealed abstract case class Field private (name: String, fieldType: ScalaType, annotations: List[Annotation])
extends Code {
// only allow copy on fieldType, since name is mangled to be valid in smart constructor
def copy(fieldType: ScalaType): Field = new Field(name, fieldType) {}
def copy(fieldType: ScalaType = fieldType, annotations: List[Annotation] = annotations): Field =
new Field(name, fieldType, annotations) {}
}

object Field {

def apply(name: String): Field = apply(name, ScalaType.Inferred)
def apply(name: String, fieldType: ScalaType): Field = {
def apply(name: String): Field = apply(name, ScalaType.Inferred)
def apply(name: String, fieldType: ScalaType): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType, Nil) {}
}
def apply(name: String, fieldType: ScalaType, annotation: Annotation): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType, List(annotation)) {}
}
def apply(name: String, fieldType: ScalaType, annotations: List[Annotation]): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType) {}
new Field(validScalaTermName, fieldType, annotations) {}
}
}

Expand Down
11 changes: 8 additions & 3 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ object CodeGen {
val traitBodyBuilder = new StringBuilder().append(' ')
var pre = '{'
val imports = abstractMembers.foldLeft(List.empty[Code.Import]) {
case (importsAcc, Code.Field(name, fieldType)) =>
case (importsAcc, Code.Field(name, fieldType, annotations)) =>
val (imports, tpe) = render(basePackage)(fieldType)
if (tpe.isEmpty) importsAcc
else {
traitBodyBuilder += pre
pre = '\n'
annotations.foreach { annotation =>
traitBodyBuilder ++= annotation.value
traitBodyBuilder += '\n'
}
traitBodyBuilder ++= "def "
traitBodyBuilder ++= name
traitBodyBuilder ++= ": "
Expand Down Expand Up @@ -173,10 +177,11 @@ object CodeGen {
imports -> s"Option[$tpe]"
}

case Code.Field(name, fieldType) =>
case Code.Field(name, fieldType, annotations) =>
val (imports, tpe) = render(basePackage)(fieldType)
val annotationsStr = annotations.map(_.value).mkString("\n")
val content = if (tpe.isEmpty) s"val $name" else s"val $name: $tpe"
imports -> content
imports -> (annotationsStr + content)

case Code.Primitive.ScalaBoolean => Nil -> "Boolean"
case Code.Primitive.ScalaByte => Nil -> "Byte"
Expand Down
12 changes: 6 additions & 6 deletions zio-http-gen/src/test/resources/ComponentAnimal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
) extends Animal
object Zebra {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
) extends Animal
object Zebra {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
dazzle: Chunk[Zebra],
) extends Animal
object Zebra {
Expand Down
15 changes: 15 additions & 0 deletions zio-http-gen/src/test/resources/ValidatedData.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package test.component

import zio.schema._

case class ValidatedData(
@zio.schema.annotation.validate[String](zio.schema.validation.Validation.minLength(10)) name: String,
@zio.schema.annotation.validate[Int](
zio.schema.validation.Validation.greaterThan(0) && zio.schema.validation.Validation.lessThan(100),
) age: Int,
)
object ValidatedData {

implicit val codec: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData]

}
33 changes: 30 additions & 3 deletions zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ import scala.meta._
import scala.meta.parsers._
import scala.util.{Failure, Success, Try}

import zio.Scope
import zio.json.{JsonDecoder, JsonEncoder}
import zio.test.Assertion.{hasSameElements, isFailure, isSuccess, throws}
import zio.test.Assertion.{hasSameElements, isFailure, isSuccess}
import zio.test.TestAspect.{blocking, flaky}
import zio.test.TestFailure.fail
import zio.test._
import zio.{Scope, ZIO}

import zio.schema.annotation.validate
import zio.schema.codec.JsonCodec
import zio.schema.validation.Validation
import zio.schema.{DeriveSchema, Schema}

import zio.http._
import zio.http.codec._
Expand All @@ -27,6 +29,14 @@ import zio.http.gen.openapi.{Config, EndpointGen}
@nowarn("msg=missing interpolator")
object CodeGenSpec extends ZIOSpecDefault {

case class ValidatedData(
@validate(Validation.maxLength(10))
name: String,
@validate(Validation.greaterThan(0) && Validation.lessThan(100))
age: Int,
)
implicit val validatedDataSchema: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData]

private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = {
val filePath = dir.resolve(Paths.get(subPath))
val generated = Files.readAllLines(filePath).asScala.mkString("\n")
Expand Down Expand Up @@ -791,5 +801,22 @@ object CodeGenSpec extends ZIOSpecDefault {
"/AnimalWithMap.scala",
)
},
test("Endpoint with data validation") {
val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[ValidatedData]
val openAPIJson = OpenAPIGen.fromEndpoints(endpoint).toJson
val openAPI = OpenAPI.fromJson(openAPIJson).getOrElse(OpenAPI.empty)
val code = EndpointGen.fromOpenAPI(openAPI)

val tempDir = Files.createTempDirectory("codegen")

CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath))

fileShouldBe(
tempDir,
"test/component/ValidatedData.scala",
"/ValidatedData.scala",
)

},
) @@ java11OrNewer @@ flaky @@ blocking // Downloading scalafmt on CI is flaky
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object AuthSpec extends ZIOSpecDefault {
val endpoint = Endpoint(Method.GET / "test").out[String](MediaType.text.`plain`)
val routes =
Routes(
endpoint.implementHandler(handler((_: Unit) => ZIO.serviceWith[AuthContext](_.value))),
endpoint.implementHandler(handler((_: Unit) => withContext((ctx: AuthContext) => ctx.value))),
) @@ basicAuthContext
val response = routes.run(
Request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ final case class HttpEndpoint(
case JsonSchema.OneOfSchema(_) => throw new Exception("OneOfSchema not supported")
case JsonSchema.AllOfSchema(_) => throw new Exception("AllOfSchema not supported")
case JsonSchema.AnyOfSchema(_) => throw new Exception("AnyOfSchema not supported")
case JsonSchema.Number(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Integer(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.String(_, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Number(_, _, _, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Integer(_, _, _, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.String(_, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Boolean => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.ArrayType(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Object(properties, _, _) =>
Expand Down
25 changes: 14 additions & 11 deletions zio-http/shared/src/main/scala/zio/http/endpoint/http/HttpGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,28 @@ object HttpGen {
val bodySchema0 = bodySchema(inAtoms)

def loop(schema: JsonSchema, name: Option[String]): Seq[HttpVariable] = schema match {
case JsonSchema.AnnotatedSchema(schema, _) => loop(schema, name)
case JsonSchema.RefSchema(_) => throw new Exception("RefSchema not supported")
case JsonSchema.OneOfSchema(_) => throw new Exception("OneOfSchema not supported")
case JsonSchema.AllOfSchema(_) => throw new Exception("AllOfSchema not supported")
case JsonSchema.AnyOfSchema(_) => throw new Exception("AnyOfSchema not supported")
case JsonSchema.Number(format) =>
case JsonSchema.AnnotatedSchema(schema, _) => loop(schema, name)
case JsonSchema.RefSchema(_) => throw new Exception("RefSchema not supported")
case JsonSchema.OneOfSchema(_) => throw new Exception("OneOfSchema not supported")
case JsonSchema.AllOfSchema(_) => throw new Exception("AllOfSchema not supported")
case JsonSchema.AnyOfSchema(_) => throw new Exception("AnyOfSchema not supported")
// TODO: add comments for validation restrictions
case JsonSchema.Number(format, _, _, _, _, _) =>
val typeHint = format match {
case JsonSchema.NumberFormat.Float => "type: Float"
case JsonSchema.NumberFormat.Double => "type: Double"
}
Seq(HttpVariable(getName(name), None, Some(typeHint)))
case JsonSchema.Integer(format) =>
// TODO: add comments for validation restrictions
case JsonSchema.Integer(format, _, _, _, _, _) =>
val typeHint = format match {
case JsonSchema.IntegerFormat.Int32 => "type: Int"
case JsonSchema.IntegerFormat.Int64 => "type: Long"
case JsonSchema.IntegerFormat.Timestamp => "type: Timestamp in milliseconds"
}
Seq(HttpVariable(getName(name), None, Some(typeHint)))
case JsonSchema.String(format, pattern) =>
// TODO: add comments for validation restrictions
case JsonSchema.String(format, pattern, _, _) =>
val formatHint: String = format match {
case Some(value) => s" format: ${value.value}"
case None => ""
Expand All @@ -79,8 +82,8 @@ object HttpGen {
case None => ""
}
Seq(HttpVariable(getName(name), None, Some(s"type: String$formatHint$patternHint")))
case JsonSchema.Boolean => Seq(HttpVariable(getName(name), None, Some("type: Boolean")))
case JsonSchema.ArrayType(items) =>
case JsonSchema.Boolean => Seq(HttpVariable(getName(name), None, Some("type: Boolean")))
case JsonSchema.ArrayType(items) =>
val typeHint =
items match {
case Some(schema) =>
Expand All @@ -90,7 +93,7 @@ object HttpGen {
}

Seq(HttpVariable(getName(name), None, Some(s"type: array of $typeHint")))
case JsonSchema.Object(properties, _, _) =>
case JsonSchema.Object(properties, _, _) =>
properties.flatMap { case (key, value) => loop(value, Some(key)) }.toSeq
case JsonSchema.Enum(values) => Seq(HttpVariable(getName(name), None, Some(s"enum: ${values.mkString(",")}")))
case JsonSchema.Null => Seq.empty
Expand Down
Loading

0 comments on commit a34b0cc

Please sign in to comment.