encode.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"math"
   9  	"sort"
  10  	"sync/atomic"
  11  
  12  	"google.golang.org/protobuf/internal/flags"
  13  	"google.golang.org/protobuf/internal/protolazy"
  14  	"google.golang.org/protobuf/proto"
  15  	piface "google.golang.org/protobuf/runtime/protoiface"
  16  )
  17  
  18  type marshalOptions struct {
  19  	flags piface.MarshalInputFlags
  20  }
  21  
  22  func (o marshalOptions) Options() proto.MarshalOptions {
  23  	return proto.MarshalOptions{
  24  		AllowPartial:  true,
  25  		Deterministic: o.Deterministic(),
  26  		UseCachedSize: o.UseCachedSize(),
  27  	}
  28  }
  29  
  30  func (o marshalOptions) Deterministic() bool { return o.flags&piface.MarshalDeterministic != 0 }
  31  func (o marshalOptions) UseCachedSize() bool { return o.flags&piface.MarshalUseCachedSize != 0 }
  32  
  33  // size is protoreflect.Methods.Size.
  34  func (mi *MessageInfo) size(in piface.SizeInput) piface.SizeOutput {
  35  	var p pointer
  36  	if ms, ok := in.Message.(*messageState); ok {
  37  		p = ms.pointer()
  38  	} else {
  39  		p = in.Message.(*messageReflectWrapper).pointer()
  40  	}
  41  	size := mi.sizePointer(p, marshalOptions{
  42  		flags: in.Flags,
  43  	})
  44  	return piface.SizeOutput{Size: size}
  45  }
  46  
  47  func (mi *MessageInfo) sizePointer(p pointer, opts marshalOptions) (size int) {
  48  	mi.init()
  49  	if p.IsNil() {
  50  		return 0
  51  	}
  52  	if opts.UseCachedSize() && mi.sizecacheOffset.IsValid() {
  53  		// The size cache contains the size + 1, to allow the
  54  		// zero value to be invalid, while also allowing for a
  55  		// 0 size to be cached.
  56  		if size := atomic.LoadInt32(p.Apply(mi.sizecacheOffset).Int32()); size > 0 {
  57  			return int(size - 1)
  58  		}
  59  	}
  60  	return mi.sizePointerSlow(p, opts)
  61  }
  62  
  63  func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int) {
  64  	if flags.ProtoLegacy && mi.isMessageSet {
  65  		size = sizeMessageSet(mi, p, opts)
  66  		if mi.sizecacheOffset.IsValid() {
  67  			atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size+1))
  68  		}
  69  		return size
  70  	}
  71  	if mi.extensionOffset.IsValid() {
  72  		e := p.Apply(mi.extensionOffset).Extensions()
  73  		size += mi.sizeExtensions(e, opts)
  74  	}
  75  
  76  	var lazy **protolazy.XXX_lazyUnmarshalInfo
  77  	var presence presence
  78  	if mi.presenceOffset.IsValid() {
  79  		presence = p.Apply(mi.presenceOffset).PresenceInfo()
  80  		if mi.lazyOffset.IsValid() {
  81  			lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
  82  		}
  83  	}
  84  
  85  	for _, f := range mi.orderedCoderFields {
  86  		if f.funcs.size == nil {
  87  			continue
  88  		}
  89  		fptr := p.Apply(f.offset)
  90  
  91  		if f.presenceIndex != noPresence {
  92  			if !presence.Present(f.presenceIndex) {
  93  				continue
  94  			}
  95  
  96  			if f.isLazy && fptr.AtomicGetPointer().IsNil() {
  97  				if lazyFields(opts) {
  98  					size += (*lazy).SizeField(uint32(f.num))
  99  					continue
 100  				} else {
 101  					mi.lazyUnmarshal(p, f.num)
 102  				}
 103  			}
 104  			size += f.funcs.size(fptr, f, opts)
 105  			continue
 106  		}
 107  
 108  		if f.isPointer && fptr.Elem().IsNil() {
 109  			continue
 110  		}
 111  		size += f.funcs.size(fptr, f, opts)
 112  	}
 113  	if mi.unknownOffset.IsValid() {
 114  		if u := mi.getUnknownBytes(p); u != nil {
 115  			size += len(*u)
 116  		}
 117  	}
 118  	if mi.sizecacheOffset.IsValid() {
 119  		if size > (math.MaxInt32 - 1) {
 120  			// The size is too large for the int32 sizecache field.
 121  			// We will need to recompute the size when encoding;
 122  			// unfortunately expensive, but better than invalid output.
 123  			atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), 0)
 124  		} else {
 125  			// The size cache contains the size + 1, to allow the
 126  			// zero value to be invalid, while also allowing for a
 127  			// 0 size to be cached.
 128  			atomic.StoreInt32(p.Apply(mi.sizecacheOffset).Int32(), int32(size+1))
 129  		}
 130  	}
 131  	return size
 132  }
 133  
 134  // marshal is protoreflect.Methods.Marshal.
 135  func (mi *MessageInfo) marshal(in piface.MarshalInput) (out piface.MarshalOutput, err error) {
 136  	var p pointer
 137  	if ms, ok := in.Message.(*messageState); ok {
 138  		p = ms.pointer()
 139  	} else {
 140  		p = in.Message.(*messageReflectWrapper).pointer()
 141  	}
 142  	b, err := mi.marshalAppendPointer(in.Buf, p, marshalOptions{
 143  		flags: in.Flags,
 144  	})
 145  	return piface.MarshalOutput{Buf: b}, err
 146  }
 147  
 148  func (mi *MessageInfo) marshalAppendPointer(b []byte, p pointer, opts marshalOptions) ([]byte, error) {
 149  	mi.init()
 150  	if p.IsNil() {
 151  		return b, nil
 152  	}
 153  	if flags.ProtoLegacy && mi.isMessageSet {
 154  		return marshalMessageSet(mi, b, p, opts)
 155  	}
 156  	var err error
 157  	// The old marshaler encodes extensions at beginning.
 158  	if mi.extensionOffset.IsValid() {
 159  		e := p.Apply(mi.extensionOffset).Extensions()
 160  		// TODO: Special handling for MessageSet?
 161  		b, err = mi.appendExtensions(b, e, opts)
 162  		if err != nil {
 163  			return b, err
 164  		}
 165  	}
 166  
 167  	var lazy **protolazy.XXX_lazyUnmarshalInfo
 168  	var presence presence
 169  	if mi.presenceOffset.IsValid() {
 170  		presence = p.Apply(mi.presenceOffset).PresenceInfo()
 171  		if mi.lazyOffset.IsValid() {
 172  			lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
 173  		}
 174  	}
 175  
 176  	for _, f := range mi.orderedCoderFields {
 177  		if f.funcs.marshal == nil {
 178  			continue
 179  		}
 180  		fptr := p.Apply(f.offset)
 181  
 182  		if f.presenceIndex != noPresence {
 183  			if !presence.Present(f.presenceIndex) {
 184  				continue
 185  			}
 186  			if f.isLazy {
 187  				// Be careful, this field needs to be read atomically, like for a get
 188  				if f.isPointer && fptr.AtomicGetPointer().IsNil() {
 189  					if lazyFields(opts) {
 190  						b, _ = (*lazy).AppendField(b, uint32(f.num))
 191  						continue
 192  					} else {
 193  						mi.lazyUnmarshal(p, f.num)
 194  					}
 195  				}
 196  
 197  				b, err = f.funcs.marshal(b, fptr, f, opts)
 198  				if err != nil {
 199  					return b, err
 200  				}
 201  				continue
 202  			} else if f.isPointer && fptr.Elem().IsNil() {
 203  				continue
 204  			}
 205  			b, err = f.funcs.marshal(b, fptr, f, opts)
 206  			if err != nil {
 207  				return b, err
 208  			}
 209  			continue
 210  		}
 211  
 212  		if f.isPointer && fptr.Elem().IsNil() {
 213  			continue
 214  		}
 215  		b, err = f.funcs.marshal(b, fptr, f, opts)
 216  		if err != nil {
 217  			return b, err
 218  		}
 219  	}
 220  	if mi.unknownOffset.IsValid() && !mi.isMessageSet {
 221  		if u := mi.getUnknownBytes(p); u != nil {
 222  			b = append(b, (*u)...)
 223  		}
 224  	}
 225  	return b, nil
 226  }
 227  
 228  // fullyLazyExtensions returns true if we should attempt to keep extensions lazy over size and marshal.
 229  func fullyLazyExtensions(opts marshalOptions) bool {
 230  	// When deterministic marshaling is requested, force an unmarshal for lazy
 231  	// extensions to produce a deterministic result, instead of passing through
 232  	// bytes lazily that may or may not match what Go Protobuf would produce.
 233  	return opts.flags&piface.MarshalDeterministic == 0
 234  }
 235  
 236  // lazyFields returns true if we should attempt to keep fields lazy over size and marshal.
 237  func lazyFields(opts marshalOptions) bool {
 238  	// When deterministic marshaling is requested, force an unmarshal for lazy
 239  	// fields to produce a deterministic result, instead of passing through
 240  	// bytes lazily that may or may not match what Go Protobuf would produce.
 241  	return opts.flags&piface.MarshalDeterministic == 0
 242  }
 243  
 244  func (mi *MessageInfo) sizeExtensions(ext *map[int32]ExtensionField, opts marshalOptions) (n int) {
 245  	if ext == nil {
 246  		return 0
 247  	}
 248  	for _, x := range *ext {
 249  		xi := getExtensionFieldInfo(x.Type())
 250  		if xi.funcs.size == nil {
 251  			continue
 252  		}
 253  		if fullyLazyExtensions(opts) {
 254  			// Don't expand the extension, instead use the buffer to calculate size
 255  			if lb := x.lazyBuffer(); lb != nil {
 256  				// We got hold of the buffer, so it's still lazy.
 257  				n += len(lb)
 258  				continue
 259  			}
 260  		}
 261  		n += xi.funcs.size(x.Value(), xi.tagsize, opts)
 262  	}
 263  	return n
 264  }
 265  
 266  func (mi *MessageInfo) appendExtensions(b []byte, ext *map[int32]ExtensionField, opts marshalOptions) ([]byte, error) {
 267  	if ext == nil {
 268  		return b, nil
 269  	}
 270  
 271  	switch len(*ext) {
 272  	case 0:
 273  		return b, nil
 274  	case 1:
 275  		// Fast-path for one extension: Don't bother sorting the keys.
 276  		var err error
 277  		for _, x := range *ext {
 278  			xi := getExtensionFieldInfo(x.Type())
 279  			if fullyLazyExtensions(opts) {
 280  				// Don't expand the extension if it's still in wire format, instead use the buffer content.
 281  				if lb := x.lazyBuffer(); lb != nil {
 282  					b = append(b, lb...)
 283  					continue
 284  				}
 285  			}
 286  			b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
 287  		}
 288  		return b, err
 289  	default:
 290  		// Sort the keys to provide a deterministic encoding.
 291  		// Not sure this is required, but the old code does it.
 292  		keys := make([]int, 0, len(*ext))
 293  		for k := range *ext {
 294  			keys = append(keys, int(k))
 295  		}
 296  		sort.Ints(keys)
 297  		var err error
 298  		for _, k := range keys {
 299  			x := (*ext)[int32(k)]
 300  			xi := getExtensionFieldInfo(x.Type())
 301  			if fullyLazyExtensions(opts) {
 302  				// Don't expand the extension if it's still in wire format, instead use the buffer content.
 303  				if lb := x.lazyBuffer(); lb != nil {
 304  					b = append(b, lb...)
 305  					continue
 306  				}
 307  			}
 308  			b, err = xi.funcs.marshal(b, x.Value(), xi.wiretag, opts)
 309  			if err != nil {
 310  				return b, err
 311  			}
 312  		}
 313  		return b, nil
 314  	}
 315  }
 316