Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions core/conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo, fullName st
return nil
}

func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildAnonymousFieldInfoWithVisited(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string, visited map[reflect.Type]bool) error {
switch ft.Kind() {
case reflect.Struct:
fields, err := buildFieldsInfo(ft, fullName)
fields, err := buildFieldsInfoWithVisited(ft, fullName, visited)
if err != nil {
return err
}
Expand All @@ -166,7 +166,7 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
}
}
case reflect.Map:
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemField, err := buildFieldsInfoWithVisited(mapping.Deref(ft.Elem()), fullName, visited)
if err != nil {
return err
}
Expand All @@ -193,13 +193,25 @@ func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.T
}

func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
return buildFieldsInfoWithVisited(tp, fullName, make(map[reflect.Type]bool))
}

func buildFieldsInfoWithVisited(tp reflect.Type, fullName string, visited map[reflect.Type]bool) (*fieldInfo, error) {
tp = mapping.Deref(tp)

if visited[tp] {
return &fieldInfo{
children: make(map[string]*fieldInfo),
}, nil
}

switch tp.Kind() {
case reflect.Struct:
return buildStructFieldsInfo(tp, fullName)
visited[tp] = true
defer delete(visited, tp)
return buildStructFieldsInfoWithVisited(tp, fullName, visited)
case reflect.Array, reflect.Slice, reflect.Map:
return buildFieldsInfo(mapping.Deref(tp.Elem()), fullName)
return buildFieldsInfoWithVisited(mapping.Deref(tp.Elem()), fullName, visited)
case reflect.Chan, reflect.Func:
return nil, fmt.Errorf("unsupported type: %s, fullName: %s", tp.Kind(), fullName)
default:
Expand All @@ -209,23 +221,23 @@ func buildFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
}
}

func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string) error {
func buildNamedFieldInfoWithVisited(info *fieldInfo, lowerCaseName string, ft reflect.Type, fullName string, visited map[reflect.Type]bool) error {
var finfo *fieldInfo
var err error

switch ft.Kind() {
case reflect.Struct:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfoWithVisited(ft, fullName, visited)
if err != nil {
return err
}
case reflect.Array, reflect.Slice:
finfo, err = buildFieldsInfo(ft.Elem(), fullName)
finfo, err = buildFieldsInfoWithVisited(ft.Elem(), fullName, visited)
if err != nil {
return err
}
case reflect.Map:
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()), fullName)
elemInfo, err := buildFieldsInfoWithVisited(mapping.Deref(ft.Elem()), fullName, visited)
if err != nil {
return err
}
Expand All @@ -235,7 +247,7 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
mapField: elemInfo,
}
default:
finfo, err = buildFieldsInfo(ft, fullName)
finfo, err = buildFieldsInfoWithVisited(ft, fullName, visited)
if err != nil {
return err
}
Expand All @@ -244,7 +256,7 @@ func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type,
return addOrMergeFields(info, lowerCaseName, finfo, fullName)
}

func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error) {
func buildStructFieldsInfoWithVisited(tp reflect.Type, fullName string, visited map[reflect.Type]bool) (*fieldInfo, error) {
info := &fieldInfo{
children: make(map[string]*fieldInfo),
}
Expand All @@ -260,12 +272,12 @@ func buildStructFieldsInfo(tp reflect.Type, fullName string) (*fieldInfo, error)
ft := mapping.Deref(field.Type)
// flatten anonymous fields
if field.Anonymous {
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
if err := buildAnonymousFieldInfoWithVisited(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName), visited); err != nil {
return nil, err
}
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName)); err != nil {
} else if err := buildNamedFieldInfoWithVisited(info, lowerCaseName, ft,
getFullName(fullName, lowerCaseName), visited); err != nil {
return nil, err
}
}
Expand Down
72 changes: 72 additions & 0 deletions core/conf/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,78 @@ func Test_buildFieldsInfo(t *testing.T) {
}
}

func Test_buildFieldsInfo_CircularReference(t *testing.T) {
type MySQLConfig struct {
Alias string `json:",optional"`
DSN string `json:",optional"`
Type string `json:",optional"`
MaxOpenConns int `json:",optional"`
MaxIdleConns int `json:",optional"`
Slave []*MySQLConfig `json:",optional"` // Self-reference slice
}

type CountryConfig struct {
MySQL MySQLConfig `json:",optional"`
}

type GlobalConfig struct {
CN CountryConfig `json:",optional"`
}

tests := []struct {
name string
t reflect.Type
}{
{
name: "direct circular reference",
t: reflect.TypeOf(MySQLConfig{}),
},
{
name: "nested circular reference",
t: reflect.TypeOf(GlobalConfig{}),
},
{
name: "pointer circular reference",
t: reflect.TypeOf(&MySQLConfig{}),
},
{
name: "slice of circular reference",
t: reflect.TypeOf([]MySQLConfig{}),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
info, err := buildFieldsInfo(tt.t, "test")
assert.NoError(t, err)
assert.NotNil(t, info)
assert.NotNil(t, info.children)
})
}
}

func Test_buildFieldsInfoWithVisited_CircularDetection(t *testing.T) {
type CircularStruct struct {
Name string `json:",optional"`
Children []*CircularStruct `json:",optional"`
}

visited := make(map[reflect.Type]bool)
tp := reflect.TypeOf(CircularStruct{})

info1, err1 := buildFieldsInfoWithVisited(tp, "test1", visited)
assert.NoError(t, err1)
assert.NotNil(t, info1)

visited[tp] = true

info2, err2 := buildFieldsInfoWithVisited(tp, "test2", visited)
assert.NoError(t, err2)
assert.NotNil(t, info2)
assert.NotNil(t, info2.children)
assert.Equal(t, 0, len(info2.children))
}

func createTempFile(t *testing.T, ext, text string) (string, error) {
tmpFile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
if err != nil {
Expand Down
21 changes: 20 additions & 1 deletion core/mapping/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,25 @@ func ensureValue(v reflect.Value) reflect.Value {
}

func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
return implicitValueRequiredStructWithDepth(tag, tp, 0, make(map[reflect.Type]bool))
}

func implicitValueRequiredStructWithDepth(tag string, tp reflect.Type, depth int, visited map[reflect.Type]bool) (bool, error) {

// max depth to avoid too deep recursion
const maxDepth = 100
if depth > maxDepth {
return false, nil
}

tp = Deref(tp)
if visited[tp] {
return false, nil
}

visited[tp] = true
defer delete(visited, tp)

numFields := tp.NumField()
for i := 0; i < numFields; i++ {
childField := tp.Field(i)
Expand All @@ -215,7 +234,7 @@ func implicitValueRequiredStruct(tag string, tp reflect.Type) (bool, error) {
return true, nil
}

if required, err := implicitValueRequiredStruct(tag, childField.Type); err != nil {
if required, err := implicitValueRequiredStructWithDepth(tag, childField.Type, depth+1, visited); err != nil {
return false, err
} else if required {
return true, nil
Expand Down
86 changes: 86 additions & 0 deletions core/mapping/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,89 @@ func TestValidateValueRange(t *testing.T) {
func TestSetMatchedPrimitiveValue(t *testing.T) {
assert.Error(t, setMatchedPrimitiveValue(reflect.Func, reflect.ValueOf(2), "1"))
}

func TestImplicitValueRequiredStruct_CircularReference(t *testing.T) {
type MySQLConfig struct {
Alias string `json:",optional"`
DSN string `json:",optional"`
Type string `json:",optional"`
MaxOpenConns int `json:",optional"`
MaxIdleConns int `json:",optional"`
Slave []*MySQLConfig `json:",optional"` // Self-reference slice
}

type CountryConfig struct {
MySQL MySQLConfig `json:",optional"`
}

type GlobalConfig struct {
CN CountryConfig `json:",optional"`
}

tests := []struct {
name string
tag string
tp reflect.Type
expected bool
}{
{
name: "direct circular reference - all optional",
tag: "json",
tp: reflect.TypeOf(MySQLConfig{}),
expected: false,
},
{
name: "nested circular reference - all optional",
tag: "json",
tp: reflect.TypeOf(GlobalConfig{}),
expected: false,
},
{
name: "pointer circular reference",
tag: "json",
tp: reflect.TypeOf(&MySQLConfig{}),
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := implicitValueRequiredStruct(tt.tag, tt.tp)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

func TestImplicitValueRequiredStructWithDepth_MaxDepth(t *testing.T) {
type DeepStruct struct {
Child *DeepStruct `json:",optional"`
}

visited := make(map[reflect.Type]bool)
tp := reflect.TypeOf(DeepStruct{})

result, err := implicitValueRequiredStructWithDepth("json", tp, 150, visited)
assert.NoError(t, err)
assert.False(t, result)
}

func TestImplicitValueRequiredStructWithDepth_CircularDetection(t *testing.T) {
type CircularStruct struct {
Name string `json:",optional"`
Children []*CircularStruct `json:",optional"`
}

visited := make(map[reflect.Type]bool)
tp := reflect.TypeOf(CircularStruct{})

result1, err1 := implicitValueRequiredStructWithDepth("json", tp, 0, visited)
assert.NoError(t, err1)
assert.False(t, result1)

visited[Deref(tp)] = true

result2, err2 := implicitValueRequiredStructWithDepth("json", tp, 0, visited)
assert.NoError(t, err2)
assert.False(t, result2)
}
Loading