Skip to content

Commit

Permalink
feat: add comment parsing to populate OpenAPI description fields
Browse files Browse the repository at this point in the history
  • Loading branch information
a-h committed Feb 13, 2023
1 parent 0c67623 commit fa9ef03
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 59 deletions.
6 changes: 5 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ func NewAPI(name string, opts ...APIOpts) *API {
KnownTypes: defaultKnownTypes,
Routes: make(map[Pattern]MethodToRoute),
// map of model name to schema.
models: map[string]*openapi3.Schema{},
models: make(map[string]*openapi3.Schema),
comments: make(map[string]map[string]string),
}
}

Expand Down Expand Up @@ -91,6 +92,9 @@ type API struct {
// adjust the OpenAPI specification.
configureSpec func(spec *openapi3.T)

// comments from the package. This can be cleared once the spec has been created.
comments map[string]map[string]string

// handler is a HTTP handler that serves up the OpenAPI specification and Swagger UI.
handler http.Handler
configured bool
Expand Down
3 changes: 3 additions & 0 deletions examples/chiexample/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ require (
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/perimeterx/marshmallow v1.1.4 // indirect
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
6 changes: 6 additions & 0 deletions examples/chiexample/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab h1:628ME69lBm9C6JY2wXhAph/yjN3jezx1z7BIDLUwxjo=
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
43 changes: 10 additions & 33 deletions examples/chiexample/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,17 @@ import (
"github.com/a-h/respond"
"github.com/a-h/rest"
"github.com/a-h/rest/chiadapter"
"github.com/a-h/rest/examples/chiexample/models"
"github.com/getkin/kin-openapi/openapi3"
"github.com/go-chi/chi/v5"
)

type Topic struct {
Namespace string `json:"namespace"`
Topic string `json:"topic"`
Private bool `json:"private"`
ViewCount int64 `json:"viewCount"`
}

type TopicsPostRequest struct {
Topic
}

type TopicsPostResponse struct {
ID string `json:"id"`
}

type TopicsGetResponse struct {
Topics []TopicRecord `json:"topics"`
}

type TopicRecord struct {
ID string `json:"id"`
Topic
}

func main() {
// Define routes in any router.
router := chi.NewRouter()

router.Get("/topic/{id}", func(w http.ResponseWriter, r *http.Request) {
resp := Topic{
resp := models.Topic{
Namespace: "example",
Topic: "topic",
Private: false,
Expand All @@ -50,11 +27,11 @@ func main() {
})

router.Get("/topics", func(w http.ResponseWriter, r *http.Request) {
resp := TopicsGetResponse{
Topics: []TopicRecord{
resp := models.TopicsGetResponse{
Topics: []models.TopicRecord{
{
ID: "testId",
Topic: Topic{
Topic: models.Topic{
Namespace: "example",
Topic: "topic",
Private: false,
Expand All @@ -67,7 +44,7 @@ func main() {
})

router.Post("/topics", func(w http.ResponseWriter, r *http.Request) {
resp := TopicsPostResponse{ID: "123"}
resp := models.TopicsPostResponse{ID: "123"}
respond.WithJSON(w, resp, http.StatusOK)
})

Expand All @@ -90,16 +67,16 @@ func main() {

// Document the routes.
api.Get("/topic/{id}").
HasResponseModel(http.StatusOK, rest.ModelOf[TopicsGetResponse]()).
HasResponseModel(http.StatusOK, rest.ModelOf[models.TopicsGetResponse]()).
HasResponseModel(http.StatusInternalServerError, rest.ModelOf[respond.Error]())

api.Get("/topics").
HasResponseModel(http.StatusOK, rest.ModelOf[TopicsGetResponse]()).
HasResponseModel(http.StatusOK, rest.ModelOf[models.TopicsGetResponse]()).
HasResponseModel(http.StatusInternalServerError, rest.ModelOf[respond.Error]())

api.Post("/topics").
HasRequestModel(rest.ModelOf[TopicsPostRequest]()).
HasResponseModel(http.StatusOK, rest.ModelOf[TopicsPostResponse]()).
HasRequestModel(rest.ModelOf[models.TopicsPostRequest]()).
HasResponseModel(http.StatusOK, rest.ModelOf[models.TopicsPostResponse]()).
HasResponseModel(http.StatusInternalServerError, rest.ModelOf[respond.Error]())

api.ConfigureSpec(func(spec *openapi3.T) {
Expand Down
28 changes: 28 additions & 0 deletions examples/chiexample/models/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package models

type Topic struct {
Namespace string `json:"namespace"`
Topic string `json:"topic"`
Private bool `json:"private"`
ViewCount int64 `json:"viewCount"`
}

// TopicsPostRequest is the request to POST /topics.
type TopicsPostRequest struct {
Topic
}

type TopicsPostResponse struct {
ID string `json:"id"`
}

// TopicsGetResponse is the response to GET /topics.
type TopicsGetResponse struct {
Topics []TopicRecord `json:"topics"`
}

type TopicRecord struct {
// ID of the topic record.
ID string `json:"id"`
Topic
}
3 changes: 3 additions & 0 deletions examples/stdlib/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ require (
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/perimeterx/marshmallow v1.1.4 // indirect
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
6 changes: 6 additions & 0 deletions examples/stdlib/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab h1:628ME69lBm9C6JY2wXhAph/yjN3jezx1z7BIDLUwxjo=
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
1 change: 1 addition & 0 deletions examples/stdlib/models/models.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package models

// Topic of a thread.
type Topic struct {
Namespace string `json:"namespace"`
Topic string `json:"topic"`
Expand Down
10 changes: 9 additions & 1 deletion getcomments/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@ package main

import (
"encoding/json"
"flag"
"fmt"
"log"
"os"

"github.com/a-h/rest/getcomments/parser"
)

var flagPackage = flag.String("package", "", "The package to retrieve comments from, e.g. github.com/a-h/rest/getcomments/example")

func main() {
m, err := parser.Get("github.com/a-h/rest/getcomments/parser/example")
flag.Parse()
if *flagPackage == "" {
flag.Usage()
os.Exit(0)
}
m, err := parser.Get(*flagPackage)
if err != nil {
log.Fatalf("failed to parse: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion getcomments/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (

func Get(packageName string) (m map[string]string, err error) {
config := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
Tests: true,
}
pkgs, err := packages.Load(config, packageName)
if err != nil {
Expand Down
42 changes: 40 additions & 2 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"sort"
"strings"

"github.com/a-h/rest/getcomments/parser"
"github.com/getkin/kin-openapi/openapi3"
"golang.org/x/exp/constraints"
)
Expand Down Expand Up @@ -245,6 +246,9 @@ func (api *API) RegisterModel(model Model, opts ...ModelOpts) (name string, sche
schema.AdditionalProperties.Schema = getSchemaReferenceOrValue(elementName, elementSchema)
case reflect.Struct:
schema = openapi3.NewObjectSchema()
if schema.Description, err = api.getTypeComment(t.PkgPath(), t.Name()); err != nil {
return name, schema, fmt.Errorf("failed to get comments for type %q: %w", name, err)
}
schema.Properties = make(openapi3.Schemas)
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
Expand All @@ -260,7 +264,7 @@ func (api *API) RegisterModel(model Model, opts ...ModelOpts) (name string, sche
_, alreadyExists := api.models[api.getModelName(f.Type)]
fieldSchemaName, fieldSchema, err := api.RegisterModel(ModelFromType(f.Type))
if err != nil {
return fieldName, schema, fmt.Errorf("error getting schema for type %q, field %q, failed to get schema for embedded type %q: %w", t, fieldName, f.Type, err)
return name, schema, fmt.Errorf("error getting schema for type %q, field %q, failed to get schema for embedded type %q: %w", t, fieldName, f.Type, err)
}
if f.Anonymous {
// It's an anonymous type, no need for a reference to it,
Expand All @@ -274,7 +278,13 @@ func (api *API) RegisterModel(model Model, opts ...ModelOpts) (name string, sche
}
continue
}
schema.Properties[fieldName] = getSchemaReferenceOrValue(fieldSchemaName, fieldSchema)
ref := getSchemaReferenceOrValue(fieldSchemaName, fieldSchema)
if ref.Value != nil {
if ref.Value.Description, err = api.getTypeFieldComment(t.PkgPath(), t.Name(), fieldName); err != nil {
return name, schema, fmt.Errorf("failed to get comments for field %q in type %q: %w", fieldName, name, err)
}
}
schema.Properties[fieldName] = ref
}
}

Expand All @@ -295,6 +305,34 @@ func (api *API) RegisterModel(model Model, opts ...ModelOpts) (name string, sche
return
}

func (api *API) getCommentsForPackage(pkg string) (pkgComments map[string]string, err error) {
if pkgComments, loaded := api.comments[pkg]; loaded {
return pkgComments, nil
}
pkgComments, err = parser.Get(pkg)
if err != nil {
return
}
api.comments[pkg] = pkgComments
return
}

func (api *API) getTypeComment(pkg string, name string) (comment string, err error) {
pkgComments, err := api.getCommentsForPackage(pkg)
if err != nil {
return
}
return pkgComments[pkg+"."+name], nil
}

func (api *API) getTypeFieldComment(pkg string, name string, field string) (comment string, err error) {
pkgComments, err := api.getCommentsForPackage(pkg)
if err != nil {
return
}
return pkgComments[pkg+"."+name+"."+field], nil
}

func shouldBeReferenced(schema *openapi3.Schema) bool {
if schema.Type == openapi3.TypeObject && schema.AdditionalProperties.Schema == nil {
return true
Expand Down
32 changes: 14 additions & 18 deletions schema_test.go → tests/schema_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package rest_test
package test

import (
"embed"
Expand All @@ -20,14 +20,16 @@ import (
"gopkg.in/yaml.v2"
)

//go:embed tests/*
//go:embed *
var testFiles embed.FS

type TestRequestType struct {
IntField int
}

// TestResponseType description.
type TestResponseType struct {
// IntField description.
IntField int
}

Expand All @@ -45,13 +47,10 @@ type AllBasicDataTypes struct {
Uintptr uintptr
Float32 float32
Float64 float64
// Complex types are not supported by the Go JSON serializer.
//Complex64 complex64
//Complex128 complex128
Byte byte
Rune rune
String string
Bool bool
Byte byte
Rune rune
String string
Bool bool
}

type AllBasicDataTypesPointers struct {
Expand All @@ -68,13 +67,10 @@ type AllBasicDataTypesPointers struct {
Uintptr *uintptr
Float32 *float32
Float64 *float64
// Complex types are not supported by the Go JSON serializer.
//Complex64 *complex64
//Complex128 *complex128
Byte *byte
Rune *rune
String *string
Bool *bool
Byte *byte
Rune *rune
String *string
Bool *bool
}

type EmbeddedStructA struct {
Expand Down Expand Up @@ -276,7 +272,7 @@ func TestSchema(t *testing.T) {
go func() {
defer wg.Done()
// Load test file.
expectedYAML, err := testFiles.ReadFile("tests/" + test.name)
expectedYAML, err := testFiles.ReadFile(test.name)
if err != nil {
errs[0] = fmt.Errorf("could not read file %q: %v", test.name, err)
return
Expand Down Expand Up @@ -319,7 +315,7 @@ func TestSchema(t *testing.T) {
// Compare the JSON marshalled output to ignore unexported fields and internal state.
if diff := cmp.Diff(expected, actual); diff != "" {
t.Error(diff)
t.Error(string(actual))
t.Error("\n\n" + string(actual))
}
})
}
Expand Down
Loading

0 comments on commit fa9ef03

Please sign in to comment.