validate.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package impl
   6  
   7  import (
   8  	"fmt"
   9  	"math"
  10  	"math/bits"
  11  	"reflect"
  12  	"unicode/utf8"
  13  
  14  	"google.golang.org/protobuf/encoding/protowire"
  15  	"google.golang.org/protobuf/internal/encoding/messageset"
  16  	"google.golang.org/protobuf/internal/flags"
  17  	"google.golang.org/protobuf/internal/genid"
  18  	"google.golang.org/protobuf/internal/strs"
  19  	"google.golang.org/protobuf/reflect/protoreflect"
  20  	"google.golang.org/protobuf/reflect/protoregistry"
  21  	"google.golang.org/protobuf/runtime/protoiface"
  22  )
  23  
  24  // ValidationStatus is the result of validating the wire-format encoding of a message.
  25  type ValidationStatus int
  26  
  27  const (
  28  	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
  29  	// The validator was unable to render a judgement.
  30  	//
  31  	// The only causes of this status are an aberrant message type appearing somewhere
  32  	// in the message or a failure in the extension resolver.
  33  	ValidationUnknown ValidationStatus = iota + 1
  34  
  35  	// ValidationInvalid indicates that unmarshaling the message will fail.
  36  	ValidationInvalid
  37  
  38  	// ValidationValid indicates that unmarshaling the message will succeed.
  39  	ValidationValid
  40  
  41  	// ValidationWrongWireType indicates that a validated field does not have
  42  	// the expected wire type.
  43  	ValidationWrongWireType
  44  )
  45  
  46  func (v ValidationStatus) String() string {
  47  	switch v {
  48  	case ValidationUnknown:
  49  		return "ValidationUnknown"
  50  	case ValidationInvalid:
  51  		return "ValidationInvalid"
  52  	case ValidationValid:
  53  		return "ValidationValid"
  54  	default:
  55  		return fmt.Sprintf("ValidationStatus(%d)", int(v))
  56  	}
  57  }
  58  
  59  // Validate determines whether the contents of the buffer are a valid wire encoding
  60  // of the message type.
  61  //
  62  // This function is exposed for testing.
  63  func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
  64  	mi, ok := mt.(*MessageInfo)
  65  	if !ok {
  66  		return out, ValidationUnknown
  67  	}
  68  	if in.Resolver == nil {
  69  		in.Resolver = protoregistry.GlobalTypes
  70  	}
  71  	if in.Depth == 0 {
  72  		in.Depth = protowire.DefaultRecursionLimit
  73  	}
  74  	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
  75  		flags:    in.Flags,
  76  		resolver: in.Resolver,
  77  		depth:    in.Depth,
  78  	})
  79  	if o.initialized {
  80  		out.Flags |= protoiface.UnmarshalInitialized
  81  	}
  82  	return out, st
  83  }
  84  
  85  type validationInfo struct {
  86  	mi               *MessageInfo
  87  	typ              validationType
  88  	keyType, valType validationType
  89  
  90  	// For non-required fields, requiredBit is 0.
  91  	//
  92  	// For required fields, requiredBit's nth bit is set, where n is a
  93  	// unique index in the range [0, MessageInfo.numRequiredFields).
  94  	//
  95  	// If there are more than 64 required fields, requiredBit is 0.
  96  	requiredBit uint64
  97  }
  98  
  99  type validationType uint8
 100  
 101  const (
 102  	validationTypeOther validationType = iota
 103  	validationTypeMessage
 104  	validationTypeGroup
 105  	validationTypeMap
 106  	validationTypeRepeatedVarint
 107  	validationTypeRepeatedFixed32
 108  	validationTypeRepeatedFixed64
 109  	validationTypeVarint
 110  	validationTypeFixed32
 111  	validationTypeFixed64
 112  	validationTypeBytes
 113  	validationTypeUTF8String
 114  	validationTypeMessageSetItem
 115  )
 116  
 117  func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
 118  	var vi validationInfo
 119  	switch {
 120  	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
 121  		switch fd.Kind() {
 122  		case protoreflect.MessageKind:
 123  			vi.typ = validationTypeMessage
 124  			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
 125  				vi.mi = getMessageInfo(ot.Field(0).Type)
 126  			}
 127  		case protoreflect.GroupKind:
 128  			vi.typ = validationTypeGroup
 129  			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
 130  				vi.mi = getMessageInfo(ot.Field(0).Type)
 131  			}
 132  		case protoreflect.StringKind:
 133  			if strs.EnforceUTF8(fd) {
 134  				vi.typ = validationTypeUTF8String
 135  			}
 136  		}
 137  	default:
 138  		vi = newValidationInfo(fd, ft)
 139  	}
 140  	if fd.Cardinality() == protoreflect.Required {
 141  		// Avoid overflow. The required field check is done with a 64-bit mask, with
 142  		// any message containing more than 64 required fields always reported as
 143  		// potentially uninitialized, so it is not important to get a precise count
 144  		// of the required fields past 64.
 145  		if mi.numRequiredFields < math.MaxUint8 {
 146  			mi.numRequiredFields++
 147  			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
 148  		}
 149  	}
 150  	return vi
 151  }
 152  
 153  func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
 154  	var vi validationInfo
 155  	switch {
 156  	case fd.IsList():
 157  		switch fd.Kind() {
 158  		case protoreflect.MessageKind:
 159  			vi.typ = validationTypeMessage
 160  
 161  			if ft.Kind() == reflect.Ptr {
 162  				// Repeated opaque message fields are *[]*T.
 163  				ft = ft.Elem()
 164  			}
 165  
 166  			if ft.Kind() == reflect.Slice {
 167  				vi.mi = getMessageInfo(ft.Elem())
 168  			}
 169  		case protoreflect.GroupKind:
 170  			vi.typ = validationTypeGroup
 171  
 172  			if ft.Kind() == reflect.Ptr {
 173  				// Repeated opaque message fields are *[]*T.
 174  				ft = ft.Elem()
 175  			}
 176  
 177  			if ft.Kind() == reflect.Slice {
 178  				vi.mi = getMessageInfo(ft.Elem())
 179  			}
 180  		case protoreflect.StringKind:
 181  			vi.typ = validationTypeBytes
 182  			if strs.EnforceUTF8(fd) {
 183  				vi.typ = validationTypeUTF8String
 184  			}
 185  		default:
 186  			switch wireTypes[fd.Kind()] {
 187  			case protowire.VarintType:
 188  				vi.typ = validationTypeRepeatedVarint
 189  			case protowire.Fixed32Type:
 190  				vi.typ = validationTypeRepeatedFixed32
 191  			case protowire.Fixed64Type:
 192  				vi.typ = validationTypeRepeatedFixed64
 193  			}
 194  		}
 195  	case fd.IsMap():
 196  		vi.typ = validationTypeMap
 197  		switch fd.MapKey().Kind() {
 198  		case protoreflect.StringKind:
 199  			if strs.EnforceUTF8(fd) {
 200  				vi.keyType = validationTypeUTF8String
 201  			}
 202  		}
 203  		switch fd.MapValue().Kind() {
 204  		case protoreflect.MessageKind:
 205  			vi.valType = validationTypeMessage
 206  			if ft.Kind() == reflect.Map {
 207  				vi.mi = getMessageInfo(ft.Elem())
 208  			}
 209  		case protoreflect.StringKind:
 210  			if strs.EnforceUTF8(fd) {
 211  				vi.valType = validationTypeUTF8String
 212  			}
 213  		}
 214  	default:
 215  		switch fd.Kind() {
 216  		case protoreflect.MessageKind:
 217  			vi.typ = validationTypeMessage
 218  			vi.mi = getMessageInfo(ft)
 219  		case protoreflect.GroupKind:
 220  			vi.typ = validationTypeGroup
 221  			vi.mi = getMessageInfo(ft)
 222  		case protoreflect.StringKind:
 223  			vi.typ = validationTypeBytes
 224  			if strs.EnforceUTF8(fd) {
 225  				vi.typ = validationTypeUTF8String
 226  			}
 227  		default:
 228  			switch wireTypes[fd.Kind()] {
 229  			case protowire.VarintType:
 230  				vi.typ = validationTypeVarint
 231  			case protowire.Fixed32Type:
 232  				vi.typ = validationTypeFixed32
 233  			case protowire.Fixed64Type:
 234  				vi.typ = validationTypeFixed64
 235  			case protowire.BytesType:
 236  				vi.typ = validationTypeBytes
 237  			}
 238  		}
 239  	}
 240  	return vi
 241  }
 242  
 243  func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
 244  	mi.init()
 245  	type validationState struct {
 246  		typ              validationType
 247  		keyType, valType validationType
 248  		endGroup         protowire.Number
 249  		mi               *MessageInfo
 250  		tail             []byte
 251  		requiredMask     uint64
 252  	}
 253  
 254  	// Pre-allocate some slots to avoid repeated slice reallocation.
 255  	states := make([]validationState, 0, 16)
 256  	states = append(states, validationState{
 257  		typ: validationTypeMessage,
 258  		mi:  mi,
 259  	})
 260  	if groupTag > 0 {
 261  		states[0].typ = validationTypeGroup
 262  		states[0].endGroup = groupTag
 263  	}
 264  	if opts.depth--; opts.depth < 0 {
 265  		return out, ValidationInvalid
 266  	}
 267  	initialized := true
 268  	start := len(b)
 269  State:
 270  	for len(states) > 0 {
 271  		st := &states[len(states)-1]
 272  		for len(b) > 0 {
 273  			// Parse the tag (field number and wire type).
 274  			var tag uint64
 275  			if b[0] < 0x80 {
 276  				tag = uint64(b[0])
 277  				b = b[1:]
 278  			} else if len(b) >= 2 && b[1] < 128 {
 279  				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
 280  				b = b[2:]
 281  			} else {
 282  				var n int
 283  				tag, n = protowire.ConsumeVarint(b)
 284  				if n < 0 {
 285  					return out, ValidationInvalid
 286  				}
 287  				b = b[n:]
 288  			}
 289  			var num protowire.Number
 290  			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
 291  				return out, ValidationInvalid
 292  			} else {
 293  				num = protowire.Number(n)
 294  			}
 295  			wtyp := protowire.Type(tag & 7)
 296  
 297  			if wtyp == protowire.EndGroupType {
 298  				if st.endGroup == num {
 299  					goto PopState
 300  				}
 301  				return out, ValidationInvalid
 302  			}
 303  			var vi validationInfo
 304  			switch {
 305  			case st.typ == validationTypeMap:
 306  				switch num {
 307  				case genid.MapEntry_Key_field_number:
 308  					vi.typ = st.keyType
 309  				case genid.MapEntry_Value_field_number:
 310  					vi.typ = st.valType
 311  					vi.mi = st.mi
 312  					vi.requiredBit = 1
 313  				}
 314  			case flags.ProtoLegacy && st.mi.isMessageSet:
 315  				switch num {
 316  				case messageset.FieldItem:
 317  					vi.typ = validationTypeMessageSetItem
 318  				}
 319  			default:
 320  				var f *coderFieldInfo
 321  				if int(num) < len(st.mi.denseCoderFields) {
 322  					f = st.mi.denseCoderFields[num]
 323  				} else {
 324  					f = st.mi.coderFields[num]
 325  				}
 326  				if f != nil {
 327  					vi = f.validation
 328  					break
 329  				}
 330  				// Possible extension field.
 331  				//
 332  				// TODO: We should return ValidationUnknown when:
 333  				//   1. The resolver is not frozen. (More extensions may be added to it.)
 334  				//   2. The resolver returns preg.NotFound.
 335  				// In this case, a type added to the resolver in the future could cause
 336  				// unmarshaling to begin failing. Supporting this requires some way to
 337  				// determine if the resolver is frozen.
 338  				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
 339  				if err != nil && err != protoregistry.NotFound {
 340  					return out, ValidationUnknown
 341  				}
 342  				if err == nil {
 343  					vi = getExtensionFieldInfo(xt).validation
 344  				}
 345  			}
 346  			if vi.requiredBit != 0 {
 347  				// Check that the field has a compatible wire type.
 348  				// We only need to consider non-repeated field types,
 349  				// since repeated fields (and maps) can never be required.
 350  				ok := false
 351  				switch vi.typ {
 352  				case validationTypeVarint:
 353  					ok = wtyp == protowire.VarintType
 354  				case validationTypeFixed32:
 355  					ok = wtyp == protowire.Fixed32Type
 356  				case validationTypeFixed64:
 357  					ok = wtyp == protowire.Fixed64Type
 358  				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
 359  					ok = wtyp == protowire.BytesType
 360  				case validationTypeGroup:
 361  					ok = wtyp == protowire.StartGroupType
 362  				}
 363  				if ok {
 364  					st.requiredMask |= vi.requiredBit
 365  				}
 366  			}
 367  
 368  			switch wtyp {
 369  			case protowire.VarintType:
 370  				if len(b) >= 10 {
 371  					switch {
 372  					case b[0] < 0x80:
 373  						b = b[1:]
 374  					case b[1] < 0x80:
 375  						b = b[2:]
 376  					case b[2] < 0x80:
 377  						b = b[3:]
 378  					case b[3] < 0x80:
 379  						b = b[4:]
 380  					case b[4] < 0x80:
 381  						b = b[5:]
 382  					case b[5] < 0x80:
 383  						b = b[6:]
 384  					case b[6] < 0x80:
 385  						b = b[7:]
 386  					case b[7] < 0x80:
 387  						b = b[8:]
 388  					case b[8] < 0x80:
 389  						b = b[9:]
 390  					case b[9] < 0x80 && b[9] < 2:
 391  						b = b[10:]
 392  					default:
 393  						return out, ValidationInvalid
 394  					}
 395  				} else {
 396  					switch {
 397  					case len(b) > 0 && b[0] < 0x80:
 398  						b = b[1:]
 399  					case len(b) > 1 && b[1] < 0x80:
 400  						b = b[2:]
 401  					case len(b) > 2 && b[2] < 0x80:
 402  						b = b[3:]
 403  					case len(b) > 3 && b[3] < 0x80:
 404  						b = b[4:]
 405  					case len(b) > 4 && b[4] < 0x80:
 406  						b = b[5:]
 407  					case len(b) > 5 && b[5] < 0x80:
 408  						b = b[6:]
 409  					case len(b) > 6 && b[6] < 0x80:
 410  						b = b[7:]
 411  					case len(b) > 7 && b[7] < 0x80:
 412  						b = b[8:]
 413  					case len(b) > 8 && b[8] < 0x80:
 414  						b = b[9:]
 415  					case len(b) > 9 && b[9] < 2:
 416  						b = b[10:]
 417  					default:
 418  						return out, ValidationInvalid
 419  					}
 420  				}
 421  				continue State
 422  			case protowire.BytesType:
 423  				var size uint64
 424  				if len(b) >= 1 && b[0] < 0x80 {
 425  					size = uint64(b[0])
 426  					b = b[1:]
 427  				} else if len(b) >= 2 && b[1] < 128 {
 428  					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
 429  					b = b[2:]
 430  				} else {
 431  					var n int
 432  					size, n = protowire.ConsumeVarint(b)
 433  					if n < 0 {
 434  						return out, ValidationInvalid
 435  					}
 436  					b = b[n:]
 437  				}
 438  				if size > uint64(len(b)) {
 439  					return out, ValidationInvalid
 440  				}
 441  				v := b[:size]
 442  				b = b[size:]
 443  				switch vi.typ {
 444  				case validationTypeMessage:
 445  					if vi.mi == nil {
 446  						return out, ValidationUnknown
 447  					}
 448  					vi.mi.init()
 449  					fallthrough
 450  				case validationTypeMap:
 451  					if vi.mi != nil {
 452  						vi.mi.init()
 453  					}
 454  					states = append(states, validationState{
 455  						typ:     vi.typ,
 456  						keyType: vi.keyType,
 457  						valType: vi.valType,
 458  						mi:      vi.mi,
 459  						tail:    b,
 460  					})
 461  					if vi.typ == validationTypeMessage ||
 462  						vi.typ == validationTypeGroup ||
 463  						vi.typ == validationTypeMap {
 464  						if opts.depth--; opts.depth < 0 {
 465  							return out, ValidationInvalid
 466  						}
 467  					}
 468  					b = v
 469  					continue State
 470  				case validationTypeRepeatedVarint:
 471  					// Packed field.
 472  					for len(v) > 0 {
 473  						_, n := protowire.ConsumeVarint(v)
 474  						if n < 0 {
 475  							return out, ValidationInvalid
 476  						}
 477  						v = v[n:]
 478  					}
 479  				case validationTypeRepeatedFixed32:
 480  					// Packed field.
 481  					if len(v)%4 != 0 {
 482  						return out, ValidationInvalid
 483  					}
 484  				case validationTypeRepeatedFixed64:
 485  					// Packed field.
 486  					if len(v)%8 != 0 {
 487  						return out, ValidationInvalid
 488  					}
 489  				case validationTypeUTF8String:
 490  					if !utf8.Valid(v) {
 491  						return out, ValidationInvalid
 492  					}
 493  				}
 494  			case protowire.Fixed32Type:
 495  				if len(b) < 4 {
 496  					return out, ValidationInvalid
 497  				}
 498  				b = b[4:]
 499  			case protowire.Fixed64Type:
 500  				if len(b) < 8 {
 501  					return out, ValidationInvalid
 502  				}
 503  				b = b[8:]
 504  			case protowire.StartGroupType:
 505  				switch {
 506  				case vi.typ == validationTypeGroup:
 507  					if vi.mi == nil {
 508  						return out, ValidationUnknown
 509  					}
 510  					vi.mi.init()
 511  					states = append(states, validationState{
 512  						typ:      validationTypeGroup,
 513  						mi:       vi.mi,
 514  						endGroup: num,
 515  					})
 516  					if opts.depth--; opts.depth < 0 {
 517  						return out, ValidationInvalid
 518  					}
 519  					continue State
 520  				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
 521  					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
 522  					if err != nil {
 523  						return out, ValidationInvalid
 524  					}
 525  					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
 526  					switch {
 527  					case err == protoregistry.NotFound:
 528  						b = b[n:]
 529  					case err != nil:
 530  						return out, ValidationUnknown
 531  					default:
 532  						xvi := getExtensionFieldInfo(xt).validation
 533  						if xvi.mi != nil {
 534  							xvi.mi.init()
 535  						}
 536  						states = append(states, validationState{
 537  							typ:  xvi.typ,
 538  							mi:   xvi.mi,
 539  							tail: b[n:],
 540  						})
 541  						if xvi.typ == validationTypeMessage ||
 542  							xvi.typ == validationTypeGroup ||
 543  							xvi.typ == validationTypeMap {
 544  							if opts.depth--; opts.depth < 0 {
 545  								return out, ValidationInvalid
 546  							}
 547  						}
 548  						b = v
 549  						continue State
 550  					}
 551  				default:
 552  					n := protowire.ConsumeFieldValue(num, wtyp, b)
 553  					if n < 0 {
 554  						return out, ValidationInvalid
 555  					}
 556  					b = b[n:]
 557  				}
 558  			default:
 559  				return out, ValidationInvalid
 560  			}
 561  		}
 562  		if st.endGroup != 0 {
 563  			return out, ValidationInvalid
 564  		}
 565  		if len(b) != 0 {
 566  			return out, ValidationInvalid
 567  		}
 568  		b = st.tail
 569  	PopState:
 570  		numRequiredFields := 0
 571  		switch st.typ {
 572  		case validationTypeMessage, validationTypeGroup:
 573  			numRequiredFields = int(st.mi.numRequiredFields)
 574  			opts.depth++
 575  		case validationTypeMap:
 576  			// If this is a map field with a message value that contains
 577  			// required fields, require that the value be present.
 578  			if st.mi != nil && st.mi.numRequiredFields > 0 {
 579  				numRequiredFields = 1
 580  			}
 581  			opts.depth++
 582  		}
 583  		// If there are more than 64 required fields, this check will
 584  		// always fail and we will report that the message is potentially
 585  		// uninitialized.
 586  		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
 587  			initialized = false
 588  		}
 589  		states = states[:len(states)-1]
 590  	}
 591  	out.n = start - len(b)
 592  	if initialized {
 593  		out.initialized = true
 594  	}
 595  	return out, ValidationValid
 596  }
 597