codec_map.go raw
1 // Copyright 2019 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package impl
6
7 import (
8 "reflect"
9 "sort"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/errors"
13 "google.golang.org/protobuf/internal/genid"
14 "google.golang.org/protobuf/reflect/protoreflect"
15 )
16
17 type mapInfo struct {
18 goType reflect.Type
19 keyWiretag uint64
20 valWiretag uint64
21 keyFuncs valueCoderFuncs
22 valFuncs valueCoderFuncs
23 keyZero protoreflect.Value
24 keyKind protoreflect.Kind
25 conv *mapConverter
26 }
27
28 func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
29 // TODO: Consider generating specialized map coders.
30 keyField := fd.MapKey()
31 valField := fd.MapValue()
32 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
33 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
34 keyFuncs := encoderFuncsForValue(keyField)
35 valFuncs := encoderFuncsForValue(valField)
36 conv := newMapConverter(ft, fd)
37
38 mapi := &mapInfo{
39 goType: ft,
40 keyWiretag: keyWiretag,
41 valWiretag: valWiretag,
42 keyFuncs: keyFuncs,
43 valFuncs: valFuncs,
44 keyZero: keyField.Default(),
45 keyKind: keyField.Kind(),
46 conv: conv,
47 }
48 if valField.Kind() == protoreflect.MessageKind {
49 valueMessage = getMessageInfo(ft.Elem())
50 }
51
52 funcs = pointerCoderFuncs{
53 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
54 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
55 },
56 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
57 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
58 },
59 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
60 mp := p.AsValueOf(ft)
61 if mp.Elem().IsNil() {
62 mp.Elem().Set(reflect.MakeMap(mapi.goType))
63 }
64 if f.mi == nil {
65 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
66 } else {
67 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
68 }
69 },
70 }
71 switch valField.Kind() {
72 case protoreflect.MessageKind:
73 funcs.merge = mergeMapOfMessage
74 case protoreflect.BytesKind:
75 funcs.merge = mergeMapOfBytes
76 default:
77 funcs.merge = mergeMap
78 }
79 if valFuncs.isInit != nil {
80 funcs.isInit = func(p pointer, f *coderFieldInfo) error {
81 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
82 }
83 }
84 return valueMessage, funcs
85 }
86
87 const (
88 mapKeyTagSize = 1 // field 1, tag size 1.
89 mapValTagSize = 1 // field 2, tag size 2.
90 )
91
92 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
93 if mapv.Len() == 0 {
94 return 0
95 }
96 n := 0
97 iter := mapv.MapRange()
98 for iter.Next() {
99 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
100 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
101 var valSize int
102 value := mapi.conv.valConv.PBValueOf(iter.Value())
103 if f.mi == nil {
104 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
105 } else {
106 p := pointerOfValue(iter.Value())
107 valSize += mapValTagSize
108 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
109 }
110 n += f.tagsize + protowire.SizeBytes(keySize+valSize)
111 }
112 return n
113 }
114
115 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
116 if opts.depth--; opts.depth < 0 {
117 return out, errRecursionDepth
118 }
119 if wtyp != protowire.BytesType {
120 return out, errUnknown
121 }
122 b, n := protowire.ConsumeBytes(b)
123 if n < 0 {
124 return out, errDecode
125 }
126 var (
127 key = mapi.keyZero
128 val = mapi.conv.valConv.New()
129 )
130 for len(b) > 0 {
131 num, wtyp, n := protowire.ConsumeTag(b)
132 if n < 0 {
133 return out, errDecode
134 }
135 if num > protowire.MaxValidNumber {
136 return out, errDecode
137 }
138 b = b[n:]
139 err := errUnknown
140 switch num {
141 case genid.MapEntry_Key_field_number:
142 var v protoreflect.Value
143 var o unmarshalOutput
144 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
145 if err != nil {
146 break
147 }
148 key = v
149 n = o.n
150 case genid.MapEntry_Value_field_number:
151 var v protoreflect.Value
152 var o unmarshalOutput
153 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
154 if err != nil {
155 break
156 }
157 val = v
158 n = o.n
159 }
160 if err == errUnknown {
161 n = protowire.ConsumeFieldValue(num, wtyp, b)
162 if n < 0 {
163 return out, errDecode
164 }
165 } else if err != nil {
166 return out, err
167 }
168 b = b[n:]
169 }
170 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
171 out.n = n
172 return out, nil
173 }
174
175 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
176 if opts.depth--; opts.depth < 0 {
177 return out, errRecursionDepth
178 }
179 if wtyp != protowire.BytesType {
180 return out, errUnknown
181 }
182 b, n := protowire.ConsumeBytes(b)
183 if n < 0 {
184 return out, errDecode
185 }
186 var (
187 key = mapi.keyZero
188 val = reflect.New(f.mi.GoReflectType.Elem())
189 )
190 for len(b) > 0 {
191 num, wtyp, n := protowire.ConsumeTag(b)
192 if n < 0 {
193 return out, errDecode
194 }
195 if num > protowire.MaxValidNumber {
196 return out, errDecode
197 }
198 b = b[n:]
199 err := errUnknown
200 switch num {
201 case 1:
202 var v protoreflect.Value
203 var o unmarshalOutput
204 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
205 if err != nil {
206 break
207 }
208 key = v
209 n = o.n
210 case 2:
211 if wtyp != protowire.BytesType {
212 break
213 }
214 var v []byte
215 v, n = protowire.ConsumeBytes(b)
216 if n < 0 {
217 return out, errDecode
218 }
219 var o unmarshalOutput
220 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
221 if o.initialized {
222 // Consider this map item initialized so long as we see
223 // an initialized value.
224 out.initialized = true
225 }
226 }
227 if err == errUnknown {
228 n = protowire.ConsumeFieldValue(num, wtyp, b)
229 if n < 0 {
230 return out, errDecode
231 }
232 } else if err != nil {
233 return out, err
234 }
235 b = b[n:]
236 }
237 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
238 out.n = n
239 return out, nil
240 }
241
242 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
243 if f.mi == nil {
244 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
245 val := mapi.conv.valConv.PBValueOf(valrv)
246 size := 0
247 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
248 size += mapi.valFuncs.size(val, mapValTagSize, opts)
249 b = protowire.AppendVarint(b, uint64(size))
250 before := len(b)
251 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
252 if err != nil {
253 return nil, err
254 }
255 b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
256 if measuredSize := len(b) - before; size != measuredSize && err == nil {
257 return nil, errors.MismatchedSizeCalculation(size, measuredSize)
258 }
259 return b, err
260 } else {
261 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
262 val := pointerOfValue(valrv)
263 valSize := f.mi.sizePointer(val, opts)
264 size := 0
265 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
266 size += mapValTagSize + protowire.SizeBytes(valSize)
267 b = protowire.AppendVarint(b, uint64(size))
268 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
269 if err != nil {
270 return nil, err
271 }
272 b = protowire.AppendVarint(b, mapi.valWiretag)
273 b = protowire.AppendVarint(b, uint64(valSize))
274 before := len(b)
275 b, err = f.mi.marshalAppendPointer(b, val, opts)
276 if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
277 return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
278 }
279 return b, err
280 }
281 }
282
283 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
284 if mapv.Len() == 0 {
285 return b, nil
286 }
287 if opts.Deterministic() {
288 return appendMapDeterministic(b, mapv, mapi, f, opts)
289 }
290 iter := mapv.MapRange()
291 for iter.Next() {
292 var err error
293 b = protowire.AppendVarint(b, f.wiretag)
294 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
295 if err != nil {
296 return b, err
297 }
298 }
299 return b, nil
300 }
301
302 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
303 keys := mapv.MapKeys()
304 sort.Slice(keys, func(i, j int) bool {
305 switch keys[i].Kind() {
306 case reflect.Bool:
307 return !keys[i].Bool() && keys[j].Bool()
308 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
309 return keys[i].Int() < keys[j].Int()
310 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
311 return keys[i].Uint() < keys[j].Uint()
312 case reflect.Float32, reflect.Float64:
313 return keys[i].Float() < keys[j].Float()
314 case reflect.String:
315 return keys[i].String() < keys[j].String()
316 default:
317 panic("invalid kind: " + keys[i].Kind().String())
318 }
319 })
320 for _, key := range keys {
321 var err error
322 b = protowire.AppendVarint(b, f.wiretag)
323 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
324 if err != nil {
325 return b, err
326 }
327 }
328 return b, nil
329 }
330
331 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
332 if mi := f.mi; mi != nil {
333 mi.init()
334 if !mi.needsInitCheck {
335 return nil
336 }
337 iter := mapv.MapRange()
338 for iter.Next() {
339 val := pointerOfValue(iter.Value())
340 if err := mi.checkInitializedPointer(val); err != nil {
341 return err
342 }
343 }
344 } else {
345 iter := mapv.MapRange()
346 for iter.Next() {
347 val := mapi.conv.valConv.PBValueOf(iter.Value())
348 if err := mapi.valFuncs.isInit(val); err != nil {
349 return err
350 }
351 }
352 }
353 return nil
354 }
355
356 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
357 dstm := dst.AsValueOf(f.ft).Elem()
358 srcm := src.AsValueOf(f.ft).Elem()
359 if srcm.Len() == 0 {
360 return
361 }
362 if dstm.IsNil() {
363 dstm.Set(reflect.MakeMap(f.ft))
364 }
365 iter := srcm.MapRange()
366 for iter.Next() {
367 dstm.SetMapIndex(iter.Key(), iter.Value())
368 }
369 }
370
371 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
372 dstm := dst.AsValueOf(f.ft).Elem()
373 srcm := src.AsValueOf(f.ft).Elem()
374 if srcm.Len() == 0 {
375 return
376 }
377 if dstm.IsNil() {
378 dstm.Set(reflect.MakeMap(f.ft))
379 }
380 iter := srcm.MapRange()
381 for iter.Next() {
382 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
383 }
384 }
385
386 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
387 dstm := dst.AsValueOf(f.ft).Elem()
388 srcm := src.AsValueOf(f.ft).Elem()
389 if srcm.Len() == 0 {
390 return
391 }
392 if dstm.IsNil() {
393 dstm.Set(reflect.MakeMap(f.ft))
394 }
395 iter := srcm.MapRange()
396 for iter.Next() {
397 val := reflect.New(f.ft.Elem().Elem())
398 if f.mi != nil {
399 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
400 } else {
401 opts.Merge(asMessage(val), asMessage(iter.Value()))
402 }
403 dstm.SetMapIndex(iter.Key(), val)
404 }
405 }
406