defaults.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"google.golang.org/protobuf/reflect/protoreflect"
   9  )
  10  
  11  // SetDefaults sets unpopulated scalar fields to their default values.
  12  // Fields within a oneof are not set even if they have a default value.
  13  // SetDefaults is recursively called upon any populated message fields.
  14  func SetDefaults(m Message) {
  15  	if m != nil {
  16  		setDefaults(MessageReflect(m))
  17  	}
  18  }
  19  
  20  func setDefaults(m protoreflect.Message) {
  21  	fds := m.Descriptor().Fields()
  22  	for i := 0; i < fds.Len(); i++ {
  23  		fd := fds.Get(i)
  24  		if !m.Has(fd) {
  25  			if fd.HasDefault() && fd.ContainingOneof() == nil {
  26  				v := fd.Default()
  27  				if fd.Kind() == protoreflect.BytesKind {
  28  					v = protoreflect.ValueOf(append([]byte(nil), v.Bytes()...)) // copy the default bytes
  29  				}
  30  				m.Set(fd, v)
  31  			}
  32  			continue
  33  		}
  34  	}
  35  
  36  	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  37  		switch {
  38  		// Handle singular message.
  39  		case fd.Cardinality() != protoreflect.Repeated:
  40  			if fd.Message() != nil {
  41  				setDefaults(m.Get(fd).Message())
  42  			}
  43  		// Handle list of messages.
  44  		case fd.IsList():
  45  			if fd.Message() != nil {
  46  				ls := m.Get(fd).List()
  47  				for i := 0; i < ls.Len(); i++ {
  48  					setDefaults(ls.Get(i).Message())
  49  				}
  50  			}
  51  		// Handle map of messages.
  52  		case fd.IsMap():
  53  			if fd.MapValue().Message() != nil {
  54  				ms := m.Get(fd).Map()
  55  				ms.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
  56  					setDefaults(v.Message())
  57  					return true
  58  				})
  59  			}
  60  		}
  61  		return true
  62  	})
  63  }
  64