marshal.go raw
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // This file contains a modified copy of the encoding/xml encoder.
6 // All dynamic behavior has been removed, and reflecttion has been replaced with go/types.
7 // This allows us to statically find unmarshable types
8 // with the same rules for tags, shadowing and addressability as encoding/xml.
9 // This is used for SA1026 and SA5008.
10
11 // NOTE(dh): we do not check CanInterface in various places, which means we'll accept more marshaler implementations than encoding/xml does. This will lead to a small amount of false negatives.
12
13 package fakexml
14
15 import (
16 "fmt"
17 "go/types"
18
19 "honnef.co/go/tools/go/types/typeutil"
20 "honnef.co/go/tools/knowledge"
21 "honnef.co/go/tools/staticcheck/fakereflect"
22 )
23
24 func Marshal(v types.Type) error {
25 return NewEncoder().Encode(v)
26 }
27
28 type Encoder struct {
29 // TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know
30 // if this is actually required for correctness.
31 seenCanAddr typeutil.Map[struct{}]
32 seenCantAddr typeutil.Map[struct{}]
33 }
34
35 func NewEncoder() *Encoder {
36 e := &Encoder{}
37 return e
38 }
39
40 func (enc *Encoder) Encode(v types.Type) error {
41 rv := fakereflect.TypeAndCanAddr{Type: v}
42 return enc.marshalValue(rv, nil, nil, "x")
43 }
44
45 func implementsMarshaler(v fakereflect.TypeAndCanAddr) bool {
46 t := v.Type
47 obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXML")
48 if obj == nil {
49 return false
50 }
51 fn, ok := obj.(*types.Func)
52 if !ok {
53 return false
54 }
55 params := fn.Type().(*types.Signature).Params()
56 if params.Len() != 2 {
57 return false
58 }
59 if !typeutil.IsPointerToTypeWithName(params.At(0).Type(), "encoding/xml.Encoder") {
60 return false
61 }
62 if !typeutil.IsTypeWithName(params.At(1).Type(), "encoding/xml.StartElement") {
63 return false
64 }
65 rets := fn.Type().(*types.Signature).Results()
66 if rets.Len() != 1 {
67 return false
68 }
69 if !typeutil.IsTypeWithName(rets.At(0).Type(), "error") {
70 return false
71 }
72 return true
73 }
74
75 func implementsMarshalerAttr(v fakereflect.TypeAndCanAddr) bool {
76 t := v.Type
77 obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXMLAttr")
78 if obj == nil {
79 return false
80 }
81 fn, ok := obj.(*types.Func)
82 if !ok {
83 return false
84 }
85 params := fn.Type().(*types.Signature).Params()
86 if params.Len() != 1 {
87 return false
88 }
89 if !typeutil.IsTypeWithName(params.At(0).Type(), "encoding/xml.Name") {
90 return false
91 }
92 rets := fn.Type().(*types.Signature).Results()
93 if rets.Len() != 2 {
94 return false
95 }
96 if !typeutil.IsTypeWithName(rets.At(0).Type(), "encoding/xml.Attr") {
97 return false
98 }
99 if !typeutil.IsTypeWithName(rets.At(1).Type(), "error") {
100 return false
101 }
102 return true
103 }
104
105 type CyclicTypeError struct {
106 Type types.Type
107 Path string
108 }
109
110 func (err *CyclicTypeError) Error() string {
111 return "cyclic type"
112 }
113
114 // marshalValue writes one or more XML elements representing val.
115 // If val was obtained from a struct field, finfo must have its details.
116 func (e *Encoder) marshalValue(val fakereflect.TypeAndCanAddr, finfo *fieldInfo, startTemplate *StartElement, stack string) error {
117 var m *typeutil.Map[struct{}]
118 if val.CanAddr() {
119 m = &e.seenCanAddr
120 } else {
121 m = &e.seenCantAddr
122 }
123 if _, ok := m.At(val.Type); ok {
124 return nil
125 }
126 m.Set(val.Type, struct{}{})
127
128 // Drill into interfaces and pointers.
129 seen := map[fakereflect.TypeAndCanAddr]struct{}{}
130 for val.IsInterface() || val.IsPtr() {
131 if val.IsInterface() {
132 return nil
133 }
134 val = val.Elem()
135 if _, ok := seen[val]; ok {
136 // Loop in type graph, e.g. 'type P *P'
137 return &CyclicTypeError{val.Type, stack}
138 }
139 seen[val] = struct{}{}
140 }
141
142 // Check for marshaler.
143 if implementsMarshaler(val) {
144 return nil
145 }
146 if val.CanAddr() {
147 pv := fakereflect.PtrTo(val)
148 if implementsMarshaler(pv) {
149 return nil
150 }
151 }
152
153 // Check for text marshaler.
154 if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
155 return nil
156 }
157 if val.CanAddr() {
158 pv := fakereflect.PtrTo(val)
159 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
160 return nil
161 }
162 }
163
164 // Slices and arrays iterate over the elements. They do not have an enclosing tag.
165 if (val.IsSlice() || val.IsArray()) && !isByteArray(val) && !isByteSlice(val) {
166 if err := e.marshalValue(val.Elem(), finfo, startTemplate, stack+"[0]"); err != nil {
167 return err
168 }
169 return nil
170 }
171
172 tinfo, err := getTypeInfo(val)
173 if err != nil {
174 return err
175 }
176
177 // Create start element.
178 // Precedence for the XML element name is:
179 // 0. startTemplate
180 // 1. XMLName field in underlying struct;
181 // 2. field name/tag in the struct field; and
182 // 3. type name
183 var start StartElement
184
185 if startTemplate != nil {
186 start.Name = startTemplate.Name
187 start.Attr = append(start.Attr, startTemplate.Attr...)
188 } else if tinfo.xmlname != nil {
189 xmlname := tinfo.xmlname
190 if xmlname.name != "" {
191 start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
192 }
193 }
194
195 // Attributes
196 for i := range tinfo.fields {
197 finfo := &tinfo.fields[i]
198 if finfo.flags&fAttr == 0 {
199 continue
200 }
201 fv := finfo.value(val)
202
203 name := Name{Space: finfo.xmlns, Local: finfo.name}
204 if err := e.marshalAttr(&start, name, fv, stack+pathByIndex(val, finfo.idx)); err != nil {
205 return err
206 }
207 }
208
209 if val.IsStruct() {
210 return e.marshalStruct(tinfo, val, stack)
211 } else {
212 return e.marshalSimple(val, stack)
213 }
214 }
215
216 func isSlice(v fakereflect.TypeAndCanAddr) bool {
217 _, ok := v.Type.Underlying().(*types.Slice)
218 return ok
219 }
220
221 func isByteSlice(v fakereflect.TypeAndCanAddr) bool {
222 slice, ok := v.Type.Underlying().(*types.Slice)
223 if !ok {
224 return false
225 }
226 basic, ok := slice.Elem().Underlying().(*types.Basic)
227 if !ok {
228 return false
229 }
230 return basic.Kind() == types.Uint8
231 }
232
233 func isByteArray(v fakereflect.TypeAndCanAddr) bool {
234 slice, ok := v.Type.Underlying().(*types.Array)
235 if !ok {
236 return false
237 }
238 basic, ok := slice.Elem().Underlying().(*types.Basic)
239 if !ok {
240 return false
241 }
242 return basic.Kind() == types.Uint8
243 }
244
245 // marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
246 func (e *Encoder) marshalAttr(start *StartElement, name Name, val fakereflect.TypeAndCanAddr, stack string) error {
247 if implementsMarshalerAttr(val) {
248 return nil
249 }
250
251 if val.CanAddr() {
252 pv := fakereflect.PtrTo(val)
253 if implementsMarshalerAttr(pv) {
254 return nil
255 }
256 }
257
258 if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
259 return nil
260 }
261
262 if val.CanAddr() {
263 pv := fakereflect.PtrTo(val)
264 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
265 return nil
266 }
267 }
268
269 // Dereference or skip nil pointer
270 if val.IsPtr() {
271 val = val.Elem()
272 }
273
274 // Walk slices.
275 if isSlice(val) && !isByteSlice(val) {
276 if err := e.marshalAttr(start, name, val.Elem(), stack+"[0]"); err != nil {
277 return err
278 }
279 return nil
280 }
281
282 if typeutil.IsTypeWithName(val.Type, "encoding/xml.Attr") {
283 return nil
284 }
285
286 return e.marshalSimple(val, stack)
287 }
288
289 func (e *Encoder) marshalSimple(val fakereflect.TypeAndCanAddr, stack string) error {
290 switch val.Type.Underlying().(type) {
291 case *types.Basic, *types.Interface:
292 return nil
293 case *types.Slice, *types.Array:
294 basic, ok := val.Elem().Type.Underlying().(*types.Basic)
295 if !ok || basic.Kind() != types.Uint8 {
296 return &UnsupportedTypeError{val.Type, stack}
297 }
298 return nil
299 default:
300 return &UnsupportedTypeError{val.Type, stack}
301 }
302 }
303
304 func indirect(vf fakereflect.TypeAndCanAddr) fakereflect.TypeAndCanAddr {
305 for vf.IsPtr() {
306 vf = vf.Elem()
307 }
308 return vf
309 }
310
311 func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string {
312 path := ""
313 for _, i := range index {
314 if t.IsPtr() {
315 t = t.Elem()
316 }
317 path += "." + t.Field(i).Name
318 t = t.Field(i).Type
319 }
320 return path
321 }
322
323 func (e *Encoder) marshalStruct(tinfo *typeInfo, val fakereflect.TypeAndCanAddr, stack string) error {
324 for i := range tinfo.fields {
325 finfo := &tinfo.fields[i]
326 if finfo.flags&fAttr != 0 {
327 continue
328 }
329 vf := finfo.value(val)
330
331 switch finfo.flags & fMode {
332 case fCDATA, fCharData:
333 if vf.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
334 continue
335 }
336 if vf.CanAddr() {
337 pv := fakereflect.PtrTo(vf)
338 if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
339 continue
340 }
341 }
342 continue
343
344 case fComment:
345 vf = indirect(vf)
346 if !(isByteSlice(vf) || isByteArray(vf)) {
347 return fmt.Errorf("xml: bad type for comment field of %s", val)
348 }
349 continue
350
351 case fInnerXML:
352 vf = indirect(vf)
353 if t, ok := vf.Type.(*types.Slice); (ok && types.Identical(t.Elem(), types.Typ[types.Byte])) || types.Identical(vf.Type, types.Typ[types.String]) {
354 continue
355 }
356
357 case fElement, fElement | fAny:
358 }
359 if err := e.marshalValue(vf, finfo, nil, stack+pathByIndex(val, finfo.idx)); err != nil {
360 return err
361 }
362 }
363 return nil
364 }
365
366 // UnsupportedTypeError is returned when Marshal encounters a type
367 // that cannot be converted into XML.
368 type UnsupportedTypeError struct {
369 Type types.Type
370 Path string
371 }
372
373 func (e *UnsupportedTypeError) Error() string {
374 return fmt.Sprintf("xml: unsupported type %s, via %s ", e.Type, e.Path)
375 }
376