checkinit.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"sync"
   9  
  10  	"google.golang.org/protobuf/internal/errors"
  11  	"google.golang.org/protobuf/reflect/protoreflect"
  12  	"google.golang.org/protobuf/runtime/protoiface"
  13  )
  14  
  15  func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) {
  16  	var p pointer
  17  	if ms, ok := in.Message.(*messageState); ok {
  18  		p = ms.pointer()
  19  	} else {
  20  		p = in.Message.(*messageReflectWrapper).pointer()
  21  	}
  22  	return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
  23  }
  24  
  25  func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
  26  	mi.init()
  27  	if !mi.needsInitCheck {
  28  		return nil
  29  	}
  30  	if p.IsNil() {
  31  		for _, f := range mi.orderedCoderFields {
  32  			if f.isRequired {
  33  				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
  34  			}
  35  		}
  36  		return nil
  37  	}
  38  
  39  	var presence presence
  40  	if mi.presenceOffset.IsValid() {
  41  		presence = p.Apply(mi.presenceOffset).PresenceInfo()
  42  	}
  43  
  44  	if mi.extensionOffset.IsValid() {
  45  		e := p.Apply(mi.extensionOffset).Extensions()
  46  		if err := mi.isInitExtensions(e); err != nil {
  47  			return err
  48  		}
  49  	}
  50  	for _, f := range mi.orderedCoderFields {
  51  		if !f.isRequired && f.funcs.isInit == nil {
  52  			continue
  53  		}
  54  
  55  		if f.presenceIndex != noPresence {
  56  			if !presence.Present(f.presenceIndex) {
  57  				if f.isRequired {
  58  					return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
  59  				}
  60  				continue
  61  			}
  62  			if f.funcs.isInit != nil {
  63  				f.mi.init()
  64  				if f.mi.needsInitCheck {
  65  					if f.isLazy && p.Apply(f.offset).AtomicGetPointer().IsNil() {
  66  						lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
  67  						if !lazy.AllowedPartial() {
  68  							// Nothing to see here, it was checked on unmarshal
  69  							continue
  70  						}
  71  						mi.lazyUnmarshal(p, f.num)
  72  					}
  73  					if err := f.funcs.isInit(p.Apply(f.offset), f); err != nil {
  74  						return err
  75  					}
  76  				}
  77  			}
  78  			continue
  79  		}
  80  
  81  		fptr := p.Apply(f.offset)
  82  		if f.isPointer && fptr.Elem().IsNil() {
  83  			if f.isRequired {
  84  				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
  85  			}
  86  			continue
  87  		}
  88  		if f.funcs.isInit == nil {
  89  			continue
  90  		}
  91  		if err := f.funcs.isInit(fptr, f); err != nil {
  92  			return err
  93  		}
  94  	}
  95  	return nil
  96  }
  97  
  98  func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
  99  	if ext == nil {
 100  		return nil
 101  	}
 102  	for _, x := range *ext {
 103  		ei := getExtensionFieldInfo(x.Type())
 104  		if ei.funcs.isInit == nil || x.isUnexpandedLazy() {
 105  			continue
 106  		}
 107  		v := x.Value()
 108  		if !v.IsValid() {
 109  			continue
 110  		}
 111  		if err := ei.funcs.isInit(v); err != nil {
 112  			return err
 113  		}
 114  	}
 115  	return nil
 116  }
 117  
 118  var (
 119  	needsInitCheckMu  sync.Mutex
 120  	needsInitCheckMap sync.Map
 121  )
 122  
 123  // needsInitCheck reports whether a message needs to be checked for partial initialization.
 124  //
 125  // It returns true if the message transitively includes any required or extension fields.
 126  func needsInitCheck(md protoreflect.MessageDescriptor) bool {
 127  	if v, ok := needsInitCheckMap.Load(md); ok {
 128  		if has, ok := v.(bool); ok {
 129  			return has
 130  		}
 131  	}
 132  	needsInitCheckMu.Lock()
 133  	defer needsInitCheckMu.Unlock()
 134  	return needsInitCheckLocked(md)
 135  }
 136  
 137  func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) {
 138  	if v, ok := needsInitCheckMap.Load(md); ok {
 139  		// If has is true, we've previously determined that this message
 140  		// needs init checks.
 141  		//
 142  		// If has is false, we've previously determined that it can never
 143  		// be uninitialized.
 144  		//
 145  		// If has is not a bool, we've just encountered a cycle in the
 146  		// message graph. In this case, it is safe to return false: If
 147  		// the message does have required fields, we'll detect them later
 148  		// in the graph traversal.
 149  		has, ok := v.(bool)
 150  		return ok && has
 151  	}
 152  	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
 153  	defer func() {
 154  		needsInitCheckMap.Store(md, has)
 155  	}()
 156  	if md.RequiredNumbers().Len() > 0 {
 157  		return true
 158  	}
 159  	if md.ExtensionRanges().Len() > 0 {
 160  		return true
 161  	}
 162  	for i := 0; i < md.Fields().Len(); i++ {
 163  		fd := md.Fields().Get(i)
 164  		// Map keys are never messages, so just consider the map value.
 165  		if fd.IsMap() {
 166  			fd = fd.MapValue()
 167  		}
 168  		fmd := fd.Message()
 169  		if fmd != nil && needsInitCheckLocked(fmd) {
 170  			return true
 171  		}
 172  	}
 173  	return false
 174  }
 175