verify.go raw
1 package dkim
2
3 import (
4 "bufio"
5 "crypto"
6 "crypto/subtle"
7 "encoding/base64"
8 "errors"
9 "fmt"
10 "io"
11 "io/ioutil"
12 "regexp"
13 "strconv"
14 "strings"
15 "time"
16 "unicode"
17 )
18
19 type permFailError string
20
21 func (err permFailError) Error() string {
22 return "dkim: " + string(err)
23 }
24
25 // IsPermFail returns true if the error returned by Verify is a permanent
26 // failure. A permanent failure is for instance a missing required field or a
27 // malformed header.
28 func IsPermFail(err error) bool {
29 _, ok := err.(permFailError)
30 return ok
31 }
32
33 type tempFailError string
34
35 func (err tempFailError) Error() string {
36 return "dkim: " + string(err)
37 }
38
39 // IsTempFail returns true if the error returned by Verify is a temporary
40 // failure.
41 func IsTempFail(err error) bool {
42 _, ok := err.(tempFailError)
43 return ok
44 }
45
46 type failError string
47
48 func (err failError) Error() string {
49 return "dkim: " + string(err)
50 }
51
52 // isFail returns true if the error returned by Verify is a signature error.
53 func isFail(err error) bool {
54 _, ok := err.(failError)
55 return ok
56 }
57
58 // ErrTooManySignatures is returned by Verify when the message exceeds the
59 // maximum number of signatures.
60 var ErrTooManySignatures = errors.New("dkim: too many signatures")
61
62 var requiredTags = []string{"v", "a", "b", "bh", "d", "h", "s"}
63
64 // A Verification is produced by Verify when it checks if one signature is
65 // valid. If the signature is valid, Err is nil.
66 type Verification struct {
67 // The SDID claiming responsibility for an introduction of a message into the
68 // mail stream.
69 Domain string
70 // The Agent or User Identifier (AUID) on behalf of which the SDID is taking
71 // responsibility.
72 Identifier string
73
74 // The list of signed header fields.
75 HeaderKeys []string
76
77 // The time that this signature was created. If unknown, it's set to zero.
78 Time time.Time
79 // The expiration time. If the signature doesn't expire, it's set to zero.
80 Expiration time.Time
81
82 // Err is nil if the signature is valid.
83 Err error
84 }
85
86 type signature struct {
87 i int
88 v string
89 }
90
91 // VerifyOptions allows to customize the default signature verification
92 // behavior.
93 type VerifyOptions struct {
94 // LookupTXT returns the DNS TXT records for the given domain name. If nil,
95 // net.LookupTXT is used.
96 LookupTXT func(domain string) ([]string, error)
97 // MaxVerifications controls the maximum number of signature verifications
98 // to perform. If more signatures are present, the first MaxVerifications
99 // signatures are verified, the rest are ignored and ErrTooManySignatures
100 // is returned. If zero, there is no maximum.
101 MaxVerifications int
102 }
103
104 // Verify checks if a message's signatures are valid. It returns one
105 // verification per signature.
106 //
107 // There is no guarantee that the reader will be completely consumed.
108 func Verify(r io.Reader) ([]*Verification, error) {
109 return VerifyWithOptions(r, nil)
110 }
111
112 // VerifyWithOptions performs the same task as Verify, but allows specifying
113 // verification options.
114 func VerifyWithOptions(r io.Reader, options *VerifyOptions) ([]*Verification, error) {
115 // Read header
116 bufr := bufio.NewReader(r)
117 h, err := readHeader(bufr)
118 if err != nil {
119 return nil, err
120 }
121
122 // Scan header fields for signatures
123 var signatures []*signature
124 for i, kv := range h {
125 k, v := parseHeaderField(kv)
126 if strings.EqualFold(k, headerFieldName) {
127 signatures = append(signatures, &signature{i, v})
128 }
129 }
130
131 tooManySignatures := false
132 if options != nil && options.MaxVerifications > 0 && len(signatures) > options.MaxVerifications {
133 tooManySignatures = true
134 signatures = signatures[:options.MaxVerifications]
135 }
136
137 var verifs []*Verification
138 if len(signatures) == 1 {
139 // If there is only one signature - just verify it.
140 v, err := verify(h, bufr, h[signatures[0].i], signatures[0].v, options)
141 if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
142 return nil, err
143 }
144 v.Err = err
145 verifs = []*Verification{v}
146 } else {
147 verifs, err = parallelVerify(bufr, h, signatures, options)
148 if err != nil {
149 return nil, err
150 }
151 }
152
153 if tooManySignatures {
154 return verifs, ErrTooManySignatures
155 }
156 return verifs, nil
157 }
158
159 func parallelVerify(r io.Reader, h header, signatures []*signature, options *VerifyOptions) ([]*Verification, error) {
160 pipeWriters := make([]*io.PipeWriter, len(signatures))
161 // We can't pass pipeWriter to io.MultiWriter directly,
162 // we need a slice of io.Writer, but we also need *io.PipeWriter
163 // to call Close on it.
164 writers := make([]io.Writer, len(signatures))
165 chans := make([]chan *Verification, len(signatures))
166
167 for i, sig := range signatures {
168 // Be careful with loop variables and goroutines.
169 i, sig := i, sig
170
171 chans[i] = make(chan *Verification, 1)
172
173 pr, pw := io.Pipe()
174 writers[i] = pw
175 pipeWriters[i] = pw
176
177 go func() {
178 v, err := verify(h, pr, h[sig.i], sig.v, options)
179
180 // Make sure we consume the whole reader, otherwise io.Copy on
181 // other side can block forever.
182 io.Copy(ioutil.Discard, pr)
183
184 v.Err = err
185 chans[i] <- v
186 }()
187 }
188
189 if _, err := io.Copy(io.MultiWriter(writers...), r); err != nil {
190 return nil, err
191 }
192 for _, wr := range pipeWriters {
193 wr.Close()
194 }
195
196 verifications := make([]*Verification, len(signatures))
197 for i, ch := range chans {
198 verifications[i] = <-ch
199 }
200
201 // Return unexpected failures as a separate error.
202 for _, v := range verifications {
203 err := v.Err
204 if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
205 v.Err = nil
206 return verifications, err
207 }
208 }
209 return verifications, nil
210 }
211
212 func verify(h header, r io.Reader, sigField, sigValue string, options *VerifyOptions) (*Verification, error) {
213 verif := new(Verification)
214
215 params, err := parseHeaderParams(sigValue)
216 if err != nil {
217 return verif, permFailError("malformed signature tags: " + err.Error())
218 }
219
220 if params["v"] != "1" {
221 return verif, permFailError("incompatible signature version")
222 }
223
224 verif.Domain = stripWhitespace(params["d"])
225
226 for _, tag := range requiredTags {
227 if _, ok := params[tag]; !ok {
228 return verif, permFailError("signature missing required tag")
229 }
230 }
231
232 if i, ok := params["i"]; ok {
233 verif.Identifier = stripWhitespace(i)
234 if !strings.HasSuffix(verif.Identifier, "@"+verif.Domain) && !strings.HasSuffix(verif.Identifier, "."+verif.Domain) {
235 return verif, permFailError("domain mismatch")
236 }
237 } else {
238 verif.Identifier = "@" + verif.Domain
239 }
240
241 headerKeys := parseTagList(params["h"])
242 ok := false
243 for _, k := range headerKeys {
244 if strings.EqualFold(k, "from") {
245 ok = true
246 break
247 }
248 }
249 if !ok {
250 return verif, permFailError("From field not signed")
251 }
252 verif.HeaderKeys = headerKeys
253
254 if timeStr, ok := params["t"]; ok {
255 t, err := parseTime(timeStr)
256 if err != nil {
257 return verif, permFailError("malformed time: " + err.Error())
258 }
259 verif.Time = t
260 }
261 if expiresStr, ok := params["x"]; ok {
262 t, err := parseTime(expiresStr)
263 if err != nil {
264 return verif, permFailError("malformed expiration time: " + err.Error())
265 }
266 verif.Expiration = t
267 if now().After(t) {
268 return verif, permFailError("signature has expired")
269 }
270 }
271
272 // Query public key
273 // TODO: compute hash in parallel
274 methods := []string{string(QueryMethodDNSTXT)}
275 if methodsStr, ok := params["q"]; ok {
276 methods = parseTagList(methodsStr)
277 }
278 var res *queryResult
279 for _, method := range methods {
280 if query, ok := queryMethods[QueryMethod(method)]; ok {
281 if options != nil {
282 res, err = query(verif.Domain, stripWhitespace(params["s"]), options.LookupTXT)
283 } else {
284 res, err = query(verif.Domain, stripWhitespace(params["s"]), nil)
285 }
286 break
287 }
288 }
289 if err != nil {
290 return verif, err
291 } else if res == nil {
292 return verif, permFailError("unsupported public key query method")
293 }
294
295 // Parse algos
296 keyAlgo, hashAlgo, ok := strings.Cut(stripWhitespace(params["a"]), "-")
297 if !ok {
298 return verif, permFailError("malformed algorithm name")
299 }
300
301 // Check hash algo
302 if res.HashAlgos != nil {
303 ok := false
304 for _, algo := range res.HashAlgos {
305 if algo == hashAlgo {
306 ok = true
307 break
308 }
309 }
310 if !ok {
311 return verif, permFailError("inappropriate hash algorithm")
312 }
313 }
314 var hash crypto.Hash
315 switch hashAlgo {
316 case "sha1":
317 // RFC 8301 section 3.1: rsa-sha1 MUST NOT be used for signing or
318 // verifying.
319 return verif, permFailError(fmt.Sprintf("hash algorithm too weak: %v", hashAlgo))
320 case "sha256":
321 hash = crypto.SHA256
322 default:
323 return verif, permFailError("unsupported hash algorithm")
324 }
325
326 // Check key algo
327 if res.KeyAlgo != keyAlgo {
328 return verif, permFailError("inappropriate key algorithm")
329 }
330
331 if res.Services != nil {
332 ok := false
333 for _, s := range res.Services {
334 if s == "email" {
335 ok = true
336 break
337 }
338 }
339 if !ok {
340 return verif, permFailError("inappropriate service")
341 }
342 }
343
344 headerCan, bodyCan := parseCanonicalization(params["c"])
345 if _, ok := canonicalizers[headerCan]; !ok {
346 return verif, permFailError("unsupported header canonicalization algorithm")
347 }
348 if _, ok := canonicalizers[bodyCan]; !ok {
349 return verif, permFailError("unsupported body canonicalization algorithm")
350 }
351
352 // The body length "l" parameter is insecure, because it allows parts of
353 // the message body to not be signed. Reject messages which have it set.
354 if _, ok := params["l"]; ok {
355 // TODO: technically should be policyError
356 return verif, failError("message contains an insecure body length tag")
357 }
358
359 // Parse body hash and signature
360 bodyHashed, err := decodeBase64String(params["bh"])
361 if err != nil {
362 return verif, permFailError("malformed body hash: " + err.Error())
363 }
364 sig, err := decodeBase64String(params["b"])
365 if err != nil {
366 return verif, permFailError("malformed signature: " + err.Error())
367 }
368
369 // Check body hash
370 hasher := hash.New()
371 wc := canonicalizers[bodyCan].CanonicalizeBody(hasher)
372 if _, err := io.Copy(wc, r); err != nil {
373 return verif, err
374 }
375 if err := wc.Close(); err != nil {
376 return verif, err
377 }
378 if subtle.ConstantTimeCompare(hasher.Sum(nil), bodyHashed) != 1 {
379 return verif, failError("body hash did not verify")
380 }
381
382 // Compute data hash
383 hasher.Reset()
384 picker := newHeaderPicker(h)
385 for _, key := range headerKeys {
386 kv := picker.Pick(key)
387 if kv == "" {
388 // The field MAY contain names of header fields that do not exist
389 // when signed; nonexistent header fields do not contribute to the
390 // signature computation
391 continue
392 }
393
394 kv = canonicalizers[headerCan].CanonicalizeHeader(kv)
395 if _, err := hasher.Write([]byte(kv)); err != nil {
396 return verif, err
397 }
398 }
399 canSigField := removeSignature(sigField)
400 canSigField = canonicalizers[headerCan].CanonicalizeHeader(canSigField)
401 canSigField = strings.TrimRight(canSigField, "\r\n")
402 if _, err := hasher.Write([]byte(canSigField)); err != nil {
403 return verif, err
404 }
405 hashed := hasher.Sum(nil)
406
407 // Check signature
408 if err := res.Verifier.Verify(hash, hashed, sig); err != nil {
409 return verif, failError("signature did not verify: " + err.Error())
410 }
411
412 return verif, nil
413 }
414
415 func parseTagList(s string) []string {
416 tags := strings.Split(s, ":")
417 for i, t := range tags {
418 tags[i] = stripWhitespace(t)
419 }
420 return tags
421 }
422
423 func parseCanonicalization(s string) (headerCan, bodyCan Canonicalization) {
424 headerCan = CanonicalizationSimple
425 bodyCan = CanonicalizationSimple
426
427 cans := strings.SplitN(stripWhitespace(s), "/", 2)
428 if cans[0] != "" {
429 headerCan = Canonicalization(cans[0])
430 }
431 if len(cans) > 1 {
432 bodyCan = Canonicalization(cans[1])
433 }
434 return
435 }
436
437 func parseTime(s string) (time.Time, error) {
438 sec, err := strconv.ParseInt(stripWhitespace(s), 10, 64)
439 if err != nil {
440 return time.Time{}, err
441 }
442 return time.Unix(sec, 0), nil
443 }
444
445 func decodeBase64String(s string) ([]byte, error) {
446 return base64.StdEncoding.DecodeString(stripWhitespace(s))
447 }
448
449 func stripWhitespace(s string) string {
450 return strings.Map(func(r rune) rune {
451 if unicode.IsSpace(r) {
452 return -1
453 }
454 return r
455 }, s)
456 }
457
458 var sigRegex = regexp.MustCompile(`(b\s*=)[^;]+`)
459
460 func removeSignature(s string) string {
461 return sigRegex.ReplaceAllString(s, "$1")
462 }
463