marshal.go raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // This file contains a modified copy of the encoding/xml encoder.
   6  // All dynamic behavior has been removed, and reflecttion has been replaced with go/types.
   7  // This allows us to statically find unmarshable types
   8  // with the same rules for tags, shadowing and addressability as encoding/xml.
   9  // This is used for SA1026 and SA5008.
  10  
  11  // NOTE(dh): we do not check CanInterface in various places, which means we'll accept more marshaler implementations than encoding/xml does. This will lead to a small amount of false negatives.
  12  
  13  package fakexml
  14  
  15  import (
  16  	"fmt"
  17  	"go/types"
  18  
  19  	"honnef.co/go/tools/go/types/typeutil"
  20  	"honnef.co/go/tools/knowledge"
  21  	"honnef.co/go/tools/staticcheck/fakereflect"
  22  )
  23  
  24  func Marshal(v types.Type) error {
  25  	return NewEncoder().Encode(v)
  26  }
  27  
  28  type Encoder struct {
  29  	// TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know
  30  	// if this is actually required for correctness.
  31  	seenCanAddr  typeutil.Map[struct{}]
  32  	seenCantAddr typeutil.Map[struct{}]
  33  }
  34  
  35  func NewEncoder() *Encoder {
  36  	e := &Encoder{}
  37  	return e
  38  }
  39  
  40  func (enc *Encoder) Encode(v types.Type) error {
  41  	rv := fakereflect.TypeAndCanAddr{Type: v}
  42  	return enc.marshalValue(rv, nil, nil, "x")
  43  }
  44  
  45  func implementsMarshaler(v fakereflect.TypeAndCanAddr) bool {
  46  	t := v.Type
  47  	obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXML")
  48  	if obj == nil {
  49  		return false
  50  	}
  51  	fn, ok := obj.(*types.Func)
  52  	if !ok {
  53  		return false
  54  	}
  55  	params := fn.Type().(*types.Signature).Params()
  56  	if params.Len() != 2 {
  57  		return false
  58  	}
  59  	if !typeutil.IsPointerToTypeWithName(params.At(0).Type(), "encoding/xml.Encoder") {
  60  		return false
  61  	}
  62  	if !typeutil.IsTypeWithName(params.At(1).Type(), "encoding/xml.StartElement") {
  63  		return false
  64  	}
  65  	rets := fn.Type().(*types.Signature).Results()
  66  	if rets.Len() != 1 {
  67  		return false
  68  	}
  69  	if !typeutil.IsTypeWithName(rets.At(0).Type(), "error") {
  70  		return false
  71  	}
  72  	return true
  73  }
  74  
  75  func implementsMarshalerAttr(v fakereflect.TypeAndCanAddr) bool {
  76  	t := v.Type
  77  	obj, _, _ := types.LookupFieldOrMethod(t, false, nil, "MarshalXMLAttr")
  78  	if obj == nil {
  79  		return false
  80  	}
  81  	fn, ok := obj.(*types.Func)
  82  	if !ok {
  83  		return false
  84  	}
  85  	params := fn.Type().(*types.Signature).Params()
  86  	if params.Len() != 1 {
  87  		return false
  88  	}
  89  	if !typeutil.IsTypeWithName(params.At(0).Type(), "encoding/xml.Name") {
  90  		return false
  91  	}
  92  	rets := fn.Type().(*types.Signature).Results()
  93  	if rets.Len() != 2 {
  94  		return false
  95  	}
  96  	if !typeutil.IsTypeWithName(rets.At(0).Type(), "encoding/xml.Attr") {
  97  		return false
  98  	}
  99  	if !typeutil.IsTypeWithName(rets.At(1).Type(), "error") {
 100  		return false
 101  	}
 102  	return true
 103  }
 104  
 105  type CyclicTypeError struct {
 106  	Type types.Type
 107  	Path string
 108  }
 109  
 110  func (err *CyclicTypeError) Error() string {
 111  	return "cyclic type"
 112  }
 113  
 114  // marshalValue writes one or more XML elements representing val.
 115  // If val was obtained from a struct field, finfo must have its details.
 116  func (e *Encoder) marshalValue(val fakereflect.TypeAndCanAddr, finfo *fieldInfo, startTemplate *StartElement, stack string) error {
 117  	var m *typeutil.Map[struct{}]
 118  	if val.CanAddr() {
 119  		m = &e.seenCanAddr
 120  	} else {
 121  		m = &e.seenCantAddr
 122  	}
 123  	if _, ok := m.At(val.Type); ok {
 124  		return nil
 125  	}
 126  	m.Set(val.Type, struct{}{})
 127  
 128  	// Drill into interfaces and pointers.
 129  	seen := map[fakereflect.TypeAndCanAddr]struct{}{}
 130  	for val.IsInterface() || val.IsPtr() {
 131  		if val.IsInterface() {
 132  			return nil
 133  		}
 134  		val = val.Elem()
 135  		if _, ok := seen[val]; ok {
 136  			// Loop in type graph, e.g. 'type P *P'
 137  			return &CyclicTypeError{val.Type, stack}
 138  		}
 139  		seen[val] = struct{}{}
 140  	}
 141  
 142  	// Check for marshaler.
 143  	if implementsMarshaler(val) {
 144  		return nil
 145  	}
 146  	if val.CanAddr() {
 147  		pv := fakereflect.PtrTo(val)
 148  		if implementsMarshaler(pv) {
 149  			return nil
 150  		}
 151  	}
 152  
 153  	// Check for text marshaler.
 154  	if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 155  		return nil
 156  	}
 157  	if val.CanAddr() {
 158  		pv := fakereflect.PtrTo(val)
 159  		if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 160  			return nil
 161  		}
 162  	}
 163  
 164  	// Slices and arrays iterate over the elements. They do not have an enclosing tag.
 165  	if (val.IsSlice() || val.IsArray()) && !isByteArray(val) && !isByteSlice(val) {
 166  		if err := e.marshalValue(val.Elem(), finfo, startTemplate, stack+"[0]"); err != nil {
 167  			return err
 168  		}
 169  		return nil
 170  	}
 171  
 172  	tinfo, err := getTypeInfo(val)
 173  	if err != nil {
 174  		return err
 175  	}
 176  
 177  	// Create start element.
 178  	// Precedence for the XML element name is:
 179  	// 0. startTemplate
 180  	// 1. XMLName field in underlying struct;
 181  	// 2. field name/tag in the struct field; and
 182  	// 3. type name
 183  	var start StartElement
 184  
 185  	if startTemplate != nil {
 186  		start.Name = startTemplate.Name
 187  		start.Attr = append(start.Attr, startTemplate.Attr...)
 188  	} else if tinfo.xmlname != nil {
 189  		xmlname := tinfo.xmlname
 190  		if xmlname.name != "" {
 191  			start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
 192  		}
 193  	}
 194  
 195  	// Attributes
 196  	for i := range tinfo.fields {
 197  		finfo := &tinfo.fields[i]
 198  		if finfo.flags&fAttr == 0 {
 199  			continue
 200  		}
 201  		fv := finfo.value(val)
 202  
 203  		name := Name{Space: finfo.xmlns, Local: finfo.name}
 204  		if err := e.marshalAttr(&start, name, fv, stack+pathByIndex(val, finfo.idx)); err != nil {
 205  			return err
 206  		}
 207  	}
 208  
 209  	if val.IsStruct() {
 210  		return e.marshalStruct(tinfo, val, stack)
 211  	} else {
 212  		return e.marshalSimple(val, stack)
 213  	}
 214  }
 215  
 216  func isSlice(v fakereflect.TypeAndCanAddr) bool {
 217  	_, ok := v.Type.Underlying().(*types.Slice)
 218  	return ok
 219  }
 220  
 221  func isByteSlice(v fakereflect.TypeAndCanAddr) bool {
 222  	slice, ok := v.Type.Underlying().(*types.Slice)
 223  	if !ok {
 224  		return false
 225  	}
 226  	basic, ok := slice.Elem().Underlying().(*types.Basic)
 227  	if !ok {
 228  		return false
 229  	}
 230  	return basic.Kind() == types.Uint8
 231  }
 232  
 233  func isByteArray(v fakereflect.TypeAndCanAddr) bool {
 234  	slice, ok := v.Type.Underlying().(*types.Array)
 235  	if !ok {
 236  		return false
 237  	}
 238  	basic, ok := slice.Elem().Underlying().(*types.Basic)
 239  	if !ok {
 240  		return false
 241  	}
 242  	return basic.Kind() == types.Uint8
 243  }
 244  
 245  // marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
 246  func (e *Encoder) marshalAttr(start *StartElement, name Name, val fakereflect.TypeAndCanAddr, stack string) error {
 247  	if implementsMarshalerAttr(val) {
 248  		return nil
 249  	}
 250  
 251  	if val.CanAddr() {
 252  		pv := fakereflect.PtrTo(val)
 253  		if implementsMarshalerAttr(pv) {
 254  			return nil
 255  		}
 256  	}
 257  
 258  	if val.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 259  		return nil
 260  	}
 261  
 262  	if val.CanAddr() {
 263  		pv := fakereflect.PtrTo(val)
 264  		if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 265  			return nil
 266  		}
 267  	}
 268  
 269  	// Dereference or skip nil pointer
 270  	if val.IsPtr() {
 271  		val = val.Elem()
 272  	}
 273  
 274  	// Walk slices.
 275  	if isSlice(val) && !isByteSlice(val) {
 276  		if err := e.marshalAttr(start, name, val.Elem(), stack+"[0]"); err != nil {
 277  			return err
 278  		}
 279  		return nil
 280  	}
 281  
 282  	if typeutil.IsTypeWithName(val.Type, "encoding/xml.Attr") {
 283  		return nil
 284  	}
 285  
 286  	return e.marshalSimple(val, stack)
 287  }
 288  
 289  func (e *Encoder) marshalSimple(val fakereflect.TypeAndCanAddr, stack string) error {
 290  	switch val.Type.Underlying().(type) {
 291  	case *types.Basic, *types.Interface:
 292  		return nil
 293  	case *types.Slice, *types.Array:
 294  		basic, ok := val.Elem().Type.Underlying().(*types.Basic)
 295  		if !ok || basic.Kind() != types.Uint8 {
 296  			return &UnsupportedTypeError{val.Type, stack}
 297  		}
 298  		return nil
 299  	default:
 300  		return &UnsupportedTypeError{val.Type, stack}
 301  	}
 302  }
 303  
 304  func indirect(vf fakereflect.TypeAndCanAddr) fakereflect.TypeAndCanAddr {
 305  	for vf.IsPtr() {
 306  		vf = vf.Elem()
 307  	}
 308  	return vf
 309  }
 310  
 311  func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string {
 312  	path := ""
 313  	for _, i := range index {
 314  		if t.IsPtr() {
 315  			t = t.Elem()
 316  		}
 317  		path += "." + t.Field(i).Name
 318  		t = t.Field(i).Type
 319  	}
 320  	return path
 321  }
 322  
 323  func (e *Encoder) marshalStruct(tinfo *typeInfo, val fakereflect.TypeAndCanAddr, stack string) error {
 324  	for i := range tinfo.fields {
 325  		finfo := &tinfo.fields[i]
 326  		if finfo.flags&fAttr != 0 {
 327  			continue
 328  		}
 329  		vf := finfo.value(val)
 330  
 331  		switch finfo.flags & fMode {
 332  		case fCDATA, fCharData:
 333  			if vf.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 334  				continue
 335  			}
 336  			if vf.CanAddr() {
 337  				pv := fakereflect.PtrTo(vf)
 338  				if pv.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 339  					continue
 340  				}
 341  			}
 342  			continue
 343  
 344  		case fComment:
 345  			vf = indirect(vf)
 346  			if !(isByteSlice(vf) || isByteArray(vf)) {
 347  				return fmt.Errorf("xml: bad type for comment field of %s", val)
 348  			}
 349  			continue
 350  
 351  		case fInnerXML:
 352  			vf = indirect(vf)
 353  			if t, ok := vf.Type.(*types.Slice); (ok && types.Identical(t.Elem(), types.Typ[types.Byte])) || types.Identical(vf.Type, types.Typ[types.String]) {
 354  				continue
 355  			}
 356  
 357  		case fElement, fElement | fAny:
 358  		}
 359  		if err := e.marshalValue(vf, finfo, nil, stack+pathByIndex(val, finfo.idx)); err != nil {
 360  			return err
 361  		}
 362  	}
 363  	return nil
 364  }
 365  
 366  // UnsupportedTypeError is returned when Marshal encounters a type
 367  // that cannot be converted into XML.
 368  type UnsupportedTypeError struct {
 369  	Type types.Type
 370  	Path string
 371  }
 372  
 373  func (e *UnsupportedTypeError) Error() string {
 374  	return fmt.Sprintf("xml: unsupported type %s, via %s ", e.Type, e.Path)
 375  }
 376