encode.go raw

   1  // Copyright 2010 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // This file contains a modified copy of the encoding/json encoder.
   6  // All dynamic behavior has been removed, and reflecttion has been replaced with go/types.
   7  // This allows us to statically find unmarshable types
   8  // with the same rules for tags, shadowing and addressability as encoding/json.
   9  // This is used for SA1026.
  10  
  11  package fakejson
  12  
  13  import (
  14  	"go/types"
  15  	"sort"
  16  	"strings"
  17  	"unicode"
  18  
  19  	"golang.org/x/exp/typeparams"
  20  	"honnef.co/go/tools/go/types/typeutil"
  21  	"honnef.co/go/tools/knowledge"
  22  	"honnef.co/go/tools/staticcheck/fakereflect"
  23  )
  24  
  25  // parseTag splits a struct field's json tag into its name and
  26  // comma-separated options.
  27  func parseTag(tag string) string {
  28  	if idx := strings.Index(tag, ","); idx != -1 {
  29  		return tag[:idx]
  30  	}
  31  	return tag
  32  }
  33  
  34  func Marshal(v types.Type) *UnsupportedTypeError {
  35  	enc := encoder{}
  36  	return enc.newTypeEncoder(fakereflect.TypeAndCanAddr{Type: v}, "x")
  37  }
  38  
  39  // An UnsupportedTypeError is returned by Marshal when attempting
  40  // to encode an unsupported value type.
  41  type UnsupportedTypeError struct {
  42  	Type types.Type
  43  	Path string
  44  }
  45  
  46  type encoder struct {
  47  	// TODO we track addressable and non-addressable instances separately out of an abundance of caution. We don't know
  48  	// if this is actually required for correctness.
  49  	seenCanAddr  typeutil.Map[struct{}]
  50  	seenCantAddr typeutil.Map[struct{}]
  51  }
  52  
  53  func (enc *encoder) newTypeEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
  54  	var m *typeutil.Map[struct{}]
  55  	if t.CanAddr() {
  56  		m = &enc.seenCanAddr
  57  	} else {
  58  		m = &enc.seenCantAddr
  59  	}
  60  	if _, ok := m.At(t.Type); ok {
  61  		return nil
  62  	}
  63  	m.Set(t.Type, struct{}{})
  64  
  65  	if t.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) {
  66  		return nil
  67  	}
  68  	if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding/json.Marshaler"]) {
  69  		return nil
  70  	}
  71  	if t.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
  72  		return nil
  73  	}
  74  	if !t.IsPtr() && t.CanAddr() && fakereflect.PtrTo(t).Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
  75  		return nil
  76  	}
  77  
  78  	switch t.Type.Underlying().(type) {
  79  	case *types.Basic, *types.Interface:
  80  		return nil
  81  	case *types.Struct:
  82  		return enc.typeFields(t, stack)
  83  	case *types.Map:
  84  		return enc.newMapEncoder(t, stack)
  85  	case *types.Slice:
  86  		return enc.newSliceEncoder(t, stack)
  87  	case *types.Array:
  88  		return enc.newArrayEncoder(t, stack)
  89  	case *types.Pointer:
  90  		// we don't have to express the pointer dereference in the path; x.f is syntactic sugar for (*x).f
  91  		return enc.newTypeEncoder(t.Elem(), stack)
  92  	default:
  93  		return &UnsupportedTypeError{t.Type, stack}
  94  	}
  95  }
  96  
  97  func (enc *encoder) newMapEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
  98  	if typeparams.IsTypeParam(t.Key().Type) {
  99  		// We don't know enough about the concrete instantiation to say much about the key. The only time we could make
 100  		// a definite "this key is bad" statement is if the type parameter is constrained by type terms, none of which
 101  		// are tilde terms, none of which are a basic type. In all other cases, the key might implement TextMarshaler.
 102  		// It doesn't seem worth checking for that one single case.
 103  		return enc.newTypeEncoder(t.Elem(), stack+"[k]")
 104  	}
 105  
 106  	switch t.Key().Type.Underlying().(type) {
 107  	case *types.Basic:
 108  	default:
 109  		if !t.Key().Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 110  			return &UnsupportedTypeError{
 111  				Type: t.Type,
 112  				Path: stack,
 113  			}
 114  		}
 115  	}
 116  	return enc.newTypeEncoder(t.Elem(), stack+"[k]")
 117  }
 118  
 119  func (enc *encoder) newSliceEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
 120  	// Byte slices get special treatment; arrays don't.
 121  	basic, ok := t.Elem().Type.Underlying().(*types.Basic)
 122  	if ok && basic.Kind() == types.Uint8 {
 123  		p := fakereflect.PtrTo(t.Elem())
 124  		if !p.Implements(knowledge.Interfaces["encoding/json.Marshaler"]) && !p.Implements(knowledge.Interfaces["encoding.TextMarshaler"]) {
 125  			return nil
 126  		}
 127  	}
 128  	return enc.newArrayEncoder(t, stack)
 129  }
 130  
 131  func (enc *encoder) newArrayEncoder(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
 132  	return enc.newTypeEncoder(t.Elem(), stack+"[0]")
 133  }
 134  
 135  func isValidTag(s string) bool {
 136  	if s == "" {
 137  		return false
 138  	}
 139  	for _, c := range s {
 140  		switch {
 141  		case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c):
 142  			// Backslash and quote chars are reserved, but
 143  			// otherwise any punctuation chars are allowed
 144  			// in a tag name.
 145  		case !unicode.IsLetter(c) && !unicode.IsDigit(c):
 146  			return false
 147  		}
 148  	}
 149  	return true
 150  }
 151  
 152  func typeByIndex(t fakereflect.TypeAndCanAddr, index []int) fakereflect.TypeAndCanAddr {
 153  	for _, i := range index {
 154  		if t.IsPtr() {
 155  			t = t.Elem()
 156  		}
 157  		t = t.Field(i).Type
 158  	}
 159  	return t
 160  }
 161  
 162  func pathByIndex(t fakereflect.TypeAndCanAddr, index []int) string {
 163  	path := ""
 164  	for _, i := range index {
 165  		if t.IsPtr() {
 166  			t = t.Elem()
 167  		}
 168  		path += "." + t.Field(i).Name
 169  		t = t.Field(i).Type
 170  	}
 171  	return path
 172  }
 173  
 174  // A field represents a single field found in a struct.
 175  type field struct {
 176  	name string
 177  
 178  	tag   bool
 179  	index []int
 180  	typ   fakereflect.TypeAndCanAddr
 181  }
 182  
 183  // byIndex sorts field by index sequence.
 184  type byIndex []field
 185  
 186  func (x byIndex) Len() int { return len(x) }
 187  
 188  func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
 189  
 190  func (x byIndex) Less(i, j int) bool {
 191  	for k, xik := range x[i].index {
 192  		if k >= len(x[j].index) {
 193  			return false
 194  		}
 195  		if xik != x[j].index[k] {
 196  			return xik < x[j].index[k]
 197  		}
 198  	}
 199  	return len(x[i].index) < len(x[j].index)
 200  }
 201  
 202  // typeFields returns a list of fields that JSON should recognize for the given type.
 203  // The algorithm is breadth-first search over the set of structs to include - the top struct
 204  // and then any reachable anonymous structs.
 205  func (enc *encoder) typeFields(t fakereflect.TypeAndCanAddr, stack string) *UnsupportedTypeError {
 206  	// Anonymous fields to explore at the current level and the next.
 207  	current := []field{}
 208  	next := []field{{typ: t}}
 209  
 210  	// Count of queued names for current level and the next.
 211  	var count, nextCount map[fakereflect.TypeAndCanAddr]int
 212  
 213  	// Types already visited at an earlier level.
 214  	visited := map[fakereflect.TypeAndCanAddr]bool{}
 215  
 216  	// Fields found.
 217  	var fields []field
 218  
 219  	for len(next) > 0 {
 220  		current, next = next, current[:0]
 221  		count, nextCount = nextCount, map[fakereflect.TypeAndCanAddr]int{}
 222  
 223  		for _, f := range current {
 224  			if visited[f.typ] {
 225  				continue
 226  			}
 227  			visited[f.typ] = true
 228  
 229  			// Scan f.typ for fields to include.
 230  			for i := 0; i < f.typ.NumField(); i++ {
 231  				sf := f.typ.Field(i)
 232  				if sf.Anonymous {
 233  					t := sf.Type
 234  					if t.IsPtr() {
 235  						t = t.Elem()
 236  					}
 237  					if !sf.IsExported() && !t.IsStruct() {
 238  						// Ignore embedded fields of unexported non-struct types.
 239  						continue
 240  					}
 241  					// Do not ignore embedded fields of unexported struct types
 242  					// since they may have exported fields.
 243  				} else if !sf.IsExported() {
 244  					// Ignore unexported non-embedded fields.
 245  					continue
 246  				}
 247  				tag := sf.Tag.Get("json")
 248  				if tag == "-" {
 249  					continue
 250  				}
 251  				name := parseTag(tag)
 252  				if !isValidTag(name) {
 253  					name = ""
 254  				}
 255  				index := make([]int, len(f.index)+1)
 256  				copy(index, f.index)
 257  				index[len(f.index)] = i
 258  
 259  				ft := sf.Type
 260  				if ft.Name() == "" && ft.IsPtr() {
 261  					// Follow pointer.
 262  					ft = ft.Elem()
 263  				}
 264  
 265  				// Record found field and index sequence.
 266  				if name != "" || !sf.Anonymous || !ft.IsStruct() {
 267  					tagged := name != ""
 268  					if name == "" {
 269  						name = sf.Name
 270  					}
 271  					field := field{
 272  						name:  name,
 273  						tag:   tagged,
 274  						index: index,
 275  						typ:   ft,
 276  					}
 277  
 278  					fields = append(fields, field)
 279  					if count[f.typ] > 1 {
 280  						// If there were multiple instances, add a second,
 281  						// so that the annihilation code will see a duplicate.
 282  						// It only cares about the distinction between 1 or 2,
 283  						// so don't bother generating any more copies.
 284  						fields = append(fields, fields[len(fields)-1])
 285  					}
 286  					continue
 287  				}
 288  
 289  				// Record new anonymous struct to explore in next round.
 290  				nextCount[ft]++
 291  				if nextCount[ft] == 1 {
 292  					next = append(next, field{name: ft.Name(), index: index, typ: ft})
 293  				}
 294  			}
 295  		}
 296  	}
 297  
 298  	sort.Slice(fields, func(i, j int) bool {
 299  		x := fields
 300  		// sort field by name, breaking ties with depth, then
 301  		// breaking ties with "name came from json tag", then
 302  		// breaking ties with index sequence.
 303  		if x[i].name != x[j].name {
 304  			return x[i].name < x[j].name
 305  		}
 306  		if len(x[i].index) != len(x[j].index) {
 307  			return len(x[i].index) < len(x[j].index)
 308  		}
 309  		if x[i].tag != x[j].tag {
 310  			return x[i].tag
 311  		}
 312  		return byIndex(x).Less(i, j)
 313  	})
 314  
 315  	// Delete all fields that are hidden by the Go rules for embedded fields,
 316  	// except that fields with JSON tags are promoted.
 317  
 318  	// The fields are sorted in primary order of name, secondary order
 319  	// of field index length. Loop over names; for each name, delete
 320  	// hidden fields by choosing the one dominant field that survives.
 321  	out := fields[:0]
 322  	for advance, i := 0, 0; i < len(fields); i += advance {
 323  		// One iteration per name.
 324  		// Find the sequence of fields with the name of this first field.
 325  		fi := fields[i]
 326  		name := fi.name
 327  		for advance = 1; i+advance < len(fields); advance++ {
 328  			fj := fields[i+advance]
 329  			if fj.name != name {
 330  				break
 331  			}
 332  		}
 333  		if advance == 1 { // Only one field with this name
 334  			out = append(out, fi)
 335  			continue
 336  		}
 337  		dominant, ok := dominantField(fields[i : i+advance])
 338  		if ok {
 339  			out = append(out, dominant)
 340  		}
 341  	}
 342  
 343  	fields = out
 344  	sort.Sort(byIndex(fields))
 345  
 346  	for i := range fields {
 347  		f := &fields[i]
 348  		err := enc.newTypeEncoder(typeByIndex(t, f.index), stack+pathByIndex(t, f.index))
 349  		if err != nil {
 350  			return err
 351  		}
 352  	}
 353  	return nil
 354  }
 355  
 356  // dominantField looks through the fields, all of which are known to
 357  // have the same name, to find the single field that dominates the
 358  // others using Go's embedding rules, modified by the presence of
 359  // JSON tags. If there are multiple top-level fields, the boolean
 360  // will be false: This condition is an error in Go and we skip all
 361  // the fields.
 362  func dominantField(fields []field) (field, bool) {
 363  	// The fields are sorted in increasing index-length order, then by presence of tag.
 364  	// That means that the first field is the dominant one. We need only check
 365  	// for error cases: two fields at top level, either both tagged or neither tagged.
 366  	if len(fields) > 1 && len(fields[0].index) == len(fields[1].index) && fields[0].tag == fields[1].tag {
 367  		return field{}, false
 368  	}
 369  	return fields[0], true
 370  }
 371