merge.go raw

   1  // Copyright 2020 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"fmt"
   9  	"reflect"
  10  
  11  	"google.golang.org/protobuf/proto"
  12  	"google.golang.org/protobuf/reflect/protoreflect"
  13  	"google.golang.org/protobuf/runtime/protoiface"
  14  )
  15  
  16  type mergeOptions struct{}
  17  
  18  func (o mergeOptions) Merge(dst, src proto.Message) {
  19  	proto.Merge(dst, src)
  20  }
  21  
  22  // merge is protoreflect.Methods.Merge.
  23  func (mi *MessageInfo) merge(in protoiface.MergeInput) protoiface.MergeOutput {
  24  	dp, ok := mi.getPointer(in.Destination)
  25  	if !ok {
  26  		return protoiface.MergeOutput{}
  27  	}
  28  	sp, ok := mi.getPointer(in.Source)
  29  	if !ok {
  30  		return protoiface.MergeOutput{}
  31  	}
  32  	mi.mergePointer(dp, sp, mergeOptions{})
  33  	return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
  34  }
  35  
  36  func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
  37  	mi.init()
  38  	if dst.IsNil() {
  39  		panic(fmt.Sprintf("invalid value: merging into nil message"))
  40  	}
  41  	if src.IsNil() {
  42  		return
  43  	}
  44  
  45  	var presenceSrc presence
  46  	var presenceDst presence
  47  	if mi.presenceOffset.IsValid() {
  48  		presenceSrc = src.Apply(mi.presenceOffset).PresenceInfo()
  49  		presenceDst = dst.Apply(mi.presenceOffset).PresenceInfo()
  50  	}
  51  
  52  	for _, f := range mi.orderedCoderFields {
  53  		if f.funcs.merge == nil {
  54  			continue
  55  		}
  56  		sfptr := src.Apply(f.offset)
  57  
  58  		if f.presenceIndex != noPresence {
  59  			if !presenceSrc.Present(f.presenceIndex) {
  60  				continue
  61  			}
  62  			dfptr := dst.Apply(f.offset)
  63  			if f.isLazy {
  64  				if sfptr.AtomicGetPointer().IsNil() {
  65  					mi.lazyUnmarshal(src, f.num)
  66  				}
  67  				if presenceDst.Present(f.presenceIndex) && dfptr.AtomicGetPointer().IsNil() {
  68  					mi.lazyUnmarshal(dst, f.num)
  69  				}
  70  			}
  71  			f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
  72  			presenceDst.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
  73  			continue
  74  		}
  75  
  76  		if f.isPointer && sfptr.Elem().IsNil() {
  77  			continue
  78  		}
  79  		f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
  80  	}
  81  	if mi.extensionOffset.IsValid() {
  82  		sext := src.Apply(mi.extensionOffset).Extensions()
  83  		dext := dst.Apply(mi.extensionOffset).Extensions()
  84  		if *dext == nil {
  85  			*dext = make(map[int32]ExtensionField)
  86  		}
  87  		for num, sx := range *sext {
  88  			xt := sx.Type()
  89  			xi := getExtensionFieldInfo(xt)
  90  			if xi.funcs.merge == nil {
  91  				continue
  92  			}
  93  			dx := (*dext)[num]
  94  			var dv protoreflect.Value
  95  			if dx.Type() == sx.Type() {
  96  				dv = dx.Value()
  97  			}
  98  			if !dv.IsValid() && xi.unmarshalNeedsValue {
  99  				dv = xt.New()
 100  			}
 101  			dv = xi.funcs.merge(dv, sx.Value(), opts)
 102  			dx.Set(sx.Type(), dv)
 103  			(*dext)[num] = dx
 104  		}
 105  	}
 106  	if mi.unknownOffset.IsValid() {
 107  		su := mi.getUnknownBytes(src)
 108  		if su != nil && len(*su) > 0 {
 109  			du := mi.mutableUnknownBytes(dst)
 110  			*du = append(*du, *su...)
 111  		}
 112  	}
 113  }
 114  
 115  func mergeScalarValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 116  	return src
 117  }
 118  
 119  func mergeBytesValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 120  	return protoreflect.ValueOfBytes(append(emptyBuf[:], src.Bytes()...))
 121  }
 122  
 123  func mergeListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 124  	dstl := dst.List()
 125  	srcl := src.List()
 126  	for i, llen := 0, srcl.Len(); i < llen; i++ {
 127  		dstl.Append(srcl.Get(i))
 128  	}
 129  	return dst
 130  }
 131  
 132  func mergeBytesListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 133  	dstl := dst.List()
 134  	srcl := src.List()
 135  	for i, llen := 0, srcl.Len(); i < llen; i++ {
 136  		sb := srcl.Get(i).Bytes()
 137  		db := append(emptyBuf[:], sb...)
 138  		dstl.Append(protoreflect.ValueOfBytes(db))
 139  	}
 140  	return dst
 141  }
 142  
 143  func mergeMessageListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 144  	dstl := dst.List()
 145  	srcl := src.List()
 146  	for i, llen := 0, srcl.Len(); i < llen; i++ {
 147  		sm := srcl.Get(i).Message()
 148  		dm := proto.Clone(sm.Interface()).ProtoReflect()
 149  		dstl.Append(protoreflect.ValueOfMessage(dm))
 150  	}
 151  	return dst
 152  }
 153  
 154  func mergeMessageValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
 155  	opts.Merge(dst.Message().Interface(), src.Message().Interface())
 156  	return dst
 157  }
 158  
 159  func mergeMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 160  	if f.mi != nil {
 161  		if dst.Elem().IsNil() {
 162  			dst.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
 163  		}
 164  		f.mi.mergePointer(dst.Elem(), src.Elem(), opts)
 165  	} else {
 166  		dm := dst.AsValueOf(f.ft).Elem()
 167  		sm := src.AsValueOf(f.ft).Elem()
 168  		if dm.IsNil() {
 169  			dm.Set(reflect.New(f.ft.Elem()))
 170  		}
 171  		opts.Merge(asMessage(dm), asMessage(sm))
 172  	}
 173  }
 174  
 175  func mergeMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 176  	for _, sp := range src.PointerSlice() {
 177  		dm := reflect.New(f.ft.Elem().Elem())
 178  		if f.mi != nil {
 179  			f.mi.mergePointer(pointerOfValue(dm), sp, opts)
 180  		} else {
 181  			opts.Merge(asMessage(dm), asMessage(sp.AsValueOf(f.ft.Elem().Elem())))
 182  		}
 183  		dst.AppendPointerSlice(pointerOfValue(dm))
 184  	}
 185  }
 186  
 187  func mergeBytes(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
 188  	*dst.Bytes() = append(emptyBuf[:], *src.Bytes()...)
 189  }
 190  
 191  func mergeBytesNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
 192  	v := *src.Bytes()
 193  	if len(v) > 0 {
 194  		*dst.Bytes() = append(emptyBuf[:], v...)
 195  	}
 196  }
 197  
 198  func mergeBytesSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
 199  	ds := dst.BytesSlice()
 200  	for _, v := range *src.BytesSlice() {
 201  		*ds = append(*ds, append(emptyBuf[:], v...))
 202  	}
 203  }
 204