fields.go raw

   1  // Copyright 2013 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  package yaml
   5  
   6  import (
   7  	"bytes"
   8  	"encoding"
   9  	"encoding/json"
  10  	"reflect"
  11  	"sort"
  12  	"strings"
  13  	"sync"
  14  	"unicode"
  15  	"unicode/utf8"
  16  )
  17  
  18  // indirect walks down v allocating pointers as needed,
  19  // until it gets to a non-pointer.
  20  // if it encounters an Unmarshaler, indirect stops and returns that.
  21  // if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
  22  func indirect(v reflect.Value, decodingNull bool) (json.Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
  23  	// If v is a named type and is addressable,
  24  	// start with its address, so that if the type has pointer methods,
  25  	// we find them.
  26  	if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
  27  		v = v.Addr()
  28  	}
  29  	for {
  30  		// Load value from interface, but only if the result will be
  31  		// usefully addressable.
  32  		if v.Kind() == reflect.Interface && !v.IsNil() {
  33  			e := v.Elem()
  34  			if e.Kind() == reflect.Ptr && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Ptr) {
  35  				v = e
  36  				continue
  37  			}
  38  		}
  39  
  40  		if v.Kind() != reflect.Ptr {
  41  			break
  42  		}
  43  
  44  		if v.Elem().Kind() != reflect.Ptr && decodingNull && v.CanSet() {
  45  			break
  46  		}
  47  		if v.IsNil() {
  48  			if v.CanSet() {
  49  				v.Set(reflect.New(v.Type().Elem()))
  50  			} else {
  51  				v = reflect.New(v.Type().Elem())
  52  			}
  53  		}
  54  		if v.Type().NumMethod() > 0 {
  55  			if u, ok := v.Interface().(json.Unmarshaler); ok {
  56  				return u, nil, reflect.Value{}
  57  			}
  58  			if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
  59  				return nil, u, reflect.Value{}
  60  			}
  61  		}
  62  		v = v.Elem()
  63  	}
  64  	return nil, nil, v
  65  }
  66  
  67  // A field represents a single field found in a struct.
  68  type field struct {
  69  	name      string
  70  	nameBytes []byte                 // []byte(name)
  71  	equalFold func(s, t []byte) bool // bytes.EqualFold or equivalent
  72  
  73  	tag       bool
  74  	index     []int
  75  	typ       reflect.Type
  76  	omitEmpty bool
  77  	quoted    bool
  78  }
  79  
  80  func fillField(f field) field {
  81  	f.nameBytes = []byte(f.name)
  82  	f.equalFold = foldFunc(f.nameBytes)
  83  	return f
  84  }
  85  
  86  // byName sorts field by name, breaking ties with depth,
  87  // then breaking ties with "name came from json tag", then
  88  // breaking ties with index sequence.
  89  type byName []field
  90  
  91  func (x byName) Len() int { return len(x) }
  92  
  93  func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
  94  
  95  func (x byName) Less(i, j int) bool {
  96  	if x[i].name != x[j].name {
  97  		return x[i].name < x[j].name
  98  	}
  99  	if len(x[i].index) != len(x[j].index) {
 100  		return len(x[i].index) < len(x[j].index)
 101  	}
 102  	if x[i].tag != x[j].tag {
 103  		return x[i].tag
 104  	}
 105  	return byIndex(x).Less(i, j)
 106  }
 107  
 108  // byIndex sorts field by index sequence.
 109  type byIndex []field
 110  
 111  func (x byIndex) Len() int { return len(x) }
 112  
 113  func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
 114  
 115  func (x byIndex) Less(i, j int) bool {
 116  	for k, xik := range x[i].index {
 117  		if k >= len(x[j].index) {
 118  			return false
 119  		}
 120  		if xik != x[j].index[k] {
 121  			return xik < x[j].index[k]
 122  		}
 123  	}
 124  	return len(x[i].index) < len(x[j].index)
 125  }
 126  
 127  // typeFields returns a list of fields that JSON should recognize for the given type.
 128  // The algorithm is breadth-first search over the set of structs to include - the top struct
 129  // and then any reachable anonymous structs.
 130  func typeFields(t reflect.Type) []field {
 131  	// Anonymous fields to explore at the current level and the next.
 132  	current := []field{}
 133  	next := []field{{typ: t}}
 134  
 135  	// Count of queued names for current level and the next.
 136  	count := map[reflect.Type]int{}
 137  	nextCount := map[reflect.Type]int{}
 138  
 139  	// Types already visited at an earlier level.
 140  	visited := map[reflect.Type]bool{}
 141  
 142  	// Fields found.
 143  	var fields []field
 144  
 145  	for len(next) > 0 {
 146  		current, next = next, current[:0]
 147  		count, nextCount = nextCount, map[reflect.Type]int{}
 148  
 149  		for _, f := range current {
 150  			if visited[f.typ] {
 151  				continue
 152  			}
 153  			visited[f.typ] = true
 154  
 155  			// Scan f.typ for fields to include.
 156  			for i := 0; i < f.typ.NumField(); i++ {
 157  				sf := f.typ.Field(i)
 158  				if sf.PkgPath != "" { // unexported
 159  					continue
 160  				}
 161  				tag := sf.Tag.Get("json")
 162  				if tag == "-" {
 163  					continue
 164  				}
 165  				name, opts := parseTag(tag)
 166  				if !isValidTag(name) {
 167  					name = ""
 168  				}
 169  				index := make([]int, len(f.index)+1)
 170  				copy(index, f.index)
 171  				index[len(f.index)] = i
 172  
 173  				ft := sf.Type
 174  				if ft.Name() == "" && ft.Kind() == reflect.Ptr {
 175  					// Follow pointer.
 176  					ft = ft.Elem()
 177  				}
 178  
 179  				// Record found field and index sequence.
 180  				if name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct {
 181  					tagged := name != ""
 182  					if name == "" {
 183  						name = sf.Name
 184  					}
 185  					fields = append(fields, fillField(field{
 186  						name:      name,
 187  						tag:       tagged,
 188  						index:     index,
 189  						typ:       ft,
 190  						omitEmpty: opts.Contains("omitempty"),
 191  						quoted:    opts.Contains("string"),
 192  					}))
 193  					if count[f.typ] > 1 {
 194  						// If there were multiple instances, add a second,
 195  						// so that the annihilation code will see a duplicate.
 196  						// It only cares about the distinction between 1 or 2,
 197  						// so don't bother generating any more copies.
 198  						fields = append(fields, fields[len(fields)-1])
 199  					}
 200  					continue
 201  				}
 202  
 203  				// Record new anonymous struct to explore in next round.
 204  				nextCount[ft]++
 205  				if nextCount[ft] == 1 {
 206  					next = append(next, fillField(field{name: ft.Name(), index: index, typ: ft}))
 207  				}
 208  			}
 209  		}
 210  	}
 211  
 212  	sort.Sort(byName(fields))
 213  
 214  	// Delete all fields that are hidden by the Go rules for embedded fields,
 215  	// except that fields with JSON tags are promoted.
 216  
 217  	// The fields are sorted in primary order of name, secondary order
 218  	// of field index length. Loop over names; for each name, delete
 219  	// hidden fields by choosing the one dominant field that survives.
 220  	out := fields[:0]
 221  	for advance, i := 0, 0; i < len(fields); i += advance {
 222  		// One iteration per name.
 223  		// Find the sequence of fields with the name of this first field.
 224  		fi := fields[i]
 225  		name := fi.name
 226  		for advance = 1; i+advance < len(fields); advance++ {
 227  			fj := fields[i+advance]
 228  			if fj.name != name {
 229  				break
 230  			}
 231  		}
 232  		if advance == 1 { // Only one field with this name
 233  			out = append(out, fi)
 234  			continue
 235  		}
 236  		dominant, ok := dominantField(fields[i : i+advance])
 237  		if ok {
 238  			out = append(out, dominant)
 239  		}
 240  	}
 241  
 242  	fields = out
 243  	sort.Sort(byIndex(fields))
 244  
 245  	return fields
 246  }
 247  
 248  // dominantField looks through the fields, all of which are known to
 249  // have the same name, to find the single field that dominates the
 250  // others using Go's embedding rules, modified by the presence of
 251  // JSON tags. If there are multiple top-level fields, the boolean
 252  // will be false: This condition is an error in Go and we skip all
 253  // the fields.
 254  func dominantField(fields []field) (field, bool) {
 255  	// The fields are sorted in increasing index-length order. The winner
 256  	// must therefore be one with the shortest index length. Drop all
 257  	// longer entries, which is easy: just truncate the slice.
 258  	length := len(fields[0].index)
 259  	tagged := -1 // Index of first tagged field.
 260  	for i, f := range fields {
 261  		if len(f.index) > length {
 262  			fields = fields[:i]
 263  			break
 264  		}
 265  		if f.tag {
 266  			if tagged >= 0 {
 267  				// Multiple tagged fields at the same level: conflict.
 268  				// Return no field.
 269  				return field{}, false
 270  			}
 271  			tagged = i
 272  		}
 273  	}
 274  	if tagged >= 0 {
 275  		return fields[tagged], true
 276  	}
 277  	// All remaining fields have the same length. If there's more than one,
 278  	// we have a conflict (two fields named "X" at the same level) and we
 279  	// return no field.
 280  	if len(fields) > 1 {
 281  		return field{}, false
 282  	}
 283  	return fields[0], true
 284  }
 285  
 286  var fieldCache struct {
 287  	sync.RWMutex
 288  	m map[reflect.Type][]field
 289  }
 290  
 291  // cachedTypeFields is like typeFields but uses a cache to avoid repeated work.
 292  func cachedTypeFields(t reflect.Type) []field {
 293  	fieldCache.RLock()
 294  	f := fieldCache.m[t]
 295  	fieldCache.RUnlock()
 296  	if f != nil {
 297  		return f
 298  	}
 299  
 300  	// Compute fields without lock.
 301  	// Might duplicate effort but won't hold other computations back.
 302  	f = typeFields(t)
 303  	if f == nil {
 304  		f = []field{}
 305  	}
 306  
 307  	fieldCache.Lock()
 308  	if fieldCache.m == nil {
 309  		fieldCache.m = map[reflect.Type][]field{}
 310  	}
 311  	fieldCache.m[t] = f
 312  	fieldCache.Unlock()
 313  	return f
 314  }
 315  
 316  func isValidTag(s string) bool {
 317  	if s == "" {
 318  		return false
 319  	}
 320  	for _, c := range s {
 321  		switch {
 322  		case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
 323  			// Backslash and quote chars are reserved, but
 324  			// otherwise any punctuation chars are allowed
 325  			// in a tag name.
 326  		default:
 327  			if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
 328  				return false
 329  			}
 330  		}
 331  	}
 332  	return true
 333  }
 334  
 335  const (
 336  	caseMask     = ^byte(0x20) // Mask to ignore case in ASCII.
 337  	kelvin       = '\u212a'
 338  	smallLongEss = '\u017f'
 339  )
 340  
 341  // foldFunc returns one of four different case folding equivalence
 342  // functions, from most general (and slow) to fastest:
 343  //
 344  // 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
 345  // 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
 346  // 3) asciiEqualFold, no special, but includes non-letters (including _)
 347  // 4) simpleLetterEqualFold, no specials, no non-letters.
 348  //
 349  // The letters S and K are special because they map to 3 runes, not just 2:
 350  //  * S maps to s and to U+017F 'ſ' Latin small letter long s
 351  //  * k maps to K and to U+212A 'K' Kelvin sign
 352  // See http://play.golang.org/p/tTxjOc0OGo
 353  //
 354  // The returned function is specialized for matching against s and
 355  // should only be given s. It's not curried for performance reasons.
 356  func foldFunc(s []byte) func(s, t []byte) bool {
 357  	nonLetter := false
 358  	special := false // special letter
 359  	for _, b := range s {
 360  		if b >= utf8.RuneSelf {
 361  			return bytes.EqualFold
 362  		}
 363  		upper := b & caseMask
 364  		if upper < 'A' || upper > 'Z' {
 365  			nonLetter = true
 366  		} else if upper == 'K' || upper == 'S' {
 367  			// See above for why these letters are special.
 368  			special = true
 369  		}
 370  	}
 371  	if special {
 372  		return equalFoldRight
 373  	}
 374  	if nonLetter {
 375  		return asciiEqualFold
 376  	}
 377  	return simpleLetterEqualFold
 378  }
 379  
 380  // equalFoldRight is a specialization of bytes.EqualFold when s is
 381  // known to be all ASCII (including punctuation), but contains an 's',
 382  // 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
 383  // See comments on foldFunc.
 384  func equalFoldRight(s, t []byte) bool {
 385  	for _, sb := range s {
 386  		if len(t) == 0 {
 387  			return false
 388  		}
 389  		tb := t[0]
 390  		if tb < utf8.RuneSelf {
 391  			if sb != tb {
 392  				sbUpper := sb & caseMask
 393  				if 'A' <= sbUpper && sbUpper <= 'Z' {
 394  					if sbUpper != tb&caseMask {
 395  						return false
 396  					}
 397  				} else {
 398  					return false
 399  				}
 400  			}
 401  			t = t[1:]
 402  			continue
 403  		}
 404  		// sb is ASCII and t is not. t must be either kelvin
 405  		// sign or long s; sb must be s, S, k, or K.
 406  		tr, size := utf8.DecodeRune(t)
 407  		switch sb {
 408  		case 's', 'S':
 409  			if tr != smallLongEss {
 410  				return false
 411  			}
 412  		case 'k', 'K':
 413  			if tr != kelvin {
 414  				return false
 415  			}
 416  		default:
 417  			return false
 418  		}
 419  		t = t[size:]
 420  
 421  	}
 422  	if len(t) > 0 {
 423  		return false
 424  	}
 425  	return true
 426  }
 427  
 428  // asciiEqualFold is a specialization of bytes.EqualFold for use when
 429  // s is all ASCII (but may contain non-letters) and contains no
 430  // special-folding letters.
 431  // See comments on foldFunc.
 432  func asciiEqualFold(s, t []byte) bool {
 433  	if len(s) != len(t) {
 434  		return false
 435  	}
 436  	for i, sb := range s {
 437  		tb := t[i]
 438  		if sb == tb {
 439  			continue
 440  		}
 441  		if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
 442  			if sb&caseMask != tb&caseMask {
 443  				return false
 444  			}
 445  		} else {
 446  			return false
 447  		}
 448  	}
 449  	return true
 450  }
 451  
 452  // simpleLetterEqualFold is a specialization of bytes.EqualFold for
 453  // use when s is all ASCII letters (no underscores, etc) and also
 454  // doesn't contain 'k', 'K', 's', or 'S'.
 455  // See comments on foldFunc.
 456  func simpleLetterEqualFold(s, t []byte) bool {
 457  	if len(s) != len(t) {
 458  		return false
 459  	}
 460  	for i, b := range s {
 461  		if b&caseMask != t[i]&caseMask {
 462  			return false
 463  		}
 464  	}
 465  	return true
 466  }
 467  
 468  // tagOptions is the string following a comma in a struct field's "json"
 469  // tag, or the empty string. It does not include the leading comma.
 470  type tagOptions string
 471  
 472  // parseTag splits a struct field's json tag into its name and
 473  // comma-separated options.
 474  func parseTag(tag string) (string, tagOptions) {
 475  	if idx := strings.Index(tag, ","); idx != -1 {
 476  		return tag[:idx], tagOptions(tag[idx+1:])
 477  	}
 478  	return tag, tagOptions("")
 479  }
 480  
 481  // Contains reports whether a comma-separated list of options
 482  // contains a particular substr flag. substr must be surrounded by a
 483  // string boundary or commas.
 484  func (o tagOptions) Contains(optionName string) bool {
 485  	if len(o) == 0 {
 486  		return false
 487  	}
 488  	s := string(o)
 489  	for s != "" {
 490  		var next string
 491  		i := strings.Index(s, ",")
 492  		if i >= 0 {
 493  			s, next = s[:i], s[i+1:]
 494  		}
 495  		if s == optionName {
 496  			return true
 497  		}
 498  		s = next
 499  	}
 500  	return false
 501  }
 502