// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package asn1 import ( "bytes" "errors" "fmt" "math/big" "slices" "time" "unicode/utf8" "unsafe" ) var ( byte00Encoder encoder = byteEncoder(0x00) byteFFEncoder encoder = byteEncoder(0xff) ) // encoder represents an ASN.1 element that is waiting to be marshaled. type encoder interface { // Len returns the number of bytes needed to marshal this element. Len() int // Encode encodes this element by writing Len() bytes to dst. Encode(dst []byte) } type byteEncoder byte func (c byteEncoder) Len() int { return 1 } func (c byteEncoder) Encode(dst []byte) { dst[0] = byte(c) } type bytesEncoder []byte func (b bytesEncoder) Len() int { return len(b) } func (b bytesEncoder) Encode(dst []byte) { if copy(dst, b) != len(b) { panic("internal error") } } type stringEncoder []byte func (s stringEncoder) Len() int { return len(s) } func (s stringEncoder) Encode(dst []byte) { if copy(dst, s) != len(s) { panic("internal error") } } type multiEncoder []encoder func (m multiEncoder) Len() int { var size int for _, e := range m { size += e.Len() } return size } func (m multiEncoder) Encode(dst []byte) { var off int for _, e := range m { e.Encode(dst[off:]) off += e.Len() } } type setEncoder []encoder func (s setEncoder) Len() int { var size int for _, e := range s { size += e.Len() } return size } func (s setEncoder) Encode(dst []byte) { // Per X690 Section 11.6: The encodings of the component values of a // set-of value shall appear in ascending order, the encodings being // compared as octet strings with the shorter components being padded // at their trailing end with 0-octets. // // First we encode each element to its TLV encoding and then use // octetSort to get the ordering expected by X690 DER rules before // writing the sorted encodings out to dst. l := [][]byte{:len(s)} for i, e := range s { l[i] = []byte{:e.Len()} e.Encode(l[i]) } // Since we are using bytes.Compare to compare TLV encodings we // don't need to right pad s[i] and s[j] to the same length as // suggested in X690. If len(s[i]) < len(s[j]) the length octet of // s[i], which is the first determining byte, will inherently be // smaller than the length octet of s[j]. This lets us skip the // padding step. slices.SortFunc(l, bytes.Compare) var off int for _, b := range l { copy(dst[off:], b) off += len(b) } } type taggedEncoder struct { // scratch contains temporary space for encoding the tag and length of // an element in order to avoid extra allocations. scratch [8]byte tag encoder body encoder } func (t *taggedEncoder) Len() int { return t.tag.Len() + t.body.Len() } func (t *taggedEncoder) Encode(dst []byte) { t.tag.Encode(dst) t.body.Encode(dst[t.tag.Len():]) } type int64Encoder int64 func (i int64Encoder) Len() int { n := 1 for i > 127 { n++ i >>= 8 } for i < -128 { n++ i >>= 8 } return n } func (i int64Encoder) Encode(dst []byte) { n := i.Len() for j := 0; j < n; j++ { dst[j] = byte(i >> uint((n-1-j)*8)) } } func base128IntLength(n int64) int { if n == 0 { return 1 } l := 0 for i := n; i > 0; i >>= 7 { l++ } return l } func appendBase128Int(dst []byte, n int64) []byte { l := base128IntLength(n) for i := l - 1; i >= 0; i-- { o := byte(n >> uint(i*7)) o &= 0x7f if i != 0 { o |= 0x80 } dst = append(dst, o) } return dst } func makeBigInt(n *big.Int) (encoder, error) { if n == nil { return nil, StructuralError{"empty integer"} } if n.Sign() < 0 { // A negative number has to be converted to two's-complement // form. So we'll invert and subtract 1. If the // most-significant-bit isn't set then we'll need to pad the // beginning with 0xff in order to keep the number negative. nMinus1 := (&big.Int{}).Neg(n) nMinus1.Sub(nMinus1, bigOne) bytes := nMinus1.Bytes() for i := range bytes { bytes[i] ^= 0xff } if len(bytes) == 0 || bytes[0]&0x80 == 0 { return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil } return bytesEncoder(bytes), nil } else if n.Sign() == 0 { // Zero is written as a single 0 zero rather than no bytes. return byte00Encoder, nil } else { bytes := n.Bytes() if len(bytes) > 0 && bytes[0]&0x80 != 0 { // We'll have to pad this with 0x00 in order to stop it // looking like a negative number. return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil } return bytesEncoder(bytes), nil } } func appendLength(dst []byte, i int) []byte { n := lengthLength(i) for ; n > 0; n-- { dst = append(dst, byte(i>>uint((n-1)*8))) } return dst } func lengthLength(i int) (numBytes int) { numBytes = 1 for i > 255 { numBytes++ i >>= 8 } return } func appendTagAndLength(dst []byte, t tagAndLength) []byte { b := uint8(t.class) << 6 if t.isCompound { b |= 0x20 } if t.tag >= 31 { b |= 0x1f dst = append(dst, b) dst = appendBase128Int(dst, int64(t.tag)) } else { b |= uint8(t.tag) dst = append(dst, b) } if t.length >= 128 { l := lengthLength(t.length) dst = append(dst, 0x80|byte(l)) dst = appendLength(dst, t.length) } else { dst = append(dst, byte(t.length)) } return dst } type bitStringEncoder BitString func (b bitStringEncoder) Len() int { return len(b.Bytes) + 1 } func (b bitStringEncoder) Encode(dst []byte) { dst[0] = byte((8 - b.BitLength%8) % 8) if copy(dst[1:], b.Bytes) != len(b.Bytes) { panic("internal error") } } type oidEncoder []int func (oid oidEncoder) Len() int { l := base128IntLength(int64(oid[0]*40 + oid[1])) for i := 2; i < len(oid); i++ { l += base128IntLength(int64(oid[i])) } return l } func (oid oidEncoder) Encode(dst []byte) { dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1])) for i := 2; i < len(oid); i++ { dst = appendBase128Int(dst, int64(oid[i])) } } func makeObjectIdentifier(oid []int) (e encoder, err error) { if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { return nil, StructuralError{"invalid object identifier"} } return oidEncoder(oid), nil } func makePrintableString(s []byte) (e encoder, err error) { for i := 0; i < len(s); i++ { // The asterisk is often used in PrintableString, even though // it is invalid. If a PrintableString was specifically // requested then the asterisk is permitted by this code. // Ampersand is allowed in parsing due a handful of CA // certificates, however when making new certificates // it is rejected. if !isPrintable(s[i], allowAsterisk, rejectAmpersand) { return nil, StructuralError{"PrintableString contains invalid character"} } } return stringEncoder(s), nil } func makeIA5String(s []byte) (e encoder, err error) { for i := 0; i < len(s); i++ { if s[i] > 127 { return nil, StructuralError{"IA5String contains invalid character"} } } return stringEncoder(s), nil } func makeNumericString(s []byte) (e encoder, err error) { for i := 0; i < len(s); i++ { if !isNumeric(s[i]) { return nil, StructuralError{"NumericString contains invalid character"} } } return stringEncoder(s), nil } func makeUTF8String(s []byte) encoder { return stringEncoder(s) } func appendTwoDigits(dst []byte, v int) []byte { return append(dst, byte('0'+(v/10)%10), byte('0'+v%10)) } func appendFourDigits(dst []byte, v int) []byte { return append(dst, byte('0'+(v/1000)%10), byte('0'+(v/100)%10), byte('0'+(v/10)%10), byte('0'+v%10)) } func outsideUTCRange(t time.Time) bool { year := t.Year() return year < 1950 || year >= 2050 } func makeUTCTime(t time.Time) (e encoder, err error) { dst := []byte{:0:18} dst, err = appendUTCTime(dst, t) if err != nil { return nil, err } return bytesEncoder(dst), nil } func makeGeneralizedTime(t time.Time) (e encoder, err error) { dst := []byte{:0:20} dst, err = appendGeneralizedTime(dst, t) if err != nil { return nil, err } return bytesEncoder(dst), nil } func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) { year := t.Year() switch { case 1950 <= year && year < 2000: dst = appendTwoDigits(dst, year-1900) case 2000 <= year && year < 2050: dst = appendTwoDigits(dst, year-2000) default: return nil, StructuralError{"cannot represent time as UTCTime"} } return appendTimeCommon(dst, t), nil } func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) { year := t.Year() if year < 0 || year > 9999 { return nil, StructuralError{"cannot represent time as GeneralizedTime"} } dst = appendFourDigits(dst, year) return appendTimeCommon(dst, t), nil } func appendTimeCommon(dst []byte, t time.Time) []byte { _, month, day := t.Date() dst = appendTwoDigits(dst, int(month)) dst = appendTwoDigits(dst, day) hour, min, sec := t.Clock() dst = appendTwoDigits(dst, hour) dst = appendTwoDigits(dst, min) dst = appendTwoDigits(dst, sec) _, offset := t.Zone() switch { case offset/60 == 0: return append(dst, 'Z') case offset > 0: dst = append(dst, '+') case offset < 0: dst = append(dst, '-') } offsetMinutes := offset / 60 if offsetMinutes < 0 { offsetMinutes = -offsetMinutes } dst = appendTwoDigits(dst, offsetMinutes/60) dst = appendTwoDigits(dst, offsetMinutes%60) return dst } func stripTagAndLength(in []byte) []byte { _, offset, err := parseTagAndLength(in, 0) if err != nil { return in } return in[offset:] } func makeBody(typ *_rawType, ptr unsafe.Pointer, params fieldParameters) (e encoder, err error) { // Known concrete types by type code pointer comparison. if typ == _flagTC { return bytesEncoder(nil), nil } if typ == _timeTC { t := *(*time.Time)(ptr) if params.timeType == TagGeneralizedTime || outsideUTCRange(t) { return makeGeneralizedTime(t) } return makeUTCTime(t) } if typ == _bitStringTC { return bitStringEncoder(*(*BitString)(ptr)), nil } if typ == _objectIdentifierTC { return makeObjectIdentifier(*(*ObjectIdentifier)(ptr)) } if typ == _bigIntPtrTC { return makeBigInt(*(**big.Int)(ptr)) } // Kind-based dispatch. switch typ.kind() { case akBool: if *(*bool)(ptr) { return byteFFEncoder, nil } return byte00Encoder, nil case akInt, akInt8, akInt16, akInt32, akInt64: var intVal int64 switch typ.size() { case 1: intVal = int64(*(*int8)(ptr)) case 2: intVal = int64(*(*int16)(ptr)) case 4: intVal = int64(*(*int32)(ptr)) default: intVal = *(*int64)(ptr) } return int64Encoder(intVal), nil case akStruct: for i := 0; i < typ.numField(); i++ { if !typ.fieldExported(i) { return nil, StructuralError{"struct contains unexported fields"} } } startingField := 0 n := typ.numField() if n == 0 { return bytesEncoder(nil), nil } if typ.fieldType(0) == _rawContentsTC { rc := *(*RawContent)(unsafe.Add(ptr, typ.fieldOffset(0))) if len(rc) > 0 { return bytesEncoder(stripTagAndLength([]byte(rc))), nil } startingField = 1 } switch n1 := n - startingField; n1 { case 0: return bytesEncoder(nil), nil case 1: ft := typ.fieldType(startingField) fp := unsafe.Add(ptr, typ.fieldOffset(startingField)) tag := getASN1Tag(typ.fieldTag(startingField)) return makeField(ft, fp, parseFieldParameters(tag)) default: m := []encoder{:n1} for i := 0; i < n1; i++ { fi := i + startingField ft := typ.fieldType(fi) fp := unsafe.Add(ptr, typ.fieldOffset(fi)) tag := getASN1Tag(typ.fieldTag(fi)) m[i], err = makeField(ft, fp, parseFieldParameters(tag)) if err != nil { return nil, err } } return multiEncoder(m), nil } case akSlice: if typ.elem().kind() == akUint8 { return bytesEncoder(*(*[]byte)(ptr)), nil } hdr := (*_sliceHeader)(ptr) l := hdr.len elemType := typ.elem() var fp fieldParameters switch l { case 0: return bytesEncoder(nil), nil case 1: return makeField(elemType, sliceIndex(ptr, elemType, 0), fp) default: m := []encoder{:l} for i := 0; i < l; i++ { m[i], err = makeField(elemType, sliceIndex(ptr, elemType, i), fp) if err != nil { return nil, err } } if params.set { return setEncoder(m), nil } return multiEncoder(m), nil } case akBytes: s := *(*[]byte)(ptr) switch params.stringType { case TagIA5String: return makeIA5String(s) case TagPrintableString: return makePrintableString(s) case TagNumericString: return makeNumericString(s) default: return makeUTF8String(s), nil } } return nil, StructuralError{"unknown Go type"} } // isZeroValue reports whether the value at ptr is all zero bytes. func isZeroValue(typ *_rawType, ptr unsafe.Pointer) bool { sz := typ.size() for i := uintptr(0); i < sz; i++ { if *(*byte)(unsafe.Add(ptr, i)) != 0 { return false } } return true } func makeField(typ *_rawType, ptr unsafe.Pointer, params fieldParameters) (e encoder, err error) { if typ == nil || ptr == nil { return nil, fmt.Errorf("asn1: cannot marshal nil value") } // If the field is an interface{} then recurse into the concrete value. if typ.kind() == akInterface { iface := (*_ifacePair)(ptr) if iface.typecode == nil { return nil, fmt.Errorf("asn1: cannot marshal nil value") } concreteType := (*_rawType)(iface.typecode) var concretePtr unsafe.Pointer if concreteType.size() <= unsafe.Sizeof(uintptr(0)) { concretePtr = unsafe.Pointer(&iface.value) } else { concretePtr = iface.value } return makeField(concreteType, concretePtr, params) } if typ.kind() == akSlice && (*_sliceHeader)(ptr).len == 0 && params.omitEmpty { return bytesEncoder(nil), nil } if params.optional && params.defaultValue != nil { k := typ.kind() if k == akInt || k == akInt8 || k == akInt16 || k == akInt32 || k == akInt64 { var current int64 switch typ.size() { case 1: current = int64(*(*int8)(ptr)) case 2: current = int64(*(*int16)(ptr)) case 4: current = int64(*(*int32)(ptr)) default: current = *(*int64)(ptr) } if current == *params.defaultValue { return bytesEncoder(nil), nil } } } if params.optional && params.defaultValue == nil { if isZeroValue(typ, ptr) { return bytesEncoder(nil), nil } } if typ == _rawValueTC { rv := *(*RawValue)(ptr) if len(rv.FullBytes) != 0 { return bytesEncoder(rv.FullBytes), nil } t := &taggedEncoder{} t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})) t.body = bytesEncoder(rv.Bytes) return t, nil } matchAny, tag, isCompound, ok := getUniversalType(typ) if !ok || matchAny { return nil, StructuralError{fmt.Sprintf("unknown Go type (kind %d)", typ.kind())} } if params.timeType != 0 && tag != TagUTCTime { return nil, StructuralError{"explicit time type given to non-time member"} } if params.stringType != 0 && tag != TagPrintableString { return nil, StructuralError{"explicit string type given to non-string member"} } switch tag { case TagPrintableString: if params.stringType == 0 { s := *(*[]byte)(ptr) for _, r := range s { if r >= utf8.RuneSelf || !isPrintable(r, rejectAsterisk, rejectAmpersand) { if !utf8.Valid(s) { return nil, errors.New("asn1: string not valid UTF-8") } tag = TagUTF8String break } } } else { tag = params.stringType } case TagUTCTime: if params.timeType == TagGeneralizedTime || outsideUTCRange(*(*time.Time)(ptr)) { tag = TagGeneralizedTime } } if params.set { if tag != TagSequence { return nil, StructuralError{"non sequence tagged as set"} } tag = TagSet } if tag == TagSet && !params.set { params.set = true } t := &taggedEncoder{} t.body, err = makeBody(typ, ptr, params) if err != nil { return nil, err } bodyLen := t.body.Len() class := ClassUniversal if params.tag != nil { if params.application { class = ClassApplication } else if params.private { class = ClassPrivate } else { class = ClassContextSpecific } if params.explicit { t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound})) tt := &taggedEncoder{} tt.body = t tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{ class: class, tag: *params.tag, length: bodyLen + t.tag.Len(), isCompound: true, })) return tt, nil } // implicit tag. tag = *params.tag } t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound})) return t, nil } // Marshal returns the ASN.1 encoding of val. func Marshal(val any) ([]byte, error) { return MarshalWithParams(val, "") } // MarshalWithParams allows field parameters to be specified for the // top-level element. The form of the params is the same as the field tags. func MarshalWithParams(val any, params []byte) ([]byte, error) { iface := (*_ifacePair)(unsafe.Pointer(&val)) tc := (*_rawType)(iface.typecode) if tc == nil { return nil, fmt.Errorf("asn1: cannot marshal nil value") } var dp unsafe.Pointer if tc.size() <= unsafe.Sizeof(uintptr(0)) { dp = unsafe.Pointer(&iface.value) } else { dp = iface.value } e, err := makeField(tc, dp, parseFieldParameters(params)) if err != nil { return nil, err } b := []byte{:e.Len()} e.Encode(b) return b, nil }