diff --git a/build.sbt b/build.sbt index 1e78675a395..fd13e006d55 100644 --- a/build.sbt +++ b/build.sbt @@ -305,7 +305,6 @@ lazy val zioHttpGen = (project in file("zio-http-gen")) `zio`, `zio-test`, `zio-test-sbt`, - `zio-parser`, `zio-config`, scalafmt.cross(CrossVersion.for3Use2_13), scalametaParsers diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index d41345ecd52..42fe931438f 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -373,7 +373,7 @@ final case class EndpointGen(config: Config) { } private def fieldName(op: OpenAPI.Operation, fallback: String) = - Code.Field(op.operationId.getOrElse(fallback)) + Code.Field(op.operationId.getOrElse(fallback), config.fieldsNormalizationConf) private def endpoint( segments: List[Code.PathSegmentCode], @@ -1164,7 +1164,7 @@ final case class EndpointGen(config: Config) { case JsonSchema.AnnotatedSchema(s, _) => schemaToField(s.withoutAnnotations, openAPI, name, schema.annotations) case JsonSchema.RefSchema(SchemaRef(ref)) => - Some(Code.Field(name, Code.TypeRef(ref.capitalize))) + Some(Code.Field(name, Code.TypeRef(ref.capitalize), config.fieldsNormalizationConf)) case JsonSchema.RefSchema(ref) => throw new Exception(s" Not found: $ref. Only references to internal schemas are supported.") case JsonSchema.Integer( @@ -1189,7 +1189,7 @@ final case class EndpointGen(config: Config) { exclusiveMax.collect { case l if l <= Int.MaxValue => safeCastLongToInt(l) }, ) - Some(Code.Field(name, Code.Primitive.ScalaInt, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaInt, annotations, config.fieldsNormalizationConf)) case JsonSchema.Integer( JsonSchema.IntegerFormat.Int64, minimum, @@ -1208,7 +1208,7 @@ final case class EndpointGen(config: Config) { else maximum.map(_ + 1) val annotations = addNumericValidations[Long](exclusiveMin, exclusiveMax) - Some(Code.Field(name, Code.Primitive.ScalaLong, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaLong, annotations, config.fieldsNormalizationConf)) case JsonSchema.Integer( JsonSchema.IntegerFormat.Timestamp, minimum, @@ -1226,16 +1226,16 @@ final case class EndpointGen(config: Config) { else if (exclusiveMaximum.isDefined && exclusiveMaximum.get.isRight) exclusiveMaximum.get.toOption else maximum.map(_ + 1) val annotations = addNumericValidations[Long](exclusiveMin, exclusiveMax) - Some(Code.Field(name, Code.Primitive.ScalaLong, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaLong, annotations, config.fieldsNormalizationConf)) case JsonSchema.String(Some(JsonSchema.StringFormat.UUID), _, maxLength, minLength) => val annotations = addStringValidations(minLength, maxLength) - Some(Code.Field(name, Code.Primitive.ScalaUUID, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaUUID, annotations, config.fieldsNormalizationConf)) case JsonSchema.String(_, _, maxLength, minLength) => val annotations = addStringValidations(minLength, maxLength) - Some(Code.Field(name, Code.Primitive.ScalaString, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaString, annotations, config.fieldsNormalizationConf)) case JsonSchema.Boolean => - Some(Code.Field(name, Code.Primitive.ScalaBoolean)) + Some(Code.Field(name, Code.Primitive.ScalaBoolean, config.fieldsNormalizationConf)) case JsonSchema.OneOfSchema(schemas) => val tpe = schemas @@ -1243,7 +1243,7 @@ final case class EndpointGen(config: Config) { .flatMap(schemaToField(_, openAPI, "unused", annotations)) .map(_.fieldType) .reduceLeft(Code.ScalaType.Or.apply) - Some(Code.Field(name, tpe)) + Some(Code.Field(name, tpe, config.fieldsNormalizationConf)) case JsonSchema.AllOfSchema(_) => throw new Exception("Inline allOf schemas are not supported for fields") case JsonSchema.AnyOfSchema(schemas) => @@ -1253,7 +1253,7 @@ final case class EndpointGen(config: Config) { .flatMap(schemaToField(_, openAPI, "unused", annotations)) .map(_.fieldType) .reduceLeft(Code.ScalaType.Or.apply) - Some(Code.Field(name, tpe)) + Some(Code.Field(name, tpe, config.fieldsNormalizationConf)) case JsonSchema.Number(JsonSchema.NumberFormat.Double, minimum, exclusiveMinimum, maximum, exclusiveMaximum, _) => val exclusiveMin = if (exclusiveMinimum.isDefined && exclusiveMinimum.get == Left(true)) minimum @@ -1265,7 +1265,7 @@ final case class EndpointGen(config: Config) { else maximum.map(_ + 1) val annotations = addNumericValidations[Double](exclusiveMin, exclusiveMax) - Some(Code.Field(name, Code.Primitive.ScalaDouble, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaDouble, annotations, config.fieldsNormalizationConf)) case JsonSchema.Number(JsonSchema.NumberFormat.Float, minimum, exclusiveMinimum, maximum, exclusiveMaximum, _) => val exclusiveMin = if (exclusiveMinimum.isDefined && exclusiveMinimum.get == Left(true)) minimum @@ -1280,7 +1280,7 @@ final case class EndpointGen(config: Config) { exclusiveMin.collect { case l if l >= Float.MinValue => safeCastDoubleToFloat(l) }, exclusiveMax.collect { case l if l <= Float.MaxValue => safeCastDoubleToFloat(l) }, ) - Some(Code.Field(name, Code.Primitive.ScalaFloat, annotations)) + Some(Code.Field(name, Code.Primitive.ScalaFloat, annotations, config.fieldsNormalizationConf)) case JsonSchema.ArrayType(items, minItems, uniqueItems) => val nonEmpty = minItems.exists(_ > 1) val tpe = items @@ -1291,7 +1291,7 @@ final case class EndpointGen(config: Config) { if (uniqueItems) Code.Primitive.ScalaString.set(nonEmpty) else Code.Primitive.ScalaString.seq(nonEmpty) }, ) - tpe.map(Code.Field(name, _)) + tpe.map(Code.Field(name, _, config.fieldsNormalizationConf)) case JsonSchema.Object(properties, additionalProperties, _) if properties.nonEmpty && additionalProperties.isRight => // Can't be an object and a map at the same time @@ -1329,16 +1329,17 @@ final case class EndpointGen(config: Config) { ) }, ), + config.fieldsNormalizationConf, ), ) case JsonSchema.Object(_, _, _) => - Some(Code.Field(name, Code.TypeRef(name.capitalize))) + Some(Code.Field(name, Code.TypeRef(name.capitalize), config.fieldsNormalizationConf)) case JsonSchema.Enum(_) => - Some(Code.Field(name, Code.TypeRef(name.capitalize))) + Some(Code.Field(name, Code.TypeRef(name.capitalize), config.fieldsNormalizationConf)) case JsonSchema.Null => - Some(Code.Field(name, Code.ScalaType.Unit)) + Some(Code.Field(name, Code.ScalaType.Unit, config.fieldsNormalizationConf)) case JsonSchema.AnyJson => - Some(Code.Field(name, Code.ScalaType.JsonAST)) + Some(Code.Field(name, Code.ScalaType.JsonAST, config.fieldsNormalizationConf)) } } diff --git a/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala index 1b9c4e9f07f..100a7404311 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala @@ -3,11 +3,10 @@ package zio.http.gen.scala import scala.meta.Term import scala.meta.prettyprinters.XtensionSyntax +import zio.http.gen.openapi import zio.http.gen.openapi.Config.NormalizeFields import zio.http.{Method, Status} -import com.sun.tools.javac.code.TypeMetadata.Annotations - sealed trait Code extends Product with Serializable object Code { @@ -146,18 +145,61 @@ object Code { object Field { - def apply(name: String): Field = apply(name, ScalaType.Inferred) - def apply(name: String, fieldType: ScalaType): Field = { - val validScalaTermName = Term.Name(name).syntax - new Field(validScalaTermName, fieldType, Nil) {} - } - def apply(name: String, fieldType: ScalaType, annotation: Annotation): Field = { - val validScalaTermName = Term.Name(name).syntax - new Field(validScalaTermName, fieldType, List(annotation)) {} + def apply(name: String): Field = + apply(name, ScalaType.Inferred) + + def apply(name: String, conf: NormalizeFields): Field = + apply(name, ScalaType.Inferred, conf) + + def apply(name: String, fieldType: ScalaType): Field = + apply(name, fieldType, openapi.Config.default.fieldsNormalizationConf) + + def apply(name: String, fieldType: ScalaType, conf: NormalizeFields): Field = + apply(name, fieldType, Nil, conf) + + def apply(name: String, fieldType: ScalaType, annotation: Annotation, conf: NormalizeFields): Field = + apply(name, fieldType, List(annotation), conf) + + def apply(name: String, fieldType: ScalaType, annotations: List[Annotation], conf: NormalizeFields): Field = { + + def mkValidScalaTermName(term: String) = Term.Name(term).syntax + + val (validScalaTermName, originalFieldNameAnnotation) = conf.specialReplacements + .get(name) + .orElse(if (conf.enabled) normalize(name) else None) + .fold(mkValidScalaTermName(name) -> Option.empty[Annotation]) { maybeValidScala => + val valid = mkValidScalaTermName(maybeValidScala) + // if modified name is an invalid scala term, + // then no reason to use backticks wrapped non-original name. + // In this case we return the original name, + // after possibly wrapping with backticks. + if (valid != maybeValidScala) mkValidScalaTermName(name) -> Option.empty[Annotation] + else { + val annotationString = "@fieldName(\"" + name + "\")" + val annotationImport = List(Import("zio.schema.annotation.fieldName")) + maybeValidScala -> Some(Annotation(annotationString, annotationImport)) + } + } + + val allAnnotations = originalFieldNameAnnotation.fold(annotations)(annotations.::) + new Field(validScalaTermName, fieldType, allAnnotations) {} } - def apply(name: String, fieldType: ScalaType, annotations: List[Annotation]): Field = { - val validScalaTermName = Term.Name(name).syntax - new Field(validScalaTermName, fieldType, annotations) {} + + private val regex = "(?<=[a-z0-9])(?=[A-Z0-9])|(?<=[A-Z0-9])(?=[A-Z0-9][a-z0-9])|[^a-zA-Z0-9]+" + + def normalize(name: String): Option[String] = { + + name + .split(regex) + .toList match { + case Nil => None + case head :: tail => + val normalized = (head.toLowerCase :: tail.map(_.capitalize)).mkString + // no need to normalize if the name is already normalized + // returning None here will signal there's no need for annotation. + if (normalized == name) None + else Some(normalized) + } } } diff --git a/zio-http-gen/src/test/resources/ComponentOrderWithNormalizedFieldNames.scala b/zio-http-gen/src/test/resources/ComponentOrderWithNormalizedFieldNames.scala new file mode 100644 index 00000000000..86104b49bc1 --- /dev/null +++ b/zio-http-gen/src/test/resources/ComponentOrderWithNormalizedFieldNames.scala @@ -0,0 +1,22 @@ +package test.component + +import zio.schema._ +import zio.schema.annotation.fieldName +import zio.schema.annotation.validate +import zio.schema.validation.Validation +import java.util.UUID + +case class Order( + @fieldName("2nd item") secondItem: Option[String], + @fieldName("3rd item") thirdItem: Option[String], + @fieldName("num-of-items") + @validate[Int](Validation.greaterThan(0)) numOfItems: Int, + @fieldName("1st item") firstItem: String, + @fieldName("price in dollars") + @validate[Double](Validation.greaterThan(-1.0)) priceInDollars: Double, + @fieldName("PRODUCT_NAME") productNAME: String, + id: UUID, +) +object Order { + implicit val codec: Schema[Order] = DeriveSchema.gen[Order] +} \ No newline at end of file diff --git a/zio-http-gen/src/test/resources/inline_schema_weird_field_names.yaml b/zio-http-gen/src/test/resources/inline_schema_weird_field_names.yaml new file mode 100644 index 00000000000..b52c9ff73bc --- /dev/null +++ b/zio-http-gen/src/test/resources/inline_schema_weird_field_names.yaml @@ -0,0 +1,76 @@ +info: + title: Shop Service + version: 0.0.1 +servers: + - url: http://127.0.0.1:5000/ +tags: + - name: Order_API +paths: + /api/v1/shop/history/{id}: + get: + operationId: get_user_history + parameters: + - in: path + name: id + schema: + $ref: '#/components/schemas/UserId' + required: true + tags: + - Order_API + description: Get user order history by user id + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/UserOrderHistory' + description: OK +openapi: 3.0.3 +components: + schemas: + UserOrderHistory: + type: object + required: + - user_id + - history + properties: + user_id: + $ref: '#/components/schemas/UserId' + history: + type: object + additionalProperties: + $ref: '#/components/schemas/Order' + x-string-key-schema: + $ref: '#/components/schemas/OrderId' + Order: + type: object + required: + - id + - PRODUCT_NAME + - num-of-items + - price in dollars + - 1st item + properties: + id: + $ref: '#/components/schemas/OrderId' + PRODUCT_NAME: + type: string + num-of-items: + type: integer + format: int32 + minimum: 1 + price in dollars: + type: number + minimum: 0 + 1st item: + type: string + 2nd item: + type: string + 3rd item: + type: string + OrderId: + type: string + format: uuid + UserId: + type: string + format: uuid diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index c720f696f81..653c8e1eccd 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -24,6 +24,7 @@ import zio.http.codec._ import zio.http.endpoint.Endpoint import zio.http.endpoint.openapi.{OpenAPI, OpenAPIGen} import zio.http.gen.model._ +import zio.http.gen.openapi.Config.NormalizeFields import zio.http.gen.openapi.{Config, EndpointGen} @nowarn("msg=missing interpolator") @@ -867,5 +868,37 @@ object CodeGenSpec extends ZIOSpecDefault { } } }, - ) @@ java11OrNewer /*@@ flaky*/ @@ blocking // Downloading scalafmt on CI is flaky + test("Endpoint with normalized field names") { + val openAPIString = stringFromResource("/inline_schema_weird_field_names.yaml") + + openApiFromYamlString(openAPIString) { oapi => + codeGenFromOpenAPI( + oapi, + Config.default.copy( + fieldsNormalizationConf = NormalizeFields( + enabled = true, + specialReplacements = Map( + "1st item" -> "firstItem", + "2nd item" -> "secondItem", + "3rd item" -> "thirdItem", + ), + ), + ), + ) { testDir => + allFilesShouldBe( + testDir.toFile, + List( + "api/v1/shop/history/Id.scala", + "component/Order.scala", + "component/UserOrderHistory.scala", + ), + ) && fileShouldBe( + testDir, + "component/Order.scala", + "/ComponentOrderWithNormalizedFieldNames.scala", + ) + } + } + } @@ TestAspect.exceptScala3, + ) @@ java11OrNewer @@ flaky @@ blocking // Downloading scalafmt on CI is flaky } diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/FieldNormalizationSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/FieldNormalizationSpec.scala new file mode 100644 index 00000000000..b4022ceba5a --- /dev/null +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/FieldNormalizationSpec.scala @@ -0,0 +1,45 @@ +package zio.http.gen.scala + +import zio.Scope +import zio.test.Assertion.{equalTo, isNone, isSome} +import zio.test._ + +object FieldNormalizationSpec extends ZIOSpecDefault { + + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("FieldNormalizationSpec")( + test("Simple lowercase (None signals no change)") { + assert(Code.Field.normalize("foo"))(isNone) + }, + test("Simple UPPERCASE") { + assert(Code.Field.normalize("FOO"))(isSome(equalTo("foo"))) + }, + test("preserve camelCase (None signals no change)") { + assert(Code.Field.normalize("fooBar"))(isNone) + }, + test("preserve camelCase with digits (None signals no change)") { + assert(Code.Field.normalize("fooBar42"))(isNone) + }, + test("preserve camelCase with digits #2 (None signals no change)") { + assert(Code.Field.normalize("foo42Bar"))(isNone) + }, + test("lowercase capitalized camelCase") { + assert(Code.Field.normalize("FooBar"))(isSome(equalTo("fooBar"))) + }, + test("preserve non-leading UPPERCASE (None signals no change)") { + assert(Code.Field.normalize("fooBAR"))(isNone) + }, + test("mixed camelSnake_case") { + assert(Code.Field.normalize("camelSnake_case"))(isSome(equalTo("camelSnakeCase"))) + }, + test("mixed snake_caseUPPERLower") { + assert(Code.Field.normalize("ARN_APIGateway"))(isSome(equalTo("arnAPIGateway"))) + }, + test("challenge with complex CamelCase") { + assert(Code.Field.normalize("UseWD40ToLossenBut3MToFasten"))(isSome(equalTo("useWD40ToLossenBut3MToFasten"))) + }, + test("with whitespaces") { + assert(Code.Field.normalize("white\tspace - as\nsep"))(isSome(equalTo("whiteSpaceAsSep"))) + }, + ) +}