reflect_extension.go raw

   1  package jsoniter
   2  
   3  import (
   4  	"fmt"
   5  	"github.com/modern-go/reflect2"
   6  	"reflect"
   7  	"sort"
   8  	"strings"
   9  	"unicode"
  10  	"unsafe"
  11  )
  12  
  13  var typeDecoders = map[string]ValDecoder{}
  14  var fieldDecoders = map[string]ValDecoder{}
  15  var typeEncoders = map[string]ValEncoder{}
  16  var fieldEncoders = map[string]ValEncoder{}
  17  var extensions = []Extension{}
  18  
  19  // StructDescriptor describe how should we encode/decode the struct
  20  type StructDescriptor struct {
  21  	Type   reflect2.Type
  22  	Fields []*Binding
  23  }
  24  
  25  // GetField get one field from the descriptor by its name.
  26  // Can not use map here to keep field orders.
  27  func (structDescriptor *StructDescriptor) GetField(fieldName string) *Binding {
  28  	for _, binding := range structDescriptor.Fields {
  29  		if binding.Field.Name() == fieldName {
  30  			return binding
  31  		}
  32  	}
  33  	return nil
  34  }
  35  
  36  // Binding describe how should we encode/decode the struct field
  37  type Binding struct {
  38  	levels    []int
  39  	Field     reflect2.StructField
  40  	FromNames []string
  41  	ToNames   []string
  42  	Encoder   ValEncoder
  43  	Decoder   ValDecoder
  44  }
  45  
  46  // Extension the one for all SPI. Customize encoding/decoding by specifying alternate encoder/decoder.
  47  // Can also rename fields by UpdateStructDescriptor.
  48  type Extension interface {
  49  	UpdateStructDescriptor(structDescriptor *StructDescriptor)
  50  	CreateMapKeyDecoder(typ reflect2.Type) ValDecoder
  51  	CreateMapKeyEncoder(typ reflect2.Type) ValEncoder
  52  	CreateDecoder(typ reflect2.Type) ValDecoder
  53  	CreateEncoder(typ reflect2.Type) ValEncoder
  54  	DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder
  55  	DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder
  56  }
  57  
  58  // DummyExtension embed this type get dummy implementation for all methods of Extension
  59  type DummyExtension struct {
  60  }
  61  
  62  // UpdateStructDescriptor No-op
  63  func (extension *DummyExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
  64  }
  65  
  66  // CreateMapKeyDecoder No-op
  67  func (extension *DummyExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
  68  	return nil
  69  }
  70  
  71  // CreateMapKeyEncoder No-op
  72  func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
  73  	return nil
  74  }
  75  
  76  // CreateDecoder No-op
  77  func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
  78  	return nil
  79  }
  80  
  81  // CreateEncoder No-op
  82  func (extension *DummyExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
  83  	return nil
  84  }
  85  
  86  // DecorateDecoder No-op
  87  func (extension *DummyExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
  88  	return decoder
  89  }
  90  
  91  // DecorateEncoder No-op
  92  func (extension *DummyExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
  93  	return encoder
  94  }
  95  
  96  type EncoderExtension map[reflect2.Type]ValEncoder
  97  
  98  // UpdateStructDescriptor No-op
  99  func (extension EncoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
 100  }
 101  
 102  // CreateDecoder No-op
 103  func (extension EncoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
 104  	return nil
 105  }
 106  
 107  // CreateEncoder get encoder from map
 108  func (extension EncoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
 109  	return extension[typ]
 110  }
 111  
 112  // CreateMapKeyDecoder No-op
 113  func (extension EncoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
 114  	return nil
 115  }
 116  
 117  // CreateMapKeyEncoder No-op
 118  func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
 119  	return nil
 120  }
 121  
 122  // DecorateDecoder No-op
 123  func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
 124  	return decoder
 125  }
 126  
 127  // DecorateEncoder No-op
 128  func (extension EncoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
 129  	return encoder
 130  }
 131  
 132  type DecoderExtension map[reflect2.Type]ValDecoder
 133  
 134  // UpdateStructDescriptor No-op
 135  func (extension DecoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) {
 136  }
 137  
 138  // CreateMapKeyDecoder No-op
 139  func (extension DecoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder {
 140  	return nil
 141  }
 142  
 143  // CreateMapKeyEncoder No-op
 144  func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder {
 145  	return nil
 146  }
 147  
 148  // CreateDecoder get decoder from map
 149  func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder {
 150  	return extension[typ]
 151  }
 152  
 153  // CreateEncoder No-op
 154  func (extension DecoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder {
 155  	return nil
 156  }
 157  
 158  // DecorateDecoder No-op
 159  func (extension DecoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder {
 160  	return decoder
 161  }
 162  
 163  // DecorateEncoder No-op
 164  func (extension DecoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder {
 165  	return encoder
 166  }
 167  
 168  type funcDecoder struct {
 169  	fun DecoderFunc
 170  }
 171  
 172  func (decoder *funcDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
 173  	decoder.fun(ptr, iter)
 174  }
 175  
 176  type funcEncoder struct {
 177  	fun         EncoderFunc
 178  	isEmptyFunc func(ptr unsafe.Pointer) bool
 179  }
 180  
 181  func (encoder *funcEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 182  	encoder.fun(ptr, stream)
 183  }
 184  
 185  func (encoder *funcEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 186  	if encoder.isEmptyFunc == nil {
 187  		return false
 188  	}
 189  	return encoder.isEmptyFunc(ptr)
 190  }
 191  
 192  // DecoderFunc the function form of TypeDecoder
 193  type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator)
 194  
 195  // EncoderFunc the function form of TypeEncoder
 196  type EncoderFunc func(ptr unsafe.Pointer, stream *Stream)
 197  
 198  // RegisterTypeDecoderFunc register TypeDecoder for a type with function
 199  func RegisterTypeDecoderFunc(typ string, fun DecoderFunc) {
 200  	typeDecoders[typ] = &funcDecoder{fun}
 201  }
 202  
 203  // RegisterTypeDecoder register TypeDecoder for a typ
 204  func RegisterTypeDecoder(typ string, decoder ValDecoder) {
 205  	typeDecoders[typ] = decoder
 206  }
 207  
 208  // RegisterFieldDecoderFunc register TypeDecoder for a struct field with function
 209  func RegisterFieldDecoderFunc(typ string, field string, fun DecoderFunc) {
 210  	RegisterFieldDecoder(typ, field, &funcDecoder{fun})
 211  }
 212  
 213  // RegisterFieldDecoder register TypeDecoder for a struct field
 214  func RegisterFieldDecoder(typ string, field string, decoder ValDecoder) {
 215  	fieldDecoders[fmt.Sprintf("%s/%s", typ, field)] = decoder
 216  }
 217  
 218  // RegisterTypeEncoderFunc register TypeEncoder for a type with encode/isEmpty function
 219  func RegisterTypeEncoderFunc(typ string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) {
 220  	typeEncoders[typ] = &funcEncoder{fun, isEmptyFunc}
 221  }
 222  
 223  // RegisterTypeEncoder register TypeEncoder for a type
 224  func RegisterTypeEncoder(typ string, encoder ValEncoder) {
 225  	typeEncoders[typ] = encoder
 226  }
 227  
 228  // RegisterFieldEncoderFunc register TypeEncoder for a struct field with encode/isEmpty function
 229  func RegisterFieldEncoderFunc(typ string, field string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) {
 230  	RegisterFieldEncoder(typ, field, &funcEncoder{fun, isEmptyFunc})
 231  }
 232  
 233  // RegisterFieldEncoder register TypeEncoder for a struct field
 234  func RegisterFieldEncoder(typ string, field string, encoder ValEncoder) {
 235  	fieldEncoders[fmt.Sprintf("%s/%s", typ, field)] = encoder
 236  }
 237  
 238  // RegisterExtension register extension
 239  func RegisterExtension(extension Extension) {
 240  	extensions = append(extensions, extension)
 241  }
 242  
 243  func getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
 244  	decoder := _getTypeDecoderFromExtension(ctx, typ)
 245  	if decoder != nil {
 246  		for _, extension := range extensions {
 247  			decoder = extension.DecorateDecoder(typ, decoder)
 248  		}
 249  		decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder)
 250  		for _, extension := range ctx.extraExtensions {
 251  			decoder = extension.DecorateDecoder(typ, decoder)
 252  		}
 253  	}
 254  	return decoder
 255  }
 256  func _getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder {
 257  	for _, extension := range extensions {
 258  		decoder := extension.CreateDecoder(typ)
 259  		if decoder != nil {
 260  			return decoder
 261  		}
 262  	}
 263  	decoder := ctx.decoderExtension.CreateDecoder(typ)
 264  	if decoder != nil {
 265  		return decoder
 266  	}
 267  	for _, extension := range ctx.extraExtensions {
 268  		decoder := extension.CreateDecoder(typ)
 269  		if decoder != nil {
 270  			return decoder
 271  		}
 272  	}
 273  	typeName := typ.String()
 274  	decoder = typeDecoders[typeName]
 275  	if decoder != nil {
 276  		return decoder
 277  	}
 278  	if typ.Kind() == reflect.Ptr {
 279  		ptrType := typ.(*reflect2.UnsafePtrType)
 280  		decoder := typeDecoders[ptrType.Elem().String()]
 281  		if decoder != nil {
 282  			return &OptionalDecoder{ptrType.Elem(), decoder}
 283  		}
 284  	}
 285  	return nil
 286  }
 287  
 288  func getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
 289  	encoder := _getTypeEncoderFromExtension(ctx, typ)
 290  	if encoder != nil {
 291  		for _, extension := range extensions {
 292  			encoder = extension.DecorateEncoder(typ, encoder)
 293  		}
 294  		encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder)
 295  		for _, extension := range ctx.extraExtensions {
 296  			encoder = extension.DecorateEncoder(typ, encoder)
 297  		}
 298  	}
 299  	return encoder
 300  }
 301  
 302  func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder {
 303  	for _, extension := range extensions {
 304  		encoder := extension.CreateEncoder(typ)
 305  		if encoder != nil {
 306  			return encoder
 307  		}
 308  	}
 309  	encoder := ctx.encoderExtension.CreateEncoder(typ)
 310  	if encoder != nil {
 311  		return encoder
 312  	}
 313  	for _, extension := range ctx.extraExtensions {
 314  		encoder := extension.CreateEncoder(typ)
 315  		if encoder != nil {
 316  			return encoder
 317  		}
 318  	}
 319  	typeName := typ.String()
 320  	encoder = typeEncoders[typeName]
 321  	if encoder != nil {
 322  		return encoder
 323  	}
 324  	if typ.Kind() == reflect.Ptr {
 325  		typePtr := typ.(*reflect2.UnsafePtrType)
 326  		encoder := typeEncoders[typePtr.Elem().String()]
 327  		if encoder != nil {
 328  			return &OptionalEncoder{encoder}
 329  		}
 330  	}
 331  	return nil
 332  }
 333  
 334  func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor {
 335  	structType := typ.(*reflect2.UnsafeStructType)
 336  	embeddedBindings := []*Binding{}
 337  	bindings := []*Binding{}
 338  	for i := 0; i < structType.NumField(); i++ {
 339  		field := structType.Field(i)
 340  		tag, hastag := field.Tag().Lookup(ctx.getTagKey())
 341  		if ctx.onlyTaggedField && !hastag && !field.Anonymous() {
 342  			continue
 343  		}
 344  		if tag == "-" || field.Name() == "_" {
 345  			continue
 346  		}
 347  		tagParts := strings.Split(tag, ",")
 348  		if field.Anonymous() && (tag == "" || tagParts[0] == "") {
 349  			if field.Type().Kind() == reflect.Struct {
 350  				structDescriptor := describeStruct(ctx, field.Type())
 351  				for _, binding := range structDescriptor.Fields {
 352  					binding.levels = append([]int{i}, binding.levels...)
 353  					omitempty := binding.Encoder.(*structFieldEncoder).omitempty
 354  					binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty}
 355  					binding.Decoder = &structFieldDecoder{field, binding.Decoder}
 356  					embeddedBindings = append(embeddedBindings, binding)
 357  				}
 358  				continue
 359  			} else if field.Type().Kind() == reflect.Ptr {
 360  				ptrType := field.Type().(*reflect2.UnsafePtrType)
 361  				if ptrType.Elem().Kind() == reflect.Struct {
 362  					structDescriptor := describeStruct(ctx, ptrType.Elem())
 363  					for _, binding := range structDescriptor.Fields {
 364  						binding.levels = append([]int{i}, binding.levels...)
 365  						omitempty := binding.Encoder.(*structFieldEncoder).omitempty
 366  						binding.Encoder = &dereferenceEncoder{binding.Encoder}
 367  						binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty}
 368  						binding.Decoder = &dereferenceDecoder{ptrType.Elem(), binding.Decoder}
 369  						binding.Decoder = &structFieldDecoder{field, binding.Decoder}
 370  						embeddedBindings = append(embeddedBindings, binding)
 371  					}
 372  					continue
 373  				}
 374  			}
 375  		}
 376  		fieldNames := calcFieldNames(field.Name(), tagParts[0], tag)
 377  		fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name())
 378  		decoder := fieldDecoders[fieldCacheKey]
 379  		if decoder == nil {
 380  			decoder = decoderOfType(ctx.append(field.Name()), field.Type())
 381  		}
 382  		encoder := fieldEncoders[fieldCacheKey]
 383  		if encoder == nil {
 384  			encoder = encoderOfType(ctx.append(field.Name()), field.Type())
 385  		}
 386  		binding := &Binding{
 387  			Field:     field,
 388  			FromNames: fieldNames,
 389  			ToNames:   fieldNames,
 390  			Decoder:   decoder,
 391  			Encoder:   encoder,
 392  		}
 393  		binding.levels = []int{i}
 394  		bindings = append(bindings, binding)
 395  	}
 396  	return createStructDescriptor(ctx, typ, bindings, embeddedBindings)
 397  }
 398  func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor {
 399  	structDescriptor := &StructDescriptor{
 400  		Type:   typ,
 401  		Fields: bindings,
 402  	}
 403  	for _, extension := range extensions {
 404  		extension.UpdateStructDescriptor(structDescriptor)
 405  	}
 406  	ctx.encoderExtension.UpdateStructDescriptor(structDescriptor)
 407  	ctx.decoderExtension.UpdateStructDescriptor(structDescriptor)
 408  	for _, extension := range ctx.extraExtensions {
 409  		extension.UpdateStructDescriptor(structDescriptor)
 410  	}
 411  	processTags(structDescriptor, ctx.frozenConfig)
 412  	// merge normal & embedded bindings & sort with original order
 413  	allBindings := sortableBindings(append(embeddedBindings, structDescriptor.Fields...))
 414  	sort.Sort(allBindings)
 415  	structDescriptor.Fields = allBindings
 416  	return structDescriptor
 417  }
 418  
 419  type sortableBindings []*Binding
 420  
 421  func (bindings sortableBindings) Len() int {
 422  	return len(bindings)
 423  }
 424  
 425  func (bindings sortableBindings) Less(i, j int) bool {
 426  	left := bindings[i].levels
 427  	right := bindings[j].levels
 428  	k := 0
 429  	for {
 430  		if left[k] < right[k] {
 431  			return true
 432  		} else if left[k] > right[k] {
 433  			return false
 434  		}
 435  		k++
 436  	}
 437  }
 438  
 439  func (bindings sortableBindings) Swap(i, j int) {
 440  	bindings[i], bindings[j] = bindings[j], bindings[i]
 441  }
 442  
 443  func processTags(structDescriptor *StructDescriptor, cfg *frozenConfig) {
 444  	for _, binding := range structDescriptor.Fields {
 445  		shouldOmitEmpty := false
 446  		tagParts := strings.Split(binding.Field.Tag().Get(cfg.getTagKey()), ",")
 447  		for _, tagPart := range tagParts[1:] {
 448  			if tagPart == "omitempty" {
 449  				shouldOmitEmpty = true
 450  			} else if tagPart == "string" {
 451  				if binding.Field.Type().Kind() == reflect.String {
 452  					binding.Decoder = &stringModeStringDecoder{binding.Decoder, cfg}
 453  					binding.Encoder = &stringModeStringEncoder{binding.Encoder, cfg}
 454  				} else {
 455  					binding.Decoder = &stringModeNumberDecoder{binding.Decoder}
 456  					binding.Encoder = &stringModeNumberEncoder{binding.Encoder}
 457  				}
 458  			}
 459  		}
 460  		binding.Decoder = &structFieldDecoder{binding.Field, binding.Decoder}
 461  		binding.Encoder = &structFieldEncoder{binding.Field, binding.Encoder, shouldOmitEmpty}
 462  	}
 463  }
 464  
 465  func calcFieldNames(originalFieldName string, tagProvidedFieldName string, wholeTag string) []string {
 466  	// ignore?
 467  	if wholeTag == "-" {
 468  		return []string{}
 469  	}
 470  	// rename?
 471  	var fieldNames []string
 472  	if tagProvidedFieldName == "" {
 473  		fieldNames = []string{originalFieldName}
 474  	} else {
 475  		fieldNames = []string{tagProvidedFieldName}
 476  	}
 477  	// private?
 478  	isNotExported := unicode.IsLower(rune(originalFieldName[0])) || originalFieldName[0] == '_'
 479  	if isNotExported {
 480  		fieldNames = []string{}
 481  	}
 482  	return fieldNames
 483  }
 484