Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schema customization plug-point #411

Merged
merged 4 commits into from
Aug 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions openapi3gen/openapi3gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ func (err *CycleError) Error() string { return "detected cycle" }
// Option allows tweaking SchemaRef generation
type Option func(*generatorOpt)

type SchemaCustomizerFn func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) *openapi3.Schema
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schema is already passed by pointer so let's have the return type only be error. The func in the tests you provide should then return on non-nil strconv errors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. Agree allowing the flexibility for someone to build a completely new schema object is unnecessary (which was the reason for the return previously).


type generatorOpt struct {
useAllExportedFields bool
throwErrorOnCycle bool
schemaCustomizer SchemaCustomizerFn
}

// UseAllExportedFields changes the default behavior of only
Expand All @@ -38,6 +41,12 @@ func ThrowErrorOnCycle() Option {
return func(x *generatorOpt) { x.throwErrorOnCycle = true }
}

// SchemaCustomizer allows customization of the schema that is generated
// for a field, for example to support an additional tagging scheme
fenollp marked this conversation as resolved.
Show resolved Hide resolved
func SchemaCustomizer(sc SchemaCustomizerFn) Option {
return func(x *generatorOpt) { x.schemaCustomizer = sc }
}

// NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef.
func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) {
g := NewGenerator(opts...)
Expand Down Expand Up @@ -73,23 +82,23 @@ func NewGenerator(opts ...Option) *Generator {

func (g *Generator) GenerateSchemaRef(t reflect.Type) (*openapi3.SchemaRef, error) {
//check generatorOpt consistency here
return g.generateSchemaRefFor(nil, t)
return g.generateSchemaRefFor(nil, t, "_root", "")
}

func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect.Type) (*openapi3.SchemaRef, error) {
if ref := g.Types[t]; ref != nil {
func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect.Type, name string, tag reflect.StructTag) (*openapi3.SchemaRef, error) {
if ref := g.Types[t]; ref != nil && g.opts.schemaCustomizer == nil {
g.SchemaRefs[ref]++
return ref, nil
}
ref, err := g.generateWithoutSaving(parents, t)
ref, err := g.generateWithoutSaving(parents, t, name, tag)
if ref != nil {
g.Types[t] = ref
g.SchemaRefs[ref]++
}
return ref, err
}

func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflect.Type) (*openapi3.SchemaRef, error) {
func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflect.Type, name string, tag reflect.StructTag) (*openapi3.SchemaRef, error) {
typeInfo := jsoninfo.GetTypeInfo(t)
for _, parent := range parents {
if parent == typeInfo {
Expand All @@ -110,7 +119,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
_, a := t.FieldByName("Ref")
v, b := t.FieldByName("Value")
if a && b {
vs, err := g.generateSchemaRefFor(parents, v.Type)
vs, err := g.generateSchemaRefFor(parents, v.Type, name, tag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
g.SchemaRefs[vs]++
Expand Down Expand Up @@ -195,7 +204,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
schema.Format = "byte"
} else {
schema.Type = "array"
items, err := g.generateSchemaRefFor(parents, t.Elem())
items, err := g.generateSchemaRefFor(parents, t.Elem(), name, tag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
items = g.generateCycleSchemaRef(t.Elem(), schema)
Expand All @@ -211,7 +220,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec

case reflect.Map:
schema.Type = "object"
additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem())
additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem(), name, tag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
additionalProperties = g.generateCycleSchemaRef(t.Elem(), schema)
Expand All @@ -235,11 +244,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
continue
}
// If asked, try to use yaml tag
name, fType := fieldInfo.JSONName, fieldInfo.Type
fieldName, fType := fieldInfo.JSONName, fieldInfo.Type
if !fieldInfo.HasJSONTag && g.opts.useAllExportedFields {
// Handle anonymous fields/embedded structs
if t.Field(fieldInfo.Index[0]).Anonymous {
ref, err := g.generateSchemaRefFor(parents, fType)
ref, err := g.generateSchemaRefFor(parents, fType, fieldName, tag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
ref = g.generateCycleSchemaRef(fType, schema)
Expand All @@ -249,17 +258,24 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
}
if ref != nil {
g.SchemaRefs[ref]++
schema.WithPropertyRef(name, ref)
schema.WithPropertyRef(fieldName, ref)
}
} else {
ff := t.Field(fieldInfo.Index[len(fieldInfo.Index)-1])
if tag, ok := ff.Tag.Lookup("yaml"); ok && tag != "-" {
name, fType = tag, ff.Type
fieldName, fType = tag, ff.Type
}
}
}

ref, err := g.generateSchemaRefFor(parents, fType)
// extract the field tag if we have a customizer
var fieldTag reflect.StructTag
if g.opts.schemaCustomizer != nil {
ff := t.Field(fieldInfo.Index[len(fieldInfo.Index)-1])
fieldTag = ff.Tag
}

ref, err := g.generateSchemaRefFor(parents, fType, fieldName, fieldTag)
if err != nil {
if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle {
ref = g.generateCycleSchemaRef(fType, schema)
Expand All @@ -269,7 +285,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
}
if ref != nil {
g.SchemaRefs[ref]++
schema.WithPropertyRef(name, ref)
schema.WithPropertyRef(fieldName, ref)
}
}

Expand All @@ -280,6 +296,10 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec
}
}

if g.opts.schemaCustomizer != nil {
schema = g.opts.schemaCustomizer(name, t, tag, schema)
}

return openapi3.NewSchemaRef(t.Name(), schema), nil
}

Expand Down
62 changes: 62 additions & 0 deletions openapi3gen/openapi3gen_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package openapi3gen

import (
"encoding/json"
"reflect"
"strconv"
"strings"
"testing"

"github.com/getkin/kin-openapi/openapi3"
Expand Down Expand Up @@ -144,3 +147,62 @@ func TestCyclicReferences(t *testing.T) {
require.Equal(t, "object", schemaRef.Value.Properties["MapCycle"].Value.Type)
require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["MapCycle"].Value.AdditionalProperties.Ref)
}

func TestSchemaCustomizer(t *testing.T) {
type Bla struct {
UntaggedStringField string
AnonStruct struct {
InnerFieldWithoutTag int
InnerFieldWithTag int `mymintag:"-1" mymaxtag:"50"`
}
EnumField string `json:"another" myenumtag:"a,b"`
}

schemaRef, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) *openapi3.Schema {
t.Logf("Field=%s,Tag=%s", name, tag)
if tag.Get("mymintag") != "" {
minVal, _ := strconv.ParseFloat(tag.Get("mymintag"), 64)
schema.Min = &minVal
}
if tag.Get("mymaxtag") != "" {
maxVal, _ := strconv.ParseFloat(tag.Get("mymaxtag"), 64)
schema.Max = &maxVal
}
if tag.Get("myenumtag") != "" {
for _, s := range strings.Split(tag.Get("myenumtag"), ",") {
schema.Enum = append(schema.Enum, s)
}
}
return schema
}))
require.NoError(t, err)
jsonSchema, _ := json.MarshalIndent(schemaRef, "", " ")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
jsonSchema, _ := json.MarshalIndent(schemaRef, "", " ")
jsonSchema, err := json.MarshalIndent(schemaRef, "", " ")
require.NoError(t, err)

require.Equal(t, `{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
require.Equal(t, `{
require.Equal(t, `{
require.JSONEq(t, `{

"properties": {
"AnonStruct": {
"properties": {
"InnerFieldWithTag": {
"maximum": 50,
"minimum": -1,
"type": "integer"
},
"InnerFieldWithoutTag": {
"type": "integer"
}
},
"type": "object"
},
"UntaggedStringField": {
"type": "string"
},
"another": {
"enum": [
"a",
"b"
],
"type": "string"
}
},
"type": "object"
}`, string(jsonSchema))
}