marshal.mx raw
1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package asn1
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "math/big"
12 "slices"
13 "time"
14 "unicode/utf8"
15 "unsafe"
16 )
17
18 var (
19 byte00Encoder encoder = byteEncoder(0x00)
20 byteFFEncoder encoder = byteEncoder(0xff)
21 )
22
23 // encoder represents an ASN.1 element that is waiting to be marshaled.
24 type encoder interface {
25 // Len returns the number of bytes needed to marshal this element.
26 Len() int
27 // Encode encodes this element by writing Len() bytes to dst.
28 Encode(dst []byte)
29 }
30
31 type byteEncoder byte
32
33 func (c byteEncoder) Len() int {
34 return 1
35 }
36
37 func (c byteEncoder) Encode(dst []byte) {
38 dst[0] = byte(c)
39 }
40
41 type bytesEncoder []byte
42
43 func (b bytesEncoder) Len() int {
44 return len(b)
45 }
46
47 func (b bytesEncoder) Encode(dst []byte) {
48 if copy(dst, b) != len(b) {
49 panic("internal error")
50 }
51 }
52
53 type stringEncoder []byte
54
55 func (s stringEncoder) Len() int {
56 return len(s)
57 }
58
59 func (s stringEncoder) Encode(dst []byte) {
60 if copy(dst, s) != len(s) {
61 panic("internal error")
62 }
63 }
64
65 type multiEncoder []encoder
66
67 func (m multiEncoder) Len() int {
68 var size int
69 for _, e := range m {
70 size += e.Len()
71 }
72 return size
73 }
74
75 func (m multiEncoder) Encode(dst []byte) {
76 var off int
77 for _, e := range m {
78 e.Encode(dst[off:])
79 off += e.Len()
80 }
81 }
82
83 type setEncoder []encoder
84
85 func (s setEncoder) Len() int {
86 var size int
87 for _, e := range s {
88 size += e.Len()
89 }
90 return size
91 }
92
93 func (s setEncoder) Encode(dst []byte) {
94 // Per X690 Section 11.6: The encodings of the component values of a
95 // set-of value shall appear in ascending order, the encodings being
96 // compared as octet strings with the shorter components being padded
97 // at their trailing end with 0-octets.
98 //
99 // First we encode each element to its TLV encoding and then use
100 // octetSort to get the ordering expected by X690 DER rules before
101 // writing the sorted encodings out to dst.
102 l := [][]byte{:len(s)}
103 for i, e := range s {
104 l[i] = []byte{:e.Len()}
105 e.Encode(l[i])
106 }
107
108 // Since we are using bytes.Compare to compare TLV encodings we
109 // don't need to right pad s[i] and s[j] to the same length as
110 // suggested in X690. If len(s[i]) < len(s[j]) the length octet of
111 // s[i], which is the first determining byte, will inherently be
112 // smaller than the length octet of s[j]. This lets us skip the
113 // padding step.
114 slices.SortFunc(l, bytes.Compare)
115
116 var off int
117 for _, b := range l {
118 copy(dst[off:], b)
119 off += len(b)
120 }
121 }
122
123 type taggedEncoder struct {
124 // scratch contains temporary space for encoding the tag and length of
125 // an element in order to avoid extra allocations.
126 scratch [8]byte
127 tag encoder
128 body encoder
129 }
130
131 func (t *taggedEncoder) Len() int {
132 return t.tag.Len() + t.body.Len()
133 }
134
135 func (t *taggedEncoder) Encode(dst []byte) {
136 t.tag.Encode(dst)
137 t.body.Encode(dst[t.tag.Len():])
138 }
139
140 type int64Encoder int64
141
142 func (i int64Encoder) Len() int {
143 n := 1
144
145 for i > 127 {
146 n++
147 i >>= 8
148 }
149
150 for i < -128 {
151 n++
152 i >>= 8
153 }
154
155 return n
156 }
157
158 func (i int64Encoder) Encode(dst []byte) {
159 n := i.Len()
160
161 for j := 0; j < n; j++ {
162 dst[j] = byte(i >> uint((n-1-j)*8))
163 }
164 }
165
166 func base128IntLength(n int64) int {
167 if n == 0 {
168 return 1
169 }
170
171 l := 0
172 for i := n; i > 0; i >>= 7 {
173 l++
174 }
175
176 return l
177 }
178
179 func appendBase128Int(dst []byte, n int64) []byte {
180 l := base128IntLength(n)
181
182 for i := l - 1; i >= 0; i-- {
183 o := byte(n >> uint(i*7))
184 o &= 0x7f
185 if i != 0 {
186 o |= 0x80
187 }
188
189 dst = append(dst, o)
190 }
191
192 return dst
193 }
194
195 func makeBigInt(n *big.Int) (encoder, error) {
196 if n == nil {
197 return nil, StructuralError{"empty integer"}
198 }
199
200 if n.Sign() < 0 {
201 // A negative number has to be converted to two's-complement
202 // form. So we'll invert and subtract 1. If the
203 // most-significant-bit isn't set then we'll need to pad the
204 // beginning with 0xff in order to keep the number negative.
205 nMinus1 := (&big.Int{}).Neg(n)
206 nMinus1.Sub(nMinus1, bigOne)
207 bytes := nMinus1.Bytes()
208 for i := range bytes {
209 bytes[i] ^= 0xff
210 }
211 if len(bytes) == 0 || bytes[0]&0x80 == 0 {
212 return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
213 }
214 return bytesEncoder(bytes), nil
215 } else if n.Sign() == 0 {
216 // Zero is written as a single 0 zero rather than no bytes.
217 return byte00Encoder, nil
218 } else {
219 bytes := n.Bytes()
220 if len(bytes) > 0 && bytes[0]&0x80 != 0 {
221 // We'll have to pad this with 0x00 in order to stop it
222 // looking like a negative number.
223 return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
224 }
225 return bytesEncoder(bytes), nil
226 }
227 }
228
229 func appendLength(dst []byte, i int) []byte {
230 n := lengthLength(i)
231
232 for ; n > 0; n-- {
233 dst = append(dst, byte(i>>uint((n-1)*8)))
234 }
235
236 return dst
237 }
238
239 func lengthLength(i int) (numBytes int) {
240 numBytes = 1
241 for i > 255 {
242 numBytes++
243 i >>= 8
244 }
245 return
246 }
247
248 func appendTagAndLength(dst []byte, t tagAndLength) []byte {
249 b := uint8(t.class) << 6
250 if t.isCompound {
251 b |= 0x20
252 }
253 if t.tag >= 31 {
254 b |= 0x1f
255 dst = append(dst, b)
256 dst = appendBase128Int(dst, int64(t.tag))
257 } else {
258 b |= uint8(t.tag)
259 dst = append(dst, b)
260 }
261
262 if t.length >= 128 {
263 l := lengthLength(t.length)
264 dst = append(dst, 0x80|byte(l))
265 dst = appendLength(dst, t.length)
266 } else {
267 dst = append(dst, byte(t.length))
268 }
269
270 return dst
271 }
272
273 type bitStringEncoder BitString
274
275 func (b bitStringEncoder) Len() int {
276 return len(b.Bytes) + 1
277 }
278
279 func (b bitStringEncoder) Encode(dst []byte) {
280 dst[0] = byte((8 - b.BitLength%8) % 8)
281 if copy(dst[1:], b.Bytes) != len(b.Bytes) {
282 panic("internal error")
283 }
284 }
285
286 type oidEncoder []int
287
288 func (oid oidEncoder) Len() int {
289 l := base128IntLength(int64(oid[0]*40 + oid[1]))
290 for i := 2; i < len(oid); i++ {
291 l += base128IntLength(int64(oid[i]))
292 }
293 return l
294 }
295
296 func (oid oidEncoder) Encode(dst []byte) {
297 dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
298 for i := 2; i < len(oid); i++ {
299 dst = appendBase128Int(dst, int64(oid[i]))
300 }
301 }
302
303 func makeObjectIdentifier(oid []int) (e encoder, err error) {
304 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
305 return nil, StructuralError{"invalid object identifier"}
306 }
307
308 return oidEncoder(oid), nil
309 }
310
311 func makePrintableString(s []byte) (e encoder, err error) {
312 for i := 0; i < len(s); i++ {
313 // The asterisk is often used in PrintableString, even though
314 // it is invalid. If a PrintableString was specifically
315 // requested then the asterisk is permitted by this code.
316 // Ampersand is allowed in parsing due a handful of CA
317 // certificates, however when making new certificates
318 // it is rejected.
319 if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
320 return nil, StructuralError{"PrintableString contains invalid character"}
321 }
322 }
323
324 return stringEncoder(s), nil
325 }
326
327 func makeIA5String(s []byte) (e encoder, err error) {
328 for i := 0; i < len(s); i++ {
329 if s[i] > 127 {
330 return nil, StructuralError{"IA5String contains invalid character"}
331 }
332 }
333
334 return stringEncoder(s), nil
335 }
336
337 func makeNumericString(s []byte) (e encoder, err error) {
338 for i := 0; i < len(s); i++ {
339 if !isNumeric(s[i]) {
340 return nil, StructuralError{"NumericString contains invalid character"}
341 }
342 }
343
344 return stringEncoder(s), nil
345 }
346
347 func makeUTF8String(s []byte) encoder {
348 return stringEncoder(s)
349 }
350
351 func appendTwoDigits(dst []byte, v int) []byte {
352 return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
353 }
354
355 func appendFourDigits(dst []byte, v int) []byte {
356 return append(dst,
357 byte('0'+(v/1000)%10),
358 byte('0'+(v/100)%10),
359 byte('0'+(v/10)%10),
360 byte('0'+v%10))
361 }
362
363 func outsideUTCRange(t time.Time) bool {
364 year := t.Year()
365 return year < 1950 || year >= 2050
366 }
367
368 func makeUTCTime(t time.Time) (e encoder, err error) {
369 dst := []byte{:0:18}
370
371 dst, err = appendUTCTime(dst, t)
372 if err != nil {
373 return nil, err
374 }
375
376 return bytesEncoder(dst), nil
377 }
378
379 func makeGeneralizedTime(t time.Time) (e encoder, err error) {
380 dst := []byte{:0:20}
381
382 dst, err = appendGeneralizedTime(dst, t)
383 if err != nil {
384 return nil, err
385 }
386
387 return bytesEncoder(dst), nil
388 }
389
390 func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
391 year := t.Year()
392
393 switch {
394 case 1950 <= year && year < 2000:
395 dst = appendTwoDigits(dst, year-1900)
396 case 2000 <= year && year < 2050:
397 dst = appendTwoDigits(dst, year-2000)
398 default:
399 return nil, StructuralError{"cannot represent time as UTCTime"}
400 }
401
402 return appendTimeCommon(dst, t), nil
403 }
404
405 func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
406 year := t.Year()
407 if year < 0 || year > 9999 {
408 return nil, StructuralError{"cannot represent time as GeneralizedTime"}
409 }
410
411 dst = appendFourDigits(dst, year)
412
413 return appendTimeCommon(dst, t), nil
414 }
415
416 func appendTimeCommon(dst []byte, t time.Time) []byte {
417 _, month, day := t.Date()
418
419 dst = appendTwoDigits(dst, int(month))
420 dst = appendTwoDigits(dst, day)
421
422 hour, min, sec := t.Clock()
423
424 dst = appendTwoDigits(dst, hour)
425 dst = appendTwoDigits(dst, min)
426 dst = appendTwoDigits(dst, sec)
427
428 _, offset := t.Zone()
429
430 switch {
431 case offset/60 == 0:
432 return append(dst, 'Z')
433 case offset > 0:
434 dst = append(dst, '+')
435 case offset < 0:
436 dst = append(dst, '-')
437 }
438
439 offsetMinutes := offset / 60
440 if offsetMinutes < 0 {
441 offsetMinutes = -offsetMinutes
442 }
443
444 dst = appendTwoDigits(dst, offsetMinutes/60)
445 dst = appendTwoDigits(dst, offsetMinutes%60)
446
447 return dst
448 }
449
450 func stripTagAndLength(in []byte) []byte {
451 _, offset, err := parseTagAndLength(in, 0)
452 if err != nil {
453 return in
454 }
455 return in[offset:]
456 }
457
458
459 func makeBody(typ *_rawType, ptr unsafe.Pointer, params fieldParameters) (e encoder, err error) {
460 // Known concrete types by type code pointer comparison.
461 if typ == _flagTC {
462 return bytesEncoder(nil), nil
463 }
464 if typ == _timeTC {
465 t := *(*time.Time)(ptr)
466 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
467 return makeGeneralizedTime(t)
468 }
469 return makeUTCTime(t)
470 }
471 if typ == _bitStringTC {
472 return bitStringEncoder(*(*BitString)(ptr)), nil
473 }
474 if typ == _objectIdentifierTC {
475 return makeObjectIdentifier(*(*ObjectIdentifier)(ptr))
476 }
477 if typ == _bigIntPtrTC {
478 return makeBigInt(*(**big.Int)(ptr))
479 }
480
481 // Kind-based dispatch.
482 switch typ.kind() {
483 case akBool:
484 if *(*bool)(ptr) {
485 return byteFFEncoder, nil
486 }
487 return byte00Encoder, nil
488 case akInt, akInt8, akInt16, akInt32, akInt64:
489 var intVal int64
490 switch typ.size() {
491 case 1:
492 intVal = int64(*(*int8)(ptr))
493 case 2:
494 intVal = int64(*(*int16)(ptr))
495 case 4:
496 intVal = int64(*(*int32)(ptr))
497 default:
498 intVal = *(*int64)(ptr)
499 }
500 return int64Encoder(intVal), nil
501 case akStruct:
502 for i := 0; i < typ.numField(); i++ {
503 if !typ.fieldExported(i) {
504 return nil, StructuralError{"struct contains unexported fields"}
505 }
506 }
507
508 startingField := 0
509 n := typ.numField()
510 if n == 0 {
511 return bytesEncoder(nil), nil
512 }
513
514 if typ.fieldType(0) == _rawContentsTC {
515 rc := *(*RawContent)(unsafe.Add(ptr, typ.fieldOffset(0)))
516 if len(rc) > 0 {
517 return bytesEncoder(stripTagAndLength([]byte(rc))), nil
518 }
519 startingField = 1
520 }
521
522 switch n1 := n - startingField; n1 {
523 case 0:
524 return bytesEncoder(nil), nil
525 case 1:
526 ft := typ.fieldType(startingField)
527 fp := unsafe.Add(ptr, typ.fieldOffset(startingField))
528 tag := getASN1Tag(typ.fieldTag(startingField))
529 return makeField(ft, fp, parseFieldParameters(tag))
530 default:
531 m := []encoder{:n1}
532 for i := 0; i < n1; i++ {
533 fi := i + startingField
534 ft := typ.fieldType(fi)
535 fp := unsafe.Add(ptr, typ.fieldOffset(fi))
536 tag := getASN1Tag(typ.fieldTag(fi))
537 m[i], err = makeField(ft, fp, parseFieldParameters(tag))
538 if err != nil {
539 return nil, err
540 }
541 }
542 return multiEncoder(m), nil
543 }
544 case akSlice:
545 if typ.elem().kind() == akUint8 {
546 return bytesEncoder(*(*[]byte)(ptr)), nil
547 }
548 hdr := (*_sliceHeader)(ptr)
549 l := hdr.len
550 elemType := typ.elem()
551 var fp fieldParameters
552 switch l {
553 case 0:
554 return bytesEncoder(nil), nil
555 case 1:
556 return makeField(elemType, sliceIndex(ptr, elemType, 0), fp)
557 default:
558 m := []encoder{:l}
559 for i := 0; i < l; i++ {
560 m[i], err = makeField(elemType, sliceIndex(ptr, elemType, i), fp)
561 if err != nil {
562 return nil, err
563 }
564 }
565 if params.set {
566 return setEncoder(m), nil
567 }
568 return multiEncoder(m), nil
569 }
570 case akBytes:
571 s := *(*[]byte)(ptr)
572 switch params.stringType {
573 case TagIA5String:
574 return makeIA5String(s)
575 case TagPrintableString:
576 return makePrintableString(s)
577 case TagNumericString:
578 return makeNumericString(s)
579 default:
580 return makeUTF8String(s), nil
581 }
582 }
583
584 return nil, StructuralError{"unknown Go type"}
585 }
586
587 // isZeroValue reports whether the value at ptr is all zero bytes.
588 func isZeroValue(typ *_rawType, ptr unsafe.Pointer) bool {
589 sz := typ.size()
590 for i := uintptr(0); i < sz; i++ {
591 if *(*byte)(unsafe.Add(ptr, i)) != 0 {
592 return false
593 }
594 }
595 return true
596 }
597
598 func makeField(typ *_rawType, ptr unsafe.Pointer, params fieldParameters) (e encoder, err error) {
599 if typ == nil || ptr == nil {
600 return nil, fmt.Errorf("asn1: cannot marshal nil value")
601 }
602 // If the field is an interface{} then recurse into the concrete value.
603 if typ.kind() == akInterface {
604 iface := (*_ifacePair)(ptr)
605 if iface.typecode == nil {
606 return nil, fmt.Errorf("asn1: cannot marshal nil value")
607 }
608 concreteType := (*_rawType)(iface.typecode)
609 var concretePtr unsafe.Pointer
610 if concreteType.size() <= unsafe.Sizeof(uintptr(0)) {
611 concretePtr = unsafe.Pointer(&iface.value)
612 } else {
613 concretePtr = iface.value
614 }
615 return makeField(concreteType, concretePtr, params)
616 }
617
618 if typ.kind() == akSlice && (*_sliceHeader)(ptr).len == 0 && params.omitEmpty {
619 return bytesEncoder(nil), nil
620 }
621
622 if params.optional && params.defaultValue != nil {
623 k := typ.kind()
624 if k == akInt || k == akInt8 || k == akInt16 || k == akInt32 || k == akInt64 {
625 var current int64
626 switch typ.size() {
627 case 1:
628 current = int64(*(*int8)(ptr))
629 case 2:
630 current = int64(*(*int16)(ptr))
631 case 4:
632 current = int64(*(*int32)(ptr))
633 default:
634 current = *(*int64)(ptr)
635 }
636 if current == *params.defaultValue {
637 return bytesEncoder(nil), nil
638 }
639 }
640 }
641
642 if params.optional && params.defaultValue == nil {
643 if isZeroValue(typ, ptr) {
644 return bytesEncoder(nil), nil
645 }
646 }
647
648 if typ == _rawValueTC {
649 rv := *(*RawValue)(ptr)
650 if len(rv.FullBytes) != 0 {
651 return bytesEncoder(rv.FullBytes), nil
652 }
653 t := &taggedEncoder{}
654 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
655 t.body = bytesEncoder(rv.Bytes)
656 return t, nil
657 }
658
659 matchAny, tag, isCompound, ok := getUniversalType(typ)
660 if !ok || matchAny {
661 return nil, StructuralError{fmt.Sprintf("unknown Go type (kind %d)", typ.kind())}
662 }
663
664 if params.timeType != 0 && tag != TagUTCTime {
665 return nil, StructuralError{"explicit time type given to non-time member"}
666 }
667
668 if params.stringType != 0 && tag != TagPrintableString {
669 return nil, StructuralError{"explicit string type given to non-string member"}
670 }
671
672 switch tag {
673 case TagPrintableString:
674 if params.stringType == 0 {
675 s := *(*[]byte)(ptr)
676 for _, r := range s {
677 if r >= utf8.RuneSelf || !isPrintable(r, rejectAsterisk, rejectAmpersand) {
678 if !utf8.Valid(s) {
679 return nil, errors.New("asn1: string not valid UTF-8")
680 }
681 tag = TagUTF8String
682 break
683 }
684 }
685 } else {
686 tag = params.stringType
687 }
688 case TagUTCTime:
689 if params.timeType == TagGeneralizedTime || outsideUTCRange(*(*time.Time)(ptr)) {
690 tag = TagGeneralizedTime
691 }
692 }
693
694 if params.set {
695 if tag != TagSequence {
696 return nil, StructuralError{"non sequence tagged as set"}
697 }
698 tag = TagSet
699 }
700
701 if tag == TagSet && !params.set {
702 params.set = true
703 }
704
705 t := &taggedEncoder{}
706
707 t.body, err = makeBody(typ, ptr, params)
708 if err != nil {
709 return nil, err
710 }
711
712 bodyLen := t.body.Len()
713
714 class := ClassUniversal
715 if params.tag != nil {
716 if params.application {
717 class = ClassApplication
718 } else if params.private {
719 class = ClassPrivate
720 } else {
721 class = ClassContextSpecific
722 }
723
724 if params.explicit {
725 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
726
727 tt := &taggedEncoder{}
728 tt.body = t
729 tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
730 class: class,
731 tag: *params.tag,
732 length: bodyLen + t.tag.Len(),
733 isCompound: true,
734 }))
735
736 return tt, nil
737 }
738
739 // implicit tag.
740 tag = *params.tag
741 }
742
743 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
744
745 return t, nil
746 }
747
748 // Marshal returns the ASN.1 encoding of val.
749 func Marshal(val any) ([]byte, error) {
750 return MarshalWithParams(val, "")
751 }
752
753 // MarshalWithParams allows field parameters to be specified for the
754 // top-level element. The form of the params is the same as the field tags.
755 func MarshalWithParams(val any, params []byte) ([]byte, error) {
756 iface := (*_ifacePair)(unsafe.Pointer(&val))
757 tc := (*_rawType)(iface.typecode)
758 if tc == nil {
759 return nil, fmt.Errorf("asn1: cannot marshal nil value")
760 }
761 var dp unsafe.Pointer
762 if tc.size() <= unsafe.Sizeof(uintptr(0)) {
763 dp = unsafe.Pointer(&iface.value)
764 } else {
765 dp = iface.value
766 }
767 e, err := makeField(tc, dp, parseFieldParameters(params))
768 if err != nil {
769 return nil, err
770 }
771 b := []byte{:e.Len()}
772 e.Encode(b)
773 return b, nil
774 }
775