verify.go raw

   1  package dkim
   2  
   3  import (
   4  	"bufio"
   5  	"crypto"
   6  	"crypto/subtle"
   7  	"encoding/base64"
   8  	"errors"
   9  	"fmt"
  10  	"io"
  11  	"io/ioutil"
  12  	"regexp"
  13  	"strconv"
  14  	"strings"
  15  	"time"
  16  	"unicode"
  17  )
  18  
  19  type permFailError string
  20  
  21  func (err permFailError) Error() string {
  22  	return "dkim: " + string(err)
  23  }
  24  
  25  // IsPermFail returns true if the error returned by Verify is a permanent
  26  // failure. A permanent failure is for instance a missing required field or a
  27  // malformed header.
  28  func IsPermFail(err error) bool {
  29  	_, ok := err.(permFailError)
  30  	return ok
  31  }
  32  
  33  type tempFailError string
  34  
  35  func (err tempFailError) Error() string {
  36  	return "dkim: " + string(err)
  37  }
  38  
  39  // IsTempFail returns true if the error returned by Verify is a temporary
  40  // failure.
  41  func IsTempFail(err error) bool {
  42  	_, ok := err.(tempFailError)
  43  	return ok
  44  }
  45  
  46  type failError string
  47  
  48  func (err failError) Error() string {
  49  	return "dkim: " + string(err)
  50  }
  51  
  52  // isFail returns true if the error returned by Verify is a signature error.
  53  func isFail(err error) bool {
  54  	_, ok := err.(failError)
  55  	return ok
  56  }
  57  
  58  // ErrTooManySignatures is returned by Verify when the message exceeds the
  59  // maximum number of signatures.
  60  var ErrTooManySignatures = errors.New("dkim: too many signatures")
  61  
  62  var requiredTags = []string{"v", "a", "b", "bh", "d", "h", "s"}
  63  
  64  // A Verification is produced by Verify when it checks if one signature is
  65  // valid. If the signature is valid, Err is nil.
  66  type Verification struct {
  67  	// The SDID claiming responsibility for an introduction of a message into the
  68  	// mail stream.
  69  	Domain string
  70  	// The Agent or User Identifier (AUID) on behalf of which the SDID is taking
  71  	// responsibility.
  72  	Identifier string
  73  
  74  	// The list of signed header fields.
  75  	HeaderKeys []string
  76  
  77  	// The time that this signature was created. If unknown, it's set to zero.
  78  	Time time.Time
  79  	// The expiration time. If the signature doesn't expire, it's set to zero.
  80  	Expiration time.Time
  81  
  82  	// Err is nil if the signature is valid.
  83  	Err error
  84  }
  85  
  86  type signature struct {
  87  	i int
  88  	v string
  89  }
  90  
  91  // VerifyOptions allows to customize the default signature verification
  92  // behavior.
  93  type VerifyOptions struct {
  94  	// LookupTXT returns the DNS TXT records for the given domain name. If nil,
  95  	// net.LookupTXT is used.
  96  	LookupTXT func(domain string) ([]string, error)
  97  	// MaxVerifications controls the maximum number of signature verifications
  98  	// to perform. If more signatures are present, the first MaxVerifications
  99  	// signatures are verified, the rest are ignored and ErrTooManySignatures
 100  	// is returned. If zero, there is no maximum.
 101  	MaxVerifications int
 102  }
 103  
 104  // Verify checks if a message's signatures are valid. It returns one
 105  // verification per signature.
 106  //
 107  // There is no guarantee that the reader will be completely consumed.
 108  func Verify(r io.Reader) ([]*Verification, error) {
 109  	return VerifyWithOptions(r, nil)
 110  }
 111  
 112  // VerifyWithOptions performs the same task as Verify, but allows specifying
 113  // verification options.
 114  func VerifyWithOptions(r io.Reader, options *VerifyOptions) ([]*Verification, error) {
 115  	// Read header
 116  	bufr := bufio.NewReader(r)
 117  	h, err := readHeader(bufr)
 118  	if err != nil {
 119  		return nil, err
 120  	}
 121  
 122  	// Scan header fields for signatures
 123  	var signatures []*signature
 124  	for i, kv := range h {
 125  		k, v := parseHeaderField(kv)
 126  		if strings.EqualFold(k, headerFieldName) {
 127  			signatures = append(signatures, &signature{i, v})
 128  		}
 129  	}
 130  
 131  	tooManySignatures := false
 132  	if options != nil && options.MaxVerifications > 0 && len(signatures) > options.MaxVerifications {
 133  		tooManySignatures = true
 134  		signatures = signatures[:options.MaxVerifications]
 135  	}
 136  
 137  	var verifs []*Verification
 138  	if len(signatures) == 1 {
 139  		// If there is only one signature - just verify it.
 140  		v, err := verify(h, bufr, h[signatures[0].i], signatures[0].v, options)
 141  		if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
 142  			return nil, err
 143  		}
 144  		v.Err = err
 145  		verifs = []*Verification{v}
 146  	} else {
 147  		verifs, err = parallelVerify(bufr, h, signatures, options)
 148  		if err != nil {
 149  			return nil, err
 150  		}
 151  	}
 152  
 153  	if tooManySignatures {
 154  		return verifs, ErrTooManySignatures
 155  	}
 156  	return verifs, nil
 157  }
 158  
 159  func parallelVerify(r io.Reader, h header, signatures []*signature, options *VerifyOptions) ([]*Verification, error) {
 160  	pipeWriters := make([]*io.PipeWriter, len(signatures))
 161  	// We can't pass pipeWriter to io.MultiWriter directly,
 162  	// we need a slice of io.Writer, but we also need *io.PipeWriter
 163  	// to call Close on it.
 164  	writers := make([]io.Writer, len(signatures))
 165  	chans := make([]chan *Verification, len(signatures))
 166  
 167  	for i, sig := range signatures {
 168  		// Be careful with loop variables and goroutines.
 169  		i, sig := i, sig
 170  
 171  		chans[i] = make(chan *Verification, 1)
 172  
 173  		pr, pw := io.Pipe()
 174  		writers[i] = pw
 175  		pipeWriters[i] = pw
 176  
 177  		go func() {
 178  			v, err := verify(h, pr, h[sig.i], sig.v, options)
 179  
 180  			// Make sure we consume the whole reader, otherwise io.Copy on
 181  			// other side can block forever.
 182  			io.Copy(ioutil.Discard, pr)
 183  
 184  			v.Err = err
 185  			chans[i] <- v
 186  		}()
 187  	}
 188  
 189  	if _, err := io.Copy(io.MultiWriter(writers...), r); err != nil {
 190  		return nil, err
 191  	}
 192  	for _, wr := range pipeWriters {
 193  		wr.Close()
 194  	}
 195  
 196  	verifications := make([]*Verification, len(signatures))
 197  	for i, ch := range chans {
 198  		verifications[i] = <-ch
 199  	}
 200  
 201  	// Return unexpected failures as a separate error.
 202  	for _, v := range verifications {
 203  		err := v.Err
 204  		if err != nil && !IsTempFail(err) && !IsPermFail(err) && !isFail(err) {
 205  			v.Err = nil
 206  			return verifications, err
 207  		}
 208  	}
 209  	return verifications, nil
 210  }
 211  
 212  func verify(h header, r io.Reader, sigField, sigValue string, options *VerifyOptions) (*Verification, error) {
 213  	verif := new(Verification)
 214  
 215  	params, err := parseHeaderParams(sigValue)
 216  	if err != nil {
 217  		return verif, permFailError("malformed signature tags: " + err.Error())
 218  	}
 219  
 220  	if params["v"] != "1" {
 221  		return verif, permFailError("incompatible signature version")
 222  	}
 223  
 224  	verif.Domain = stripWhitespace(params["d"])
 225  
 226  	for _, tag := range requiredTags {
 227  		if _, ok := params[tag]; !ok {
 228  			return verif, permFailError("signature missing required tag")
 229  		}
 230  	}
 231  
 232  	if i, ok := params["i"]; ok {
 233  		verif.Identifier = stripWhitespace(i)
 234  		if !strings.HasSuffix(verif.Identifier, "@"+verif.Domain) && !strings.HasSuffix(verif.Identifier, "."+verif.Domain) {
 235  			return verif, permFailError("domain mismatch")
 236  		}
 237  	} else {
 238  		verif.Identifier = "@" + verif.Domain
 239  	}
 240  
 241  	headerKeys := parseTagList(params["h"])
 242  	ok := false
 243  	for _, k := range headerKeys {
 244  		if strings.EqualFold(k, "from") {
 245  			ok = true
 246  			break
 247  		}
 248  	}
 249  	if !ok {
 250  		return verif, permFailError("From field not signed")
 251  	}
 252  	verif.HeaderKeys = headerKeys
 253  
 254  	if timeStr, ok := params["t"]; ok {
 255  		t, err := parseTime(timeStr)
 256  		if err != nil {
 257  			return verif, permFailError("malformed time: " + err.Error())
 258  		}
 259  		verif.Time = t
 260  	}
 261  	if expiresStr, ok := params["x"]; ok {
 262  		t, err := parseTime(expiresStr)
 263  		if err != nil {
 264  			return verif, permFailError("malformed expiration time: " + err.Error())
 265  		}
 266  		verif.Expiration = t
 267  		if now().After(t) {
 268  			return verif, permFailError("signature has expired")
 269  		}
 270  	}
 271  
 272  	// Query public key
 273  	// TODO: compute hash in parallel
 274  	methods := []string{string(QueryMethodDNSTXT)}
 275  	if methodsStr, ok := params["q"]; ok {
 276  		methods = parseTagList(methodsStr)
 277  	}
 278  	var res *queryResult
 279  	for _, method := range methods {
 280  		if query, ok := queryMethods[QueryMethod(method)]; ok {
 281  			if options != nil {
 282  				res, err = query(verif.Domain, stripWhitespace(params["s"]), options.LookupTXT)
 283  			} else {
 284  				res, err = query(verif.Domain, stripWhitespace(params["s"]), nil)
 285  			}
 286  			break
 287  		}
 288  	}
 289  	if err != nil {
 290  		return verif, err
 291  	} else if res == nil {
 292  		return verif, permFailError("unsupported public key query method")
 293  	}
 294  
 295  	// Parse algos
 296  	keyAlgo, hashAlgo, ok := strings.Cut(stripWhitespace(params["a"]), "-")
 297  	if !ok {
 298  		return verif, permFailError("malformed algorithm name")
 299  	}
 300  
 301  	// Check hash algo
 302  	if res.HashAlgos != nil {
 303  		ok := false
 304  		for _, algo := range res.HashAlgos {
 305  			if algo == hashAlgo {
 306  				ok = true
 307  				break
 308  			}
 309  		}
 310  		if !ok {
 311  			return verif, permFailError("inappropriate hash algorithm")
 312  		}
 313  	}
 314  	var hash crypto.Hash
 315  	switch hashAlgo {
 316  	case "sha1":
 317  		// RFC 8301 section 3.1: rsa-sha1 MUST NOT be used for signing or
 318  		// verifying.
 319  		return verif, permFailError(fmt.Sprintf("hash algorithm too weak: %v", hashAlgo))
 320  	case "sha256":
 321  		hash = crypto.SHA256
 322  	default:
 323  		return verif, permFailError("unsupported hash algorithm")
 324  	}
 325  
 326  	// Check key algo
 327  	if res.KeyAlgo != keyAlgo {
 328  		return verif, permFailError("inappropriate key algorithm")
 329  	}
 330  
 331  	if res.Services != nil {
 332  		ok := false
 333  		for _, s := range res.Services {
 334  			if s == "email" {
 335  				ok = true
 336  				break
 337  			}
 338  		}
 339  		if !ok {
 340  			return verif, permFailError("inappropriate service")
 341  		}
 342  	}
 343  
 344  	headerCan, bodyCan := parseCanonicalization(params["c"])
 345  	if _, ok := canonicalizers[headerCan]; !ok {
 346  		return verif, permFailError("unsupported header canonicalization algorithm")
 347  	}
 348  	if _, ok := canonicalizers[bodyCan]; !ok {
 349  		return verif, permFailError("unsupported body canonicalization algorithm")
 350  	}
 351  
 352  	// The body length "l" parameter is insecure, because it allows parts of
 353  	// the message body to not be signed. Reject messages which have it set.
 354  	if _, ok := params["l"]; ok {
 355  		// TODO: technically should be policyError
 356  		return verif, failError("message contains an insecure body length tag")
 357  	}
 358  
 359  	// Parse body hash and signature
 360  	bodyHashed, err := decodeBase64String(params["bh"])
 361  	if err != nil {
 362  		return verif, permFailError("malformed body hash: " + err.Error())
 363  	}
 364  	sig, err := decodeBase64String(params["b"])
 365  	if err != nil {
 366  		return verif, permFailError("malformed signature: " + err.Error())
 367  	}
 368  
 369  	// Check body hash
 370  	hasher := hash.New()
 371  	wc := canonicalizers[bodyCan].CanonicalizeBody(hasher)
 372  	if _, err := io.Copy(wc, r); err != nil {
 373  		return verif, err
 374  	}
 375  	if err := wc.Close(); err != nil {
 376  		return verif, err
 377  	}
 378  	if subtle.ConstantTimeCompare(hasher.Sum(nil), bodyHashed) != 1 {
 379  		return verif, failError("body hash did not verify")
 380  	}
 381  
 382  	// Compute data hash
 383  	hasher.Reset()
 384  	picker := newHeaderPicker(h)
 385  	for _, key := range headerKeys {
 386  		kv := picker.Pick(key)
 387  		if kv == "" {
 388  			// The field MAY contain names of header fields that do not exist
 389  			// when signed; nonexistent header fields do not contribute to the
 390  			// signature computation
 391  			continue
 392  		}
 393  
 394  		kv = canonicalizers[headerCan].CanonicalizeHeader(kv)
 395  		if _, err := hasher.Write([]byte(kv)); err != nil {
 396  			return verif, err
 397  		}
 398  	}
 399  	canSigField := removeSignature(sigField)
 400  	canSigField = canonicalizers[headerCan].CanonicalizeHeader(canSigField)
 401  	canSigField = strings.TrimRight(canSigField, "\r\n")
 402  	if _, err := hasher.Write([]byte(canSigField)); err != nil {
 403  		return verif, err
 404  	}
 405  	hashed := hasher.Sum(nil)
 406  
 407  	// Check signature
 408  	if err := res.Verifier.Verify(hash, hashed, sig); err != nil {
 409  		return verif, failError("signature did not verify: " + err.Error())
 410  	}
 411  
 412  	return verif, nil
 413  }
 414  
 415  func parseTagList(s string) []string {
 416  	tags := strings.Split(s, ":")
 417  	for i, t := range tags {
 418  		tags[i] = stripWhitespace(t)
 419  	}
 420  	return tags
 421  }
 422  
 423  func parseCanonicalization(s string) (headerCan, bodyCan Canonicalization) {
 424  	headerCan = CanonicalizationSimple
 425  	bodyCan = CanonicalizationSimple
 426  
 427  	cans := strings.SplitN(stripWhitespace(s), "/", 2)
 428  	if cans[0] != "" {
 429  		headerCan = Canonicalization(cans[0])
 430  	}
 431  	if len(cans) > 1 {
 432  		bodyCan = Canonicalization(cans[1])
 433  	}
 434  	return
 435  }
 436  
 437  func parseTime(s string) (time.Time, error) {
 438  	sec, err := strconv.ParseInt(stripWhitespace(s), 10, 64)
 439  	if err != nil {
 440  		return time.Time{}, err
 441  	}
 442  	return time.Unix(sec, 0), nil
 443  }
 444  
 445  func decodeBase64String(s string) ([]byte, error) {
 446  	return base64.StdEncoding.DecodeString(stripWhitespace(s))
 447  }
 448  
 449  func stripWhitespace(s string) string {
 450  	return strings.Map(func(r rune) rune {
 451  		if unicode.IsSpace(r) {
 452  			return -1
 453  		}
 454  		return r
 455  	}, s)
 456  }
 457  
 458  var sigRegex = regexp.MustCompile(`(b\s*=)[^;]+`)
 459  
 460  func removeSignature(s string) string {
 461  	return sigRegex.ReplaceAllString(s, "$1")
 462  }
 463