messageset.go raw

   1  // Copyright 2019 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package messageset encodes and decodes the obsolete MessageSet wire format.
   6  package messageset
   7  
   8  import (
   9  	"math"
  10  
  11  	"google.golang.org/protobuf/encoding/protowire"
  12  	"google.golang.org/protobuf/internal/errors"
  13  	"google.golang.org/protobuf/reflect/protoreflect"
  14  )
  15  
  16  // The MessageSet wire format is equivalent to a message defined as follows,
  17  // where each Item defines an extension field with a field number of 'type_id'
  18  // and content of 'message'. MessageSet extensions must be non-repeated message
  19  // fields.
  20  //
  21  //	message MessageSet {
  22  //		repeated group Item = 1 {
  23  //			required int32 type_id = 2;
  24  //			required string message = 3;
  25  //		}
  26  //	}
  27  const (
  28  	FieldItem    = protowire.Number(1)
  29  	FieldTypeID  = protowire.Number(2)
  30  	FieldMessage = protowire.Number(3)
  31  )
  32  
  33  // ExtensionName is the field name for extensions of MessageSet.
  34  //
  35  // A valid MessageSet extension must be of the form:
  36  //
  37  //	message MyMessage {
  38  //		extend proto2.bridge.MessageSet {
  39  //			optional MyMessage message_set_extension = 1234;
  40  //		}
  41  //		...
  42  //	}
  43  const ExtensionName = "message_set_extension"
  44  
  45  // IsMessageSet returns whether the message uses the MessageSet wire format.
  46  func IsMessageSet(md protoreflect.MessageDescriptor) bool {
  47  	xmd, ok := md.(interface{ IsMessageSet() bool })
  48  	return ok && xmd.IsMessageSet()
  49  }
  50  
  51  // IsMessageSetExtension reports this field properly extends a MessageSet.
  52  func IsMessageSetExtension(fd protoreflect.FieldDescriptor) bool {
  53  	switch {
  54  	case fd.Name() != ExtensionName:
  55  		return false
  56  	case !IsMessageSet(fd.ContainingMessage()):
  57  		return false
  58  	case fd.FullName().Parent() != fd.Message().FullName():
  59  		return false
  60  	}
  61  	return true
  62  }
  63  
  64  // SizeField returns the size of a MessageSet item field containing an extension
  65  // with the given field number, not counting the contents of the message subfield.
  66  func SizeField(num protowire.Number) int {
  67  	return 2*protowire.SizeTag(FieldItem) + protowire.SizeTag(FieldTypeID) + protowire.SizeVarint(uint64(num))
  68  }
  69  
  70  // Unmarshal parses a MessageSet.
  71  //
  72  // It calls fn with the type ID and value of each item in the MessageSet.
  73  // Unknown fields are discarded.
  74  //
  75  // If wantLen is true, the item values include the varint length prefix.
  76  // This is ugly, but simplifies the fast-path decoder in internal/impl.
  77  func Unmarshal(b []byte, wantLen bool, fn func(typeID protowire.Number, value []byte) error) error {
  78  	for len(b) > 0 {
  79  		num, wtyp, n := protowire.ConsumeTag(b)
  80  		if n < 0 {
  81  			return protowire.ParseError(n)
  82  		}
  83  		b = b[n:]
  84  		if num != FieldItem || wtyp != protowire.StartGroupType {
  85  			n := protowire.ConsumeFieldValue(num, wtyp, b)
  86  			if n < 0 {
  87  				return protowire.ParseError(n)
  88  			}
  89  			b = b[n:]
  90  			continue
  91  		}
  92  		typeID, value, n, err := ConsumeFieldValue(b, wantLen)
  93  		if err != nil {
  94  			return err
  95  		}
  96  		b = b[n:]
  97  		if typeID == 0 {
  98  			continue
  99  		}
 100  		if err := fn(typeID, value); err != nil {
 101  			return err
 102  		}
 103  	}
 104  	return nil
 105  }
 106  
 107  // ConsumeFieldValue parses b as a MessageSet item field value until and including
 108  // the trailing end group marker. It assumes the start group tag has already been parsed.
 109  // It returns the contents of the type_id and message subfields and the total
 110  // item length.
 111  //
 112  // If wantLen is true, the returned message value includes the length prefix.
 113  func ConsumeFieldValue(b []byte, wantLen bool) (typeid protowire.Number, message []byte, n int, err error) {
 114  	ilen := len(b)
 115  	for {
 116  		num, wtyp, n := protowire.ConsumeTag(b)
 117  		if n < 0 {
 118  			return 0, nil, 0, protowire.ParseError(n)
 119  		}
 120  		b = b[n:]
 121  		switch {
 122  		case num == FieldItem && wtyp == protowire.EndGroupType:
 123  			if wantLen && len(message) == 0 {
 124  				// The message field was missing, which should never happen.
 125  				// Be prepared for this case anyway.
 126  				message = protowire.AppendVarint(message, 0)
 127  			}
 128  			return typeid, message, ilen - len(b), nil
 129  		case num == FieldTypeID && wtyp == protowire.VarintType:
 130  			v, n := protowire.ConsumeVarint(b)
 131  			if n < 0 {
 132  				return 0, nil, 0, protowire.ParseError(n)
 133  			}
 134  			b = b[n:]
 135  			if v < 1 || v > math.MaxInt32 {
 136  				return 0, nil, 0, errors.New("invalid type_id in message set")
 137  			}
 138  			typeid = protowire.Number(v)
 139  		case num == FieldMessage && wtyp == protowire.BytesType:
 140  			m, n := protowire.ConsumeBytes(b)
 141  			if n < 0 {
 142  				return 0, nil, 0, protowire.ParseError(n)
 143  			}
 144  			if message == nil {
 145  				if wantLen {
 146  					message = b[:n:n]
 147  				} else {
 148  					message = m[:len(m):len(m)]
 149  				}
 150  			} else {
 151  				// This case should never happen in practice, but handle it for
 152  				// correctness: The MessageSet item contains multiple message
 153  				// fields, which need to be merged.
 154  				//
 155  				// In the case where we're returning the length, this becomes
 156  				// quite inefficient since we need to strip the length off
 157  				// the existing data and reconstruct it with the combined length.
 158  				if wantLen {
 159  					_, nn := protowire.ConsumeVarint(message)
 160  					m0 := message[nn:]
 161  					message = nil
 162  					message = protowire.AppendVarint(message, uint64(len(m0)+len(m)))
 163  					message = append(message, m0...)
 164  					message = append(message, m...)
 165  				} else {
 166  					message = append(message, m...)
 167  				}
 168  			}
 169  			b = b[n:]
 170  		default:
 171  			// We have no place to put it, so we just ignore unknown fields.
 172  			n := protowire.ConsumeFieldValue(num, wtyp, b)
 173  			if n < 0 {
 174  				return 0, nil, 0, protowire.ParseError(n)
 175  			}
 176  			b = b[n:]
 177  		}
 178  	}
 179  }
 180  
 181  // AppendFieldStart appends the start of a MessageSet item field containing
 182  // an extension with the given number. The caller must add the message
 183  // subfield (including the tag).
 184  func AppendFieldStart(b []byte, num protowire.Number) []byte {
 185  	b = protowire.AppendTag(b, FieldItem, protowire.StartGroupType)
 186  	b = protowire.AppendTag(b, FieldTypeID, protowire.VarintType)
 187  	b = protowire.AppendVarint(b, uint64(num))
 188  	return b
 189  }
 190  
 191  // AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
 192  func AppendFieldEnd(b []byte) []byte {
 193  	return protowire.AppendTag(b, FieldItem, protowire.EndGroupType)
 194  }
 195  
 196  // SizeUnknown returns the size of an unknown fields section in MessageSet format.
 197  //
 198  // See AppendUnknown.
 199  func SizeUnknown(unknown []byte) (size int) {
 200  	for len(unknown) > 0 {
 201  		num, typ, n := protowire.ConsumeTag(unknown)
 202  		if n < 0 || typ != protowire.BytesType {
 203  			return 0
 204  		}
 205  		unknown = unknown[n:]
 206  		_, n = protowire.ConsumeBytes(unknown)
 207  		if n < 0 {
 208  			return 0
 209  		}
 210  		unknown = unknown[n:]
 211  		size += SizeField(num) + protowire.SizeTag(FieldMessage) + n
 212  	}
 213  	return size
 214  }
 215  
 216  // AppendUnknown appends unknown fields to b in MessageSet format.
 217  //
 218  // For historic reasons, unresolved items in a MessageSet are stored in a
 219  // message's unknown fields section in non-MessageSet format. That is, an
 220  // unknown item with typeID T and value V appears in the unknown fields as
 221  // a field with number T and value V.
 222  //
 223  // This function converts the unknown fields back into MessageSet form.
 224  func AppendUnknown(b, unknown []byte) ([]byte, error) {
 225  	for len(unknown) > 0 {
 226  		num, typ, n := protowire.ConsumeTag(unknown)
 227  		if n < 0 || typ != protowire.BytesType {
 228  			return nil, errors.New("invalid data in message set unknown fields")
 229  		}
 230  		unknown = unknown[n:]
 231  		_, n = protowire.ConsumeBytes(unknown)
 232  		if n < 0 {
 233  			return nil, errors.New("invalid data in message set unknown fields")
 234  		}
 235  		b = AppendFieldStart(b, num)
 236  		b = protowire.AppendTag(b, FieldMessage, protowire.BytesType)
 237  		b = append(b, unknown[:n]...)
 238  		b = AppendFieldEnd(b)
 239  		unknown = unknown[n:]
 240  	}
 241  	return b, nil
 242  }
 243