codec_field_opaque.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"fmt"
   9  	"reflect"
  10  
  11  	"google.golang.org/protobuf/encoding/protowire"
  12  	"google.golang.org/protobuf/internal/errors"
  13  	"google.golang.org/protobuf/reflect/protoreflect"
  14  )
  15  
  16  func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
  17  	mi := getMessageInfo(ft)
  18  	if mi == nil {
  19  		panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft))
  20  	}
  21  	switch fd.Kind() {
  22  	case protoreflect.MessageKind:
  23  		return mi, pointerCoderFuncs{
  24  			size:      sizeOpaqueMessage,
  25  			marshal:   appendOpaqueMessage,
  26  			unmarshal: consumeOpaqueMessage,
  27  			isInit:    isInitOpaqueMessage,
  28  			merge:     mergeOpaqueMessage,
  29  		}
  30  	case protoreflect.GroupKind:
  31  		return mi, pointerCoderFuncs{
  32  			size:      sizeOpaqueGroup,
  33  			marshal:   appendOpaqueGroup,
  34  			unmarshal: consumeOpaqueGroup,
  35  			isInit:    isInitOpaqueMessage,
  36  			merge:     mergeOpaqueMessage,
  37  		}
  38  	}
  39  	panic("unexpected field kind")
  40  }
  41  
  42  func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  43  	return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize
  44  }
  45  
  46  func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  47  	mp := p.AtomicGetPointer()
  48  	calculatedSize := f.mi.sizePointer(mp, opts)
  49  	b = protowire.AppendVarint(b, f.wiretag)
  50  	b = protowire.AppendVarint(b, uint64(calculatedSize))
  51  	before := len(b)
  52  	b, err := f.mi.marshalAppendPointer(b, mp, opts)
  53  	if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
  54  		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
  55  	}
  56  	return b, err
  57  }
  58  
  59  func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  60  	if wtyp != protowire.BytesType {
  61  		return out, errUnknown
  62  	}
  63  	v, n := protowire.ConsumeBytes(b)
  64  	if n < 0 {
  65  		return out, errDecode
  66  	}
  67  	mp := p.AtomicGetPointer()
  68  	if mp.IsNil() {
  69  		mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
  70  	}
  71  	o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
  72  	if err != nil {
  73  		return out, err
  74  	}
  75  	out.n = n
  76  	out.initialized = o.initialized
  77  	return out, nil
  78  }
  79  
  80  func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error {
  81  	mp := p.AtomicGetPointer()
  82  	if mp.IsNil() {
  83  		return nil
  84  	}
  85  	return f.mi.checkInitializedPointer(mp)
  86  }
  87  
  88  func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  89  	dstmp := dst.AtomicGetPointer()
  90  	if dstmp.IsNil() {
  91  		dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
  92  	}
  93  	f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts)
  94  }
  95  
  96  func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  97  	return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts)
  98  }
  99  
 100  func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 101  	b = protowire.AppendVarint(b, f.wiretag) // start group
 102  	b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts)
 103  	b = protowire.AppendVarint(b, f.wiretag+1) // end group
 104  	return b, err
 105  }
 106  
 107  func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 108  	if wtyp != protowire.StartGroupType {
 109  		return out, errUnknown
 110  	}
 111  	mp := p.AtomicGetPointer()
 112  	if mp.IsNil() {
 113  		mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
 114  	}
 115  	o, e := f.mi.unmarshalPointer(b, mp, f.num, opts)
 116  	return o, e
 117  }
 118  
 119  func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
 120  	if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
 121  		panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft))
 122  	}
 123  	mt := ft.Elem().Elem() // *[]*T -> *T
 124  	mi := getMessageInfo(mt)
 125  	if mi == nil {
 126  		panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt))
 127  	}
 128  	switch fd.Kind() {
 129  	case protoreflect.MessageKind:
 130  		return mi, pointerCoderFuncs{
 131  			size:      sizeOpaqueMessageSlice,
 132  			marshal:   appendOpaqueMessageSlice,
 133  			unmarshal: consumeOpaqueMessageSlice,
 134  			isInit:    isInitOpaqueMessageSlice,
 135  			merge:     mergeOpaqueMessageSlice,
 136  		}
 137  	case protoreflect.GroupKind:
 138  		return mi, pointerCoderFuncs{
 139  			size:      sizeOpaqueGroupSlice,
 140  			marshal:   appendOpaqueGroupSlice,
 141  			unmarshal: consumeOpaqueGroupSlice,
 142  			isInit:    isInitOpaqueMessageSlice,
 143  			merge:     mergeOpaqueMessageSlice,
 144  		}
 145  	}
 146  	panic("unexpected field kind")
 147  }
 148  
 149  func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
 150  	s := p.AtomicGetPointer().PointerSlice()
 151  	n := 0
 152  	for _, v := range s {
 153  		n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
 154  	}
 155  	return n
 156  }
 157  
 158  func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 159  	s := p.AtomicGetPointer().PointerSlice()
 160  	var err error
 161  	for _, v := range s {
 162  		b = protowire.AppendVarint(b, f.wiretag)
 163  		siz := f.mi.sizePointer(v, opts)
 164  		b = protowire.AppendVarint(b, uint64(siz))
 165  		before := len(b)
 166  		b, err = f.mi.marshalAppendPointer(b, v, opts)
 167  		if err != nil {
 168  			return b, err
 169  		}
 170  		if measuredSize := len(b) - before; siz != measuredSize {
 171  			return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
 172  		}
 173  	}
 174  	return b, nil
 175  }
 176  
 177  func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 178  	if wtyp != protowire.BytesType {
 179  		return out, errUnknown
 180  	}
 181  	v, n := protowire.ConsumeBytes(b)
 182  	if n < 0 {
 183  		return out, errDecode
 184  	}
 185  	mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
 186  	o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
 187  	if err != nil {
 188  		return out, err
 189  	}
 190  	sp := p.AtomicGetPointer()
 191  	if sp.IsNil() {
 192  		sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
 193  	}
 194  	sp.AppendPointerSlice(mp)
 195  	out.n = n
 196  	out.initialized = o.initialized
 197  	return out, nil
 198  }
 199  
 200  func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error {
 201  	sp := p.AtomicGetPointer()
 202  	if sp.IsNil() {
 203  		return nil
 204  	}
 205  	s := sp.PointerSlice()
 206  	for _, v := range s {
 207  		if err := f.mi.checkInitializedPointer(v); err != nil {
 208  			return err
 209  		}
 210  	}
 211  	return nil
 212  }
 213  
 214  func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 215  	ds := dst.AtomicGetPointer()
 216  	if ds.IsNil() {
 217  		ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
 218  	}
 219  	for _, sp := range src.AtomicGetPointer().PointerSlice() {
 220  		dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
 221  		f.mi.mergePointer(dm, sp, opts)
 222  		ds.AppendPointerSlice(dm)
 223  	}
 224  }
 225  
 226  func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
 227  	s := p.AtomicGetPointer().PointerSlice()
 228  	n := 0
 229  	for _, v := range s {
 230  		n += 2*f.tagsize + f.mi.sizePointer(v, opts)
 231  	}
 232  	return n
 233  }
 234  
 235  func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 236  	s := p.AtomicGetPointer().PointerSlice()
 237  	var err error
 238  	for _, v := range s {
 239  		b = protowire.AppendVarint(b, f.wiretag) // start group
 240  		b, err = f.mi.marshalAppendPointer(b, v, opts)
 241  		if err != nil {
 242  			return b, err
 243  		}
 244  		b = protowire.AppendVarint(b, f.wiretag+1) // end group
 245  	}
 246  	return b, nil
 247  }
 248  
 249  func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 250  	if wtyp != protowire.StartGroupType {
 251  		return out, errUnknown
 252  	}
 253  	mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
 254  	out, err = f.mi.unmarshalPointer(b, mp, f.num, opts)
 255  	if err != nil {
 256  		return out, err
 257  	}
 258  	sp := p.AtomicGetPointer()
 259  	if sp.IsNil() {
 260  		sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
 261  	}
 262  	sp.AppendPointerSlice(mp)
 263  	return out, err
 264  }
 265