reflect_marshaler.go raw

   1  package jsoniter
   2  
   3  import (
   4  	"encoding"
   5  	"encoding/json"
   6  	"unsafe"
   7  
   8  	"github.com/modern-go/reflect2"
   9  )
  10  
  11  var marshalerType = reflect2.TypeOfPtr((*json.Marshaler)(nil)).Elem()
  12  var unmarshalerType = reflect2.TypeOfPtr((*json.Unmarshaler)(nil)).Elem()
  13  var textMarshalerType = reflect2.TypeOfPtr((*encoding.TextMarshaler)(nil)).Elem()
  14  var textUnmarshalerType = reflect2.TypeOfPtr((*encoding.TextUnmarshaler)(nil)).Elem()
  15  
  16  func createDecoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValDecoder {
  17  	ptrType := reflect2.PtrTo(typ)
  18  	if ptrType.Implements(unmarshalerType) {
  19  		return &referenceDecoder{
  20  			&unmarshalerDecoder{ptrType},
  21  		}
  22  	}
  23  	if ptrType.Implements(textUnmarshalerType) {
  24  		return &referenceDecoder{
  25  			&textUnmarshalerDecoder{ptrType},
  26  		}
  27  	}
  28  	return nil
  29  }
  30  
  31  func createEncoderOfMarshaler(ctx *ctx, typ reflect2.Type) ValEncoder {
  32  	if typ == marshalerType {
  33  		checkIsEmpty := createCheckIsEmpty(ctx, typ)
  34  		var encoder ValEncoder = &directMarshalerEncoder{
  35  			checkIsEmpty: checkIsEmpty,
  36  		}
  37  		return encoder
  38  	}
  39  	if typ.Implements(marshalerType) {
  40  		checkIsEmpty := createCheckIsEmpty(ctx, typ)
  41  		var encoder ValEncoder = &marshalerEncoder{
  42  			valType:      typ,
  43  			checkIsEmpty: checkIsEmpty,
  44  		}
  45  		return encoder
  46  	}
  47  	ptrType := reflect2.PtrTo(typ)
  48  	if ctx.prefix != "" && ptrType.Implements(marshalerType) {
  49  		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
  50  		var encoder ValEncoder = &marshalerEncoder{
  51  			valType:      ptrType,
  52  			checkIsEmpty: checkIsEmpty,
  53  		}
  54  		return &referenceEncoder{encoder}
  55  	}
  56  	if typ == textMarshalerType {
  57  		checkIsEmpty := createCheckIsEmpty(ctx, typ)
  58  		var encoder ValEncoder = &directTextMarshalerEncoder{
  59  			checkIsEmpty:  checkIsEmpty,
  60  			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
  61  		}
  62  		return encoder
  63  	}
  64  	if typ.Implements(textMarshalerType) {
  65  		checkIsEmpty := createCheckIsEmpty(ctx, typ)
  66  		var encoder ValEncoder = &textMarshalerEncoder{
  67  			valType:       typ,
  68  			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
  69  			checkIsEmpty:  checkIsEmpty,
  70  		}
  71  		return encoder
  72  	}
  73  	// if prefix is empty, the type is the root type
  74  	if ctx.prefix != "" && ptrType.Implements(textMarshalerType) {
  75  		checkIsEmpty := createCheckIsEmpty(ctx, ptrType)
  76  		var encoder ValEncoder = &textMarshalerEncoder{
  77  			valType:       ptrType,
  78  			stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
  79  			checkIsEmpty:  checkIsEmpty,
  80  		}
  81  		return &referenceEncoder{encoder}
  82  	}
  83  	return nil
  84  }
  85  
  86  type marshalerEncoder struct {
  87  	checkIsEmpty checkIsEmpty
  88  	valType      reflect2.Type
  89  }
  90  
  91  func (encoder *marshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
  92  	obj := encoder.valType.UnsafeIndirect(ptr)
  93  	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
  94  		stream.WriteNil()
  95  		return
  96  	}
  97  	marshaler := obj.(json.Marshaler)
  98  	bytes, err := marshaler.MarshalJSON()
  99  	if err != nil {
 100  		stream.Error = err
 101  	} else {
 102  		// html escape was already done by jsoniter
 103  		// but the extra '\n' should be trimed
 104  		l := len(bytes)
 105  		if l > 0 && bytes[l-1] == '\n' {
 106  			bytes = bytes[:l-1]
 107  		}
 108  		stream.Write(bytes)
 109  	}
 110  }
 111  
 112  func (encoder *marshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 113  	return encoder.checkIsEmpty.IsEmpty(ptr)
 114  }
 115  
 116  type directMarshalerEncoder struct {
 117  	checkIsEmpty checkIsEmpty
 118  }
 119  
 120  func (encoder *directMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 121  	marshaler := *(*json.Marshaler)(ptr)
 122  	if marshaler == nil {
 123  		stream.WriteNil()
 124  		return
 125  	}
 126  	bytes, err := marshaler.MarshalJSON()
 127  	if err != nil {
 128  		stream.Error = err
 129  	} else {
 130  		stream.Write(bytes)
 131  	}
 132  }
 133  
 134  func (encoder *directMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 135  	return encoder.checkIsEmpty.IsEmpty(ptr)
 136  }
 137  
 138  type textMarshalerEncoder struct {
 139  	valType       reflect2.Type
 140  	stringEncoder ValEncoder
 141  	checkIsEmpty  checkIsEmpty
 142  }
 143  
 144  func (encoder *textMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 145  	obj := encoder.valType.UnsafeIndirect(ptr)
 146  	if encoder.valType.IsNullable() && reflect2.IsNil(obj) {
 147  		stream.WriteNil()
 148  		return
 149  	}
 150  	marshaler := (obj).(encoding.TextMarshaler)
 151  	bytes, err := marshaler.MarshalText()
 152  	if err != nil {
 153  		stream.Error = err
 154  	} else {
 155  		str := string(bytes)
 156  		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
 157  	}
 158  }
 159  
 160  func (encoder *textMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 161  	return encoder.checkIsEmpty.IsEmpty(ptr)
 162  }
 163  
 164  type directTextMarshalerEncoder struct {
 165  	stringEncoder ValEncoder
 166  	checkIsEmpty  checkIsEmpty
 167  }
 168  
 169  func (encoder *directTextMarshalerEncoder) Encode(ptr unsafe.Pointer, stream *Stream) {
 170  	marshaler := *(*encoding.TextMarshaler)(ptr)
 171  	if marshaler == nil {
 172  		stream.WriteNil()
 173  		return
 174  	}
 175  	bytes, err := marshaler.MarshalText()
 176  	if err != nil {
 177  		stream.Error = err
 178  	} else {
 179  		str := string(bytes)
 180  		encoder.stringEncoder.Encode(unsafe.Pointer(&str), stream)
 181  	}
 182  }
 183  
 184  func (encoder *directTextMarshalerEncoder) IsEmpty(ptr unsafe.Pointer) bool {
 185  	return encoder.checkIsEmpty.IsEmpty(ptr)
 186  }
 187  
 188  type unmarshalerDecoder struct {
 189  	valType reflect2.Type
 190  }
 191  
 192  func (decoder *unmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
 193  	valType := decoder.valType
 194  	obj := valType.UnsafeIndirect(ptr)
 195  	unmarshaler := obj.(json.Unmarshaler)
 196  	iter.nextToken()
 197  	iter.unreadByte() // skip spaces
 198  	bytes := iter.SkipAndReturnBytes()
 199  	err := unmarshaler.UnmarshalJSON(bytes)
 200  	if err != nil {
 201  		iter.ReportError("unmarshalerDecoder", err.Error())
 202  	}
 203  }
 204  
 205  type textUnmarshalerDecoder struct {
 206  	valType reflect2.Type
 207  }
 208  
 209  func (decoder *textUnmarshalerDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) {
 210  	valType := decoder.valType
 211  	obj := valType.UnsafeIndirect(ptr)
 212  	if reflect2.IsNil(obj) {
 213  		ptrType := valType.(*reflect2.UnsafePtrType)
 214  		elemType := ptrType.Elem()
 215  		elem := elemType.UnsafeNew()
 216  		ptrType.UnsafeSet(ptr, unsafe.Pointer(&elem))
 217  		obj = valType.UnsafeIndirect(ptr)
 218  	}
 219  	unmarshaler := (obj).(encoding.TextUnmarshaler)
 220  	str := iter.ReadString()
 221  	err := unmarshaler.UnmarshalText([]byte(str))
 222  	if err != nil {
 223  		iter.ReportError("textUnmarshalerDecoder", err.Error())
 224  	}
 225  }
 226