-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathvalidator.go
289 lines (262 loc) · 7.84 KB
/
validator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
package validator
import (
"bytes"
"encoding/xml"
"errors"
"fmt"
"io"
)
// XMLRoundtripError is returned when a round-trip token doesn't match the original
type XMLRoundtripError struct {
Expected, Observed xml.Token
Overflow []byte
}
func (err XMLRoundtripError) Error() string {
if len(err.Overflow) == 0 {
return fmt.Sprintf("roundtrip error: expected %v, observed %v", err.Expected, err.Observed)
}
return fmt.Sprintf("roundtrip error: unexpected overflow after token: %s", err.Overflow)
}
// XMLValidationError is returned when validating an XML document fails
type XMLValidationError struct {
Start, End, Line, Column int64
err error
}
func (err XMLValidationError) Error() string {
return fmt.Sprintf("validator: in token starting at %d:%d: %s", err.Line, err.Column, err.err.Error())
}
func (err XMLValidationError) Unwrap() error {
return err.err
}
// Validate makes sure the given XML bytes survive round trips through encoding/xml without mutations
func Validate(xmlReader io.Reader) error {
xmlBuffer := &bytes.Buffer{}
xmlReader = &byteReader{io.TeeReader(xmlReader, xmlBuffer)}
decoder := xml.NewDecoder(xmlReader)
decoder.Strict = false
decoder.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { return input, nil }
offset := int64(0)
for {
token, err := decoder.RawToken()
if errors.Is(err, io.EOF) {
return nil
} else if err != nil {
return err
}
if err := CheckToken(token); err != nil {
xmlBytes := xmlBuffer.Bytes()
line := bytes.Count(xmlBytes[0:offset], []byte{'\n'}) + 1
lineStart := int64(bytes.LastIndexByte(xmlBytes[0:offset], '\n')) + 1
column := offset - lineStart + 1
return XMLValidationError{
Start: offset,
End: decoder.InputOffset(),
Line: int64(line),
Column: column,
err: err,
}
}
offset = decoder.InputOffset()
}
}
// ValidateAll is like Validate, but instead of returning after the first error,
// it accumulates errors and validates the entire document
func ValidateAll(xmlReader io.Reader) []error {
xmlBuffer := &bytes.Buffer{}
xmlReader = io.TeeReader(xmlReader, xmlBuffer)
errs := []error{}
offset := int64(0)
line := int64(1)
column := int64(1)
for {
err := Validate(xmlReader)
if err == nil {
// reached the end with no additional errors
break
}
validationError := XMLValidationError{}
if errors.As(err, &validationError) {
// validation errors contain line numbers and offsets, but
// these offsets are based on the offset where Validate
// was called, so they need to be adjusted to accordingly
validationError.Start += offset
validationError.End += offset
if validationError.Line == 1 {
validationError.Column += column - 1
}
validationError.Line += line - 1
errs = append(errs, validationError)
xmlBytes := xmlBuffer.Bytes()
offset += int64(len(xmlBytes))
newLines := int64(bytes.Count(xmlBytes, []byte("\n")))
line += newLines
if newLines > 0 {
column = int64(len(xmlBytes) - bytes.LastIndex(xmlBytes, []byte("\n")))
} else {
column += int64(len(xmlBytes))
}
xmlBuffer.Reset()
} else {
// this was not a validation error, but likely
// completely unparseable XML instead; no point
// in trying to continue
errs = append(errs, err)
break
}
}
return errs
}
// bufio implements a ByteReader but we explicitly don't want any buffering
type byteReader struct {
r io.Reader
}
func (r *byteReader) ReadByte() (byte, error) {
var p [1]byte
n, err := r.r.Read(p[:])
// The doc for the io.ByteReader interface states:
// If ReadByte returns an error, no input byte was consumed, and the returned byte value is undefined.
// So if a byte is actually extracted from the reader, and we want to return it, we mustn't return the error.
if n > 0 {
// this byteReader is only used in the context of the Validate() function,
// we deliberately choose to completely ignore the error in this case.
// return the byte extracted from the reader
return p[0], nil
}
return 0, err
}
func (r *byteReader) Read(p []byte) (int, error) {
return r.r.Read(p)
}
// CheckToken computes a round trip for a given xml.Token and returns an
// error if the newly calculated token differs from the original
func CheckToken(before xml.Token) error {
buffer := &bytes.Buffer{}
encoder := xml.NewEncoder(buffer)
switch t := before.(type) { // nolint:gocritic
case xml.EndElement:
// xml.Encoder expects matching StartElements for all EndElements
if err := encoder.EncodeToken(xml.StartElement{Name: t.Name}); err != nil {
return err
}
}
if err := encoder.EncodeToken(before); err != nil {
return err
}
if err := encoder.Flush(); err != nil {
return err
}
encoded := buffer.Bytes()
decoder := xml.NewDecoder(bytes.NewReader(encoded))
decoder.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { return input, nil }
switch before.(type) { // nolint:gocritic
case xml.EndElement:
if _, err := decoder.RawToken(); err != nil {
return err
}
}
after, err := decoder.RawToken()
if err != nil {
return err
}
if !tokenEquals(before, after) {
return XMLRoundtripError{before, after, nil}
}
offset := decoder.InputOffset()
if offset != int64(len(encoded)) {
// this is likely unreachable, but just in case
return XMLRoundtripError{before, after, encoded[offset:]}
}
return nil
}
func tokenEquals(before, after xml.Token) bool {
switch t1 := before.(type) {
case xml.CharData:
t2, ok := after.(xml.CharData)
if !ok {
return false
}
return bytes.Equal(t1, t2)
case xml.Comment:
t2, ok := after.(xml.Comment)
if !ok {
return false
}
return bytes.Equal(t1, t2)
case xml.Directive:
t2, ok := after.(xml.Directive)
if !ok {
return false
}
return bytes.Equal(t1, t2)
case xml.EndElement:
t2, ok := after.(xml.EndElement)
if !ok {
return false
}
// local name should equal; namespace prefixes get erased
return t1.Name.Local == t2.Name.Local && t2.Name.Space == ""
case xml.ProcInst:
t2, ok := after.(xml.ProcInst)
if !ok {
return false
}
return t1.Target == t2.Target && bytes.Equal(t1.Inst, t2.Inst)
case xml.StartElement:
t2, ok := after.(xml.StartElement)
if !ok {
return false
}
// encoding/xml messes up namespace prefixes on both tag and attribute names;
// they need adjusting to make the comparison possible
fixNamespacePrefixes(&t1, &t2)
if t1.Name != t2.Name {
return false
}
if len(t1.Attr) != len(t2.Attr) {
return false
}
// after the call to fixNamespacePrefixes, all attributes should match;
// ordering is preserved
for i, attr := range t1.Attr {
if attr != t2.Attr[i] {
return false
}
}
return true
}
return false
}
func fixNamespacePrefixes(before, after *xml.StartElement) {
// if the after token has more attributes than the before token,
// the round trip likely introduced new xmlns attributes
if len(after.Attr) > len(before.Attr) {
// handle erased tag prefixes; the corresponding xmlns attribute is always the first one
if (before.Name.Space != "" && after.Name.Space == "" && after.Attr[0].Name == xml.Name{Local: "xmlns"}) {
after.Name.Space = after.Attr[0].Value
after.Attr = after.Attr[1:]
}
// handle attribute prefixes; the xmlns attribute always comes immediately before the prefixed attribute
for len(after.Attr) > len(before.Attr) && len(after.Attr) > 1 {
var xmlns *xml.Attr
i := 1
for ; i < len(after.Attr); i++ {
if after.Attr[i-1].Name.Space == "xmlns" && after.Attr[i-1].Name.Local == after.Attr[i].Name.Space {
xmlns = &after.Attr[i-1]
break
}
}
if xmlns == nil {
break
}
prefix := xmlns.Name.Local
space := xmlns.Value
copy(after.Attr[i-1:], after.Attr[i:])
after.Attr = after.Attr[:len(after.Attr)-1]
for j := range after.Attr {
if after.Attr[j].Name.Space == prefix {
after.Attr[j].Name.Space = space
}
}
}
}
}