struct.go raw

   1  // Copyright (c) Microsoft Corporation.
   2  // Licensed under the MIT license.
   3  
   4  package json
   5  
   6  import (
   7  	"encoding/json"
   8  	"fmt"
   9  	"reflect"
  10  	"strings"
  11  )
  12  
  13  func unmarshalStruct(jdec *json.Decoder, i interface{}) error {
  14  	v := reflect.ValueOf(i)
  15  	if v.Kind() != reflect.Ptr {
  16  		return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
  17  	}
  18  	v = v.Elem()
  19  	if v.Kind() != reflect.Struct {
  20  		return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
  21  	}
  22  
  23  	if hasUnmarshalJSON(v) {
  24  		// Indicates that this type has a custom Unmarshaler.
  25  		return jdec.Decode(v.Addr().Interface())
  26  	}
  27  
  28  	f := v.FieldByName(addField)
  29  	if f.Kind() == reflect.Invalid {
  30  		return fmt.Errorf("Unmarshal(%T) only supports structs that have the field AdditionalFields or implements json.Unmarshaler", i)
  31  	}
  32  
  33  	if f.Kind() != reflect.Map || !f.Type().AssignableTo(mapStrInterType) {
  34  		return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", i)
  35  	}
  36  
  37  	dec := newDecoder(jdec, v)
  38  	return dec.run()
  39  }
  40  
  41  type decoder struct {
  42  	dec        *json.Decoder
  43  	value      reflect.Value // This will be a reflect.Struct
  44  	translator translateFields
  45  	key        string
  46  }
  47  
  48  func newDecoder(dec *json.Decoder, value reflect.Value) *decoder {
  49  	return &decoder{value: value, dec: dec}
  50  }
  51  
  52  // run runs our decoder state machine.
  53  func (d *decoder) run() error {
  54  	var state = d.start
  55  	var err error
  56  	for {
  57  		state, err = state()
  58  		if err != nil {
  59  			return err
  60  		}
  61  		if state == nil {
  62  			return nil
  63  		}
  64  	}
  65  }
  66  
  67  // start looks for our opening delimeter '{' and then transitions to looping through our fields.
  68  func (d *decoder) start() (stateFn, error) {
  69  	var err error
  70  	d.translator, err = findFields(d.value)
  71  	if err != nil {
  72  		return nil, err
  73  	}
  74  
  75  	delim, err := d.dec.Token()
  76  	if err != nil {
  77  		return nil, err
  78  	}
  79  	if !delimIs(delim, '{') {
  80  		return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim)
  81  	}
  82  
  83  	return d.next, nil
  84  }
  85  
  86  // next gets the next struct field name from the raw json or stops the machine if we get our closing }.
  87  func (d *decoder) next() (stateFn, error) {
  88  	if !d.dec.More() {
  89  		// Remove the closing }.
  90  		if _, err := d.dec.Token(); err != nil {
  91  			return nil, err
  92  		}
  93  		return nil, nil
  94  	}
  95  
  96  	key, err := d.dec.Token()
  97  	if err != nil {
  98  		return nil, err
  99  	}
 100  
 101  	d.key = key.(string)
 102  	return d.storeValue, nil
 103  }
 104  
 105  // storeValue takes the next value and stores it our struct. If the field can't be found
 106  // in the struct, it pushes the operation to storeAdditional().
 107  func (d *decoder) storeValue() (stateFn, error) {
 108  	goName := d.translator.goName(d.key)
 109  	if goName == "" {
 110  		goName = d.key
 111  	}
 112  
 113  	// We don't have the field in the struct, so it goes in AdditionalFields.
 114  	f := d.value.FieldByName(goName)
 115  	if f.Kind() == reflect.Invalid {
 116  		return d.storeAdditional, nil
 117  	}
 118  
 119  	// Indicates that this type has a custom Unmarshaler.
 120  	if hasUnmarshalJSON(f) {
 121  		err := d.dec.Decode(f.Addr().Interface())
 122  		if err != nil {
 123  			return nil, err
 124  		}
 125  		return d.next, nil
 126  	}
 127  
 128  	t, isPtr, err := fieldBaseType(d.value, goName)
 129  	if err != nil {
 130  		return nil, fmt.Errorf("type(%s) had field(%s) %w", d.value.Type().Name(), goName, err)
 131  	}
 132  
 133  	switch t.Kind() {
 134  	// We need to recursively call ourselves on any *struct or struct.
 135  	case reflect.Struct:
 136  		if isPtr {
 137  			if f.IsNil() {
 138  				f.Set(reflect.New(t))
 139  			}
 140  		} else {
 141  			f = f.Addr()
 142  		}
 143  		if err := unmarshalStruct(d.dec, f.Interface()); err != nil {
 144  			return nil, err
 145  		}
 146  		return d.next, nil
 147  	case reflect.Map:
 148  		v := reflect.MakeMap(f.Type())
 149  		ptr := newValue(f.Type())
 150  		ptr.Elem().Set(v)
 151  		if err := unmarshalMap(d.dec, ptr); err != nil {
 152  			return nil, err
 153  		}
 154  		f.Set(ptr.Elem())
 155  		return d.next, nil
 156  	case reflect.Slice:
 157  		v := reflect.MakeSlice(f.Type(), 0, 0)
 158  		ptr := newValue(f.Type())
 159  		ptr.Elem().Set(v)
 160  		if err := unmarshalSlice(d.dec, ptr); err != nil {
 161  			return nil, err
 162  		}
 163  		f.Set(ptr.Elem())
 164  		return d.next, nil
 165  	}
 166  
 167  	if !isPtr {
 168  		f = f.Addr()
 169  	}
 170  
 171  	// For values that are pointers, we need them to be non-nil in order
 172  	// to decode into them.
 173  	if f.IsNil() {
 174  		f.Set(reflect.New(t))
 175  	}
 176  
 177  	if err := d.dec.Decode(f.Interface()); err != nil {
 178  		return nil, err
 179  	}
 180  
 181  	return d.next, nil
 182  }
 183  
 184  // storeAdditional pushes the key/value into our .AdditionalFields map.
 185  func (d *decoder) storeAdditional() (stateFn, error) {
 186  	rw := json.RawMessage{}
 187  	if err := d.dec.Decode(&rw); err != nil {
 188  		return nil, err
 189  	}
 190  	field := d.value.FieldByName(addField)
 191  	if field.IsNil() {
 192  		field.Set(reflect.MakeMap(field.Type()))
 193  	}
 194  	field.SetMapIndex(reflect.ValueOf(d.key), reflect.ValueOf(rw))
 195  	return d.next, nil
 196  }
 197  
 198  func fieldBaseType(v reflect.Value, fieldName string) (t reflect.Type, isPtr bool, err error) {
 199  	sf, ok := v.Type().FieldByName(fieldName)
 200  	if !ok {
 201  		return nil, false, fmt.Errorf("bug: fieldBaseType() lookup of field(%s) on type(%s): do not have field", fieldName, v.Type().Name())
 202  	}
 203  	t = sf.Type
 204  	if t.Kind() == reflect.Ptr {
 205  		t = t.Elem()
 206  		isPtr = true
 207  	}
 208  	if t.Kind() == reflect.Ptr {
 209  		return nil, isPtr, fmt.Errorf("received pointer to pointer type, not supported")
 210  	}
 211  	return t, isPtr, nil
 212  }
 213  
 214  type translateField struct {
 215  	jsonName string
 216  	goName   string
 217  }
 218  
 219  // translateFields is a list of translateFields with a handy lookup method.
 220  type translateFields []translateField
 221  
 222  // goName loops through a list of fields looking for one contaning the jsonName and
 223  // returning the goName. If not found, returns the empty string.
 224  // Note: not a map because at this size slices are faster even in tight loops.
 225  func (t translateFields) goName(jsonName string) string {
 226  	for _, entry := range t {
 227  		if entry.jsonName == jsonName {
 228  			return entry.goName
 229  		}
 230  	}
 231  	return ""
 232  }
 233  
 234  // jsonName loops through a list of fields looking for one contaning the goName and
 235  // returning the jsonName. If not found, returns the empty string.
 236  // Note: not a map because at this size slices are faster even in tight loops.
 237  func (t translateFields) jsonName(goName string) string {
 238  	for _, entry := range t {
 239  		if entry.goName == goName {
 240  			return entry.jsonName
 241  		}
 242  	}
 243  	return ""
 244  }
 245  
 246  var umarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
 247  
 248  // findFields parses a struct and writes the field tags for lookup. It will return an error
 249  // if any field has a type of *struct or struct that does not implement json.Marshaler.
 250  func findFields(v reflect.Value) (translateFields, error) {
 251  	if v.Kind() == reflect.Ptr {
 252  		v = v.Elem()
 253  	}
 254  	if v.Kind() != reflect.Struct {
 255  		return nil, fmt.Errorf("findFields received a %s type, expected *struct or struct", v.Type().Name())
 256  	}
 257  	tfs := make([]translateField, 0, v.NumField())
 258  	for i := 0; i < v.NumField(); i++ {
 259  		tf := translateField{
 260  			goName:   v.Type().Field(i).Name,
 261  			jsonName: parseTag(v.Type().Field(i).Tag.Get("json")),
 262  		}
 263  		switch tf.jsonName {
 264  		case "", "-":
 265  			tf.jsonName = tf.goName
 266  		}
 267  		tfs = append(tfs, tf)
 268  
 269  		f := v.Field(i)
 270  		if f.Kind() == reflect.Ptr {
 271  			f = f.Elem()
 272  		}
 273  		if f.Kind() == reflect.Struct {
 274  			if f.Type().Implements(umarshalerType) {
 275  				return nil, fmt.Errorf("struct type %q which has field %q which "+
 276  					"doesn't implement json.Unmarshaler", v.Type().Name(), v.Type().Field(i).Name)
 277  			}
 278  		}
 279  	}
 280  	return tfs, nil
 281  }
 282  
 283  // parseTag just returns the first entry in the tag. tag is the string
 284  // returned by reflect.StructField.Tag().Get().
 285  func parseTag(tag string) string {
 286  	if idx := strings.Index(tag, ","); idx != -1 {
 287  		return tag[:idx]
 288  	}
 289  	return tag
 290  }
 291