diff --git a/openapi3gen/openapi3gen.go b/openapi3gen/openapi3gen.go index 7c321fe7a..c21b782aa 100644 --- a/openapi3gen/openapi3gen.go +++ b/openapi3gen/openapi3gen.go @@ -21,9 +21,17 @@ func (err *CycleError) Error() string { return "detected cycle" } // Option allows tweaking SchemaRef generation type Option func(*generatorOpt) +// SchemaCustomizerFn is a callback function, allowing +// the OpenAPI schema definition to be updated with additional +// properties during the generation process, based on the +// name of the field, the Go type, and the struct tags. +// name will be "_root" for the top level object, and tag will be "" +type SchemaCustomizerFn func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error + type generatorOpt struct { useAllExportedFields bool throwErrorOnCycle bool + schemaCustomizer SchemaCustomizerFn } // UseAllExportedFields changes the default behavior of only @@ -38,6 +46,12 @@ func ThrowErrorOnCycle() Option { return func(x *generatorOpt) { x.throwErrorOnCycle = true } } +// SchemaCustomizer allows customization of the schema that is generated +// for a field, for example to support an additional tagging scheme +func SchemaCustomizer(sc SchemaCustomizerFn) Option { + return func(x *generatorOpt) { x.schemaCustomizer = sc } +} + // NewSchemaRefForValue uses reflection on the given value to produce a SchemaRef. func NewSchemaRefForValue(value interface{}, opts ...Option) (*openapi3.SchemaRef, map[*openapi3.SchemaRef]int, error) { g := NewGenerator(opts...) @@ -73,15 +87,15 @@ func NewGenerator(opts ...Option) *Generator { func (g *Generator) GenerateSchemaRef(t reflect.Type) (*openapi3.SchemaRef, error) { //check generatorOpt consistency here - return g.generateSchemaRefFor(nil, t) + return g.generateSchemaRefFor(nil, t, "_root", "") } -func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect.Type) (*openapi3.SchemaRef, error) { - if ref := g.Types[t]; ref != nil { +func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect.Type, name string, tag reflect.StructTag) (*openapi3.SchemaRef, error) { + if ref := g.Types[t]; ref != nil && g.opts.schemaCustomizer == nil { g.SchemaRefs[ref]++ return ref, nil } - ref, err := g.generateWithoutSaving(parents, t) + ref, err := g.generateWithoutSaving(parents, t, name, tag) if ref != nil { g.Types[t] = ref g.SchemaRefs[ref]++ @@ -89,7 +103,7 @@ func (g *Generator) generateSchemaRefFor(parents []*jsoninfo.TypeInfo, t reflect return ref, err } -func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflect.Type) (*openapi3.SchemaRef, error) { +func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflect.Type, name string, tag reflect.StructTag) (*openapi3.SchemaRef, error) { typeInfo := jsoninfo.GetTypeInfo(t) for _, parent := range parents { if parent == typeInfo { @@ -110,7 +124,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec _, a := t.FieldByName("Ref") v, b := t.FieldByName("Value") if a && b { - vs, err := g.generateSchemaRefFor(parents, v.Type) + vs, err := g.generateSchemaRefFor(parents, v.Type, name, tag) if err != nil { if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { g.SchemaRefs[vs]++ @@ -195,7 +209,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec schema.Format = "byte" } else { schema.Type = "array" - items, err := g.generateSchemaRefFor(parents, t.Elem()) + items, err := g.generateSchemaRefFor(parents, t.Elem(), name, tag) if err != nil { if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { items = g.generateCycleSchemaRef(t.Elem(), schema) @@ -211,7 +225,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec case reflect.Map: schema.Type = "object" - additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem()) + additionalProperties, err := g.generateSchemaRefFor(parents, t.Elem(), name, tag) if err != nil { if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { additionalProperties = g.generateCycleSchemaRef(t.Elem(), schema) @@ -235,11 +249,11 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec continue } // If asked, try to use yaml tag - name, fType := fieldInfo.JSONName, fieldInfo.Type + fieldName, fType := fieldInfo.JSONName, fieldInfo.Type if !fieldInfo.HasJSONTag && g.opts.useAllExportedFields { // Handle anonymous fields/embedded structs if t.Field(fieldInfo.Index[0]).Anonymous { - ref, err := g.generateSchemaRefFor(parents, fType) + ref, err := g.generateSchemaRefFor(parents, fType, fieldName, tag) if err != nil { if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { ref = g.generateCycleSchemaRef(fType, schema) @@ -249,17 +263,24 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec } if ref != nil { g.SchemaRefs[ref]++ - schema.WithPropertyRef(name, ref) + schema.WithPropertyRef(fieldName, ref) } } else { ff := t.Field(fieldInfo.Index[len(fieldInfo.Index)-1]) if tag, ok := ff.Tag.Lookup("yaml"); ok && tag != "-" { - name, fType = tag, ff.Type + fieldName, fType = tag, ff.Type } } } - ref, err := g.generateSchemaRefFor(parents, fType) + // extract the field tag if we have a customizer + var fieldTag reflect.StructTag + if g.opts.schemaCustomizer != nil { + ff := t.Field(fieldInfo.Index[len(fieldInfo.Index)-1]) + fieldTag = ff.Tag + } + + ref, err := g.generateSchemaRefFor(parents, fType, fieldName, fieldTag) if err != nil { if _, ok := err.(*CycleError); ok && !g.opts.throwErrorOnCycle { ref = g.generateCycleSchemaRef(fType, schema) @@ -269,7 +290,7 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec } if ref != nil { g.SchemaRefs[ref]++ - schema.WithPropertyRef(name, ref) + schema.WithPropertyRef(fieldName, ref) } } @@ -280,6 +301,12 @@ func (g *Generator) generateWithoutSaving(parents []*jsoninfo.TypeInfo, t reflec } } + if g.opts.schemaCustomizer != nil { + if err := g.opts.schemaCustomizer(name, t, tag, schema); err != nil { + return nil, err + } + } + return openapi3.NewSchemaRef(t.Name(), schema), nil } diff --git a/openapi3gen/openapi3gen_test.go b/openapi3gen/openapi3gen_test.go index 0b7fde9e6..7988acf9e 100644 --- a/openapi3gen/openapi3gen_test.go +++ b/openapi3gen/openapi3gen_test.go @@ -1,7 +1,11 @@ package openapi3gen import ( + "encoding/json" + "fmt" "reflect" + "strconv" + "strings" "testing" "github.com/getkin/kin-openapi/openapi3" @@ -144,3 +148,71 @@ func TestCyclicReferences(t *testing.T) { require.Equal(t, "object", schemaRef.Value.Properties["MapCycle"].Value.Type) require.Equal(t, "#/components/schemas/ObjectDiff", schemaRef.Value.Properties["MapCycle"].Value.AdditionalProperties.Ref) } + +func TestSchemaCustomizer(t *testing.T) { + type Bla struct { + UntaggedStringField string + AnonStruct struct { + InnerFieldWithoutTag int + InnerFieldWithTag int `mymintag:"-1" mymaxtag:"50"` + } + EnumField string `json:"another" myenumtag:"a,b"` + } + + schemaRef, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + t.Logf("Field=%s,Tag=%s", name, tag) + if tag.Get("mymintag") != "" { + minVal, _ := strconv.ParseFloat(tag.Get("mymintag"), 64) + schema.Min = &minVal + } + if tag.Get("mymaxtag") != "" { + maxVal, _ := strconv.ParseFloat(tag.Get("mymaxtag"), 64) + schema.Max = &maxVal + } + if tag.Get("myenumtag") != "" { + for _, s := range strings.Split(tag.Get("myenumtag"), ",") { + schema.Enum = append(schema.Enum, s) + } + } + return nil + })) + require.NoError(t, err) + jsonSchema, err := json.MarshalIndent(schemaRef, "", " ") + require.NoError(t, err) + require.JSONEq(t, `{ + "properties": { + "AnonStruct": { + "properties": { + "InnerFieldWithTag": { + "maximum": 50, + "minimum": -1, + "type": "integer" + }, + "InnerFieldWithoutTag": { + "type": "integer" + } + }, + "type": "object" + }, + "UntaggedStringField": { + "type": "string" + }, + "another": { + "enum": [ + "a", + "b" + ], + "type": "string" + } + }, + "type": "object" +}`, string(jsonSchema)) +} + +func TestSchemaCustomizerError(t *testing.T) { + type Bla struct{} + _, _, err := NewSchemaRefForValue(&Bla{}, UseAllExportedFields(), SchemaCustomizer(func(name string, ft reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { + return fmt.Errorf("test error") + })) + require.EqualError(t, err, "test error") +}