Skip to content

Commit 8f17441

Browse files
authored
Fix format with alias (#648)
* fix format with alias
1 parent c331468 commit 8f17441

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

decode.go

+1-6
Original file line numberDiff line numberDiff line change
@@ -758,12 +758,7 @@ func (d *Decoder) deleteStructKeys(structType reflect.Type, unknownFields map[st
758758
}
759759

760760
func (d *Decoder) unmarshalableDocument(node ast.Node) ([]byte, error) {
761-
var err error
762-
node, err = d.resolveAlias(node)
763-
if err != nil {
764-
return nil, err
765-
}
766-
doc := format.FormatNode(node)
761+
doc := format.FormatNodeWithResolvedAlias(node, d.anchorNodeMap)
767762
return []byte(doc), nil
768763
}
769764

decode_test.go

+35
Original file line numberDiff line numberDiff line change
@@ -3151,6 +3151,41 @@ func TestMapKeyCustomUnmarshaler(t *testing.T) {
31513151
}
31523152
}
31533153

3154+
type bytesUnmershalerWithMapAlias struct{}
3155+
3156+
func (*bytesUnmershalerWithMapAlias) UnmarshalYAML(b []byte) error {
3157+
expected := strings.TrimPrefix(`
3158+
stuff:
3159+
bar:
3160+
- one
3161+
- two
3162+
3163+
`, "\n")
3164+
if string(b) != expected {
3165+
return fmt.Errorf("failed to decode: expected:\n[%s]\nbut got:\n[%s]\n", expected, string(b))
3166+
}
3167+
return nil
3168+
}
3169+
3170+
func TestBytesUnmarshalerWithMapAlias(t *testing.T) {
3171+
yml := `
3172+
x-foo: &data
3173+
bar:
3174+
- one
3175+
- two
3176+
3177+
foo:
3178+
stuff: *data
3179+
`
3180+
type T struct {
3181+
Foo bytesUnmershalerWithMapAlias `yaml:"foo"`
3182+
}
3183+
var v T
3184+
if err := yaml.Unmarshal([]byte(yml), &v); err != nil {
3185+
t.Fatal(err)
3186+
}
3187+
}
3188+
31543189
func TestDecoderPreservesDefaultValues(t *testing.T) {
31553190
type nested struct {
31563191
Val string `yaml:"val"`

internal/format/format.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ import (
77
"github.com/goccy/go-yaml/token"
88
)
99

10+
func FormatNodeWithResolvedAlias(n ast.Node, anchorNodeMap map[string]ast.Node) string {
11+
tk := n.GetToken()
12+
if tk == nil {
13+
return ""
14+
}
15+
formatter := newFormatter(tk, hasComment(n))
16+
formatter.anchorNodeMap = anchorNodeMap
17+
return formatter.format(n)
18+
}
19+
1020
func FormatNode(n ast.Node) string {
1121
tk := n.GetToken()
1222
if tk == nil {
@@ -124,6 +134,7 @@ func hasComment(n ast.Node) bool {
124134
type Formatter struct {
125135
existsComment bool
126136
tokenToOriginMap map[*token.Token]string
137+
anchorNodeMap map[string]ast.Node
127138
}
128139

129140
func newFormatter(tk *token.Token, existsComment bool) *Formatter {
@@ -294,6 +305,19 @@ func (f *Formatter) formatAnchor(n *ast.AnchorNode) string {
294305
}
295306

296307
func (f *Formatter) formatAlias(n *ast.AliasNode) string {
308+
if f.anchorNodeMap != nil {
309+
node := f.anchorNodeMap[n.Value.GetToken().Value]
310+
if node != nil {
311+
formatted := f.formatNode(node)
312+
// If formatted text contains newline characters, indentation needs to be considered.
313+
if strings.Contains(formatted, "\n") {
314+
// If the first character is not a newline, the first line should be output without indentation.
315+
isIgnoredFirstLine := !strings.HasPrefix(formatted, "\n")
316+
formatted = f.addIndentSpace(n.GetToken().Position.IndentNum, formatted, isIgnoredFirstLine)
317+
}
318+
return formatted
319+
}
320+
}
297321
return f.origin(n.Start) + f.formatNode(n.Value)
298322
}
299323

@@ -385,7 +409,7 @@ func (f *Formatter) trimIndentSpace(trimIndentNum int, v string) string {
385409
}
386410
lines := strings.Split(normalizeNewLineChars(v), "\n")
387411
out := make([]string, 0, len(lines))
388-
for _, line := range strings.Split(v, "\n") {
412+
for _, line := range lines {
389413
var cnt int
390414
out = append(out, strings.TrimLeftFunc(line, func(r rune) bool {
391415
cnt++
@@ -395,6 +419,23 @@ func (f *Formatter) trimIndentSpace(trimIndentNum int, v string) string {
395419
return strings.Join(out, "\n")
396420
}
397421

422+
func (f *Formatter) addIndentSpace(indentNum int, v string, isIgnoredFirstLine bool) string {
423+
if indentNum == 0 {
424+
return v
425+
}
426+
indent := strings.Repeat(" ", indentNum)
427+
lines := strings.Split(normalizeNewLineChars(v), "\n")
428+
out := make([]string, 0, len(lines))
429+
for idx, line := range lines {
430+
if line == "" || (isIgnoredFirstLine && idx == 0) {
431+
out = append(out, line)
432+
continue
433+
}
434+
out = append(out, indent+line)
435+
}
436+
return strings.Join(out, "\n")
437+
}
438+
398439
// normalizeNewLineChars normalize CRLF and CR to LF.
399440
func normalizeNewLineChars(v string) string {
400441
return strings.ReplaceAll(strings.ReplaceAll(v, "\r\n", "\n"), "\r", "\n")

0 commit comments

Comments
 (0)