Skip to content

Commit

Permalink
🔥 Feature: improve FromContext func and test
Browse files Browse the repository at this point in the history
  • Loading branch information
JIeJaitt committed Nov 12, 2024
1 parent 44479ad commit 69bd6ee
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 53 deletions.
26 changes: 13 additions & 13 deletions middleware/requestid/requestid.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/log"
)

// The contextKey type is unexported to prevent collisions with context keys defined in
Expand Down Expand Up @@ -49,19 +50,18 @@ func New(config ...Config) fiber.Handler {

// FromContext returns the request ID from context.
// If there is no request ID, an empty string is returned.
func FromContext(c fiber.Ctx) string {
if rid, ok := c.Locals(requestIDKey).(string); ok {
return rid
}
return ""
}

// FromUserContext returns the request ID from the UserContext.
// If there is no request ID, an empty string is returned.
// Compared to Local, UserContext is more suitable for transmitting requests between microservices
func FromUserContext(ctx context.Context) string {
if rid, ok := ctx.Value(requestIDKey).(string); ok {
return rid
func FromContext(c interface{}) string {
switch ctx := c.(type) {
case fiber.Ctx:
if rid, ok := ctx.Locals(requestIDKey).(string); ok {
return rid
}
case context.Context:
if rid, ok := ctx.Value(requestIDKey).(string); ok {
return rid
}
default:
log.Errorf("Unsupported context type: %T", c)
}
return ""
}
90 changes: 50 additions & 40 deletions middleware/requestid/requestid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,50 +51,60 @@ func Test_RequestID_Next(t *testing.T) {
require.Equal(t, fiber.StatusNotFound, resp.StatusCode)
}

// go test -run Test_RequestID_Locals
// go test -run Test_RequestID_FromContext
func Test_RequestID_FromContext(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"

app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
}))

var ctxVal string

app.Use(func(c fiber.Ctx) error {
ctxVal = FromContext(c)
return c.Next()
})

_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, reqID, ctxVal)
}

// go test -run Test_RequestID_FromUserContext
func Test_RequestID_FromUserContext(t *testing.T) {
t.Parallel()
reqID := "ThisIsARequestId"

app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
type args struct {
inputFunc func(c fiber.Ctx) interface{}
}

tests := []struct {
name string
args args
}{
{
name: "From fiber.Ctx",
args: args{
inputFunc: func(c fiber.Ctx) interface{} {
return c
},
},
},
}))

var ctxVal string

app.Use(func(c fiber.Ctx) error {
ctxVal = FromUserContext(c.Context())
return c.Next()
})

_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, reqID, ctxVal)
{
name: "From context.Context",
args: args{
inputFunc: func(c fiber.Ctx) interface{} {
return c.Context()
},
},
},
}

for _, tt := range tests {
tt := tt // Re bind variables
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return reqID
},
}))

var ctxVal string

app.Use(func(c fiber.Ctx) error {
ctxVal = FromContext(tt.args.inputFunc(c))
return c.Next()
})

_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
require.NoError(t, err)
require.Equal(t, reqID, ctxVal)
})
}
}

0 comments on commit 69bd6ee

Please sign in to comment.