sign.go raw

   1  package dkim
   2  
   3  import (
   4  	"bufio"
   5  	"bytes"
   6  	"crypto"
   7  	"crypto/rand"
   8  	"crypto/rsa"
   9  	"encoding/base64"
  10  	"fmt"
  11  	"io"
  12  	"strconv"
  13  	"strings"
  14  	"time"
  15  
  16  	"golang.org/x/crypto/ed25519"
  17  )
  18  
  19  var randReader io.Reader = rand.Reader
  20  
  21  // SignOptions is used to configure Sign. Domain, Selector and Signer are
  22  // mandatory.
  23  type SignOptions struct {
  24  	// The SDID claiming responsibility for an introduction of a message into the
  25  	// mail stream. Hence, the SDID value is used to form the query for the public
  26  	// key. The SDID MUST correspond to a valid DNS name under which the DKIM key
  27  	// record is published.
  28  	//
  29  	// This can't be empty.
  30  	Domain string
  31  	// The selector subdividing the namespace for the domain.
  32  	//
  33  	// This can't be empty.
  34  	Selector string
  35  	// The Agent or User Identifier (AUID) on behalf of which the SDID is taking
  36  	// responsibility.
  37  	//
  38  	// This is optional.
  39  	Identifier string
  40  
  41  	// The key used to sign the message.
  42  	//
  43  	// Supported Signer.Public() values are *rsa.PublicKey and
  44  	// ed25519.PublicKey.
  45  	Signer crypto.Signer
  46  	// The hash algorithm used to sign the message. If zero, a default hash will
  47  	// be chosen.
  48  	//
  49  	// The only supported hash algorithm is crypto.SHA256.
  50  	Hash crypto.Hash
  51  
  52  	// Header and body canonicalization algorithms.
  53  	//
  54  	// If empty, CanonicalizationSimple is used.
  55  	HeaderCanonicalization Canonicalization
  56  	BodyCanonicalization   Canonicalization
  57  
  58  	// A list of header fields to include in the signature. If nil, all headers
  59  	// will be included. If not nil, "From" MUST be in the list.
  60  	//
  61  	// See RFC 6376 section 5.4.1 for recommended header fields.
  62  	HeaderKeys []string
  63  
  64  	// The expiration time. A zero value means no expiration.
  65  	Expiration time.Time
  66  
  67  	// A list of query methods used to retrieve the public key.
  68  	//
  69  	// If nil, it is implicitly defined as QueryMethodDNSTXT.
  70  	QueryMethods []QueryMethod
  71  }
  72  
  73  // Signer generates a DKIM signature.
  74  //
  75  // The whole message header and body must be written to the Signer. Close should
  76  // always be called (either after the whole message has been written, or after
  77  // an error occurred and the signer won't be used anymore). Close may return an
  78  // error in case signing fails.
  79  //
  80  // After a successful Close, Signature can be called to retrieve the
  81  // DKIM-Signature header field that the caller should prepend to the message.
  82  type Signer struct {
  83  	pw        *io.PipeWriter
  84  	done      <-chan error
  85  	sigParams map[string]string // only valid after done received nil
  86  }
  87  
  88  // NewSigner creates a new signer. It returns an error if SignOptions is
  89  // invalid.
  90  func NewSigner(options *SignOptions) (*Signer, error) {
  91  	if options == nil {
  92  		return nil, fmt.Errorf("dkim: no options specified")
  93  	}
  94  	if options.Domain == "" {
  95  		return nil, fmt.Errorf("dkim: no domain specified")
  96  	}
  97  	if options.Selector == "" {
  98  		return nil, fmt.Errorf("dkim: no selector specified")
  99  	}
 100  	if options.Signer == nil {
 101  		return nil, fmt.Errorf("dkim: no signer specified")
 102  	}
 103  
 104  	headerCan := options.HeaderCanonicalization
 105  	if headerCan == "" {
 106  		headerCan = CanonicalizationSimple
 107  	}
 108  	if _, ok := canonicalizers[headerCan]; !ok {
 109  		return nil, fmt.Errorf("dkim: unknown header canonicalization %q", headerCan)
 110  	}
 111  
 112  	bodyCan := options.BodyCanonicalization
 113  	if bodyCan == "" {
 114  		bodyCan = CanonicalizationSimple
 115  	}
 116  	if _, ok := canonicalizers[bodyCan]; !ok {
 117  		return nil, fmt.Errorf("dkim: unknown body canonicalization %q", bodyCan)
 118  	}
 119  
 120  	var keyAlgo string
 121  	switch options.Signer.Public().(type) {
 122  	case *rsa.PublicKey:
 123  		keyAlgo = "rsa"
 124  	case ed25519.PublicKey:
 125  		keyAlgo = "ed25519"
 126  	default:
 127  		return nil, fmt.Errorf("dkim: unsupported key algorithm %T", options.Signer.Public())
 128  	}
 129  
 130  	hash := options.Hash
 131  	var hashAlgo string
 132  	switch options.Hash {
 133  	case 0: // sha256 is the default
 134  		hash = crypto.SHA256
 135  		fallthrough
 136  	case crypto.SHA256:
 137  		hashAlgo = "sha256"
 138  	case crypto.SHA1:
 139  		return nil, fmt.Errorf("dkim: hash algorithm too weak: sha1")
 140  	default:
 141  		return nil, fmt.Errorf("dkim: unsupported hash algorithm")
 142  	}
 143  
 144  	if options.HeaderKeys != nil {
 145  		ok := false
 146  		for _, k := range options.HeaderKeys {
 147  			if strings.EqualFold(k, "From") {
 148  				ok = true
 149  				break
 150  			}
 151  		}
 152  		if !ok {
 153  			return nil, fmt.Errorf("dkim: the From header field must be signed")
 154  		}
 155  	}
 156  
 157  	done := make(chan error, 1)
 158  	pr, pw := io.Pipe()
 159  
 160  	s := &Signer{
 161  		pw:   pw,
 162  		done: done,
 163  	}
 164  
 165  	closeReadWithError := func(err error) {
 166  		pr.CloseWithError(err)
 167  		done <- err
 168  	}
 169  
 170  	go func() {
 171  		defer close(done)
 172  
 173  		// Read header
 174  		br := bufio.NewReader(pr)
 175  		h, err := readHeader(br)
 176  		if err != nil {
 177  			closeReadWithError(err)
 178  			return
 179  		}
 180  
 181  		// Hash body
 182  		hasher := hash.New()
 183  		can := canonicalizers[bodyCan].CanonicalizeBody(hasher)
 184  		if _, err := io.Copy(can, br); err != nil {
 185  			closeReadWithError(err)
 186  			return
 187  		}
 188  		if err := can.Close(); err != nil {
 189  			closeReadWithError(err)
 190  			return
 191  		}
 192  		bodyHashed := hasher.Sum(nil)
 193  
 194  		params := map[string]string{
 195  			"v":  "1",
 196  			"a":  keyAlgo + "-" + hashAlgo,
 197  			"bh": base64.StdEncoding.EncodeToString(bodyHashed),
 198  			"c":  string(headerCan) + "/" + string(bodyCan),
 199  			"d":  options.Domain,
 200  			//"l": "", // TODO
 201  			"s": options.Selector,
 202  			"t": formatTime(now()),
 203  			//"z": "", // TODO
 204  		}
 205  
 206  		var headerKeys []string
 207  		if options.HeaderKeys != nil {
 208  			headerKeys = options.HeaderKeys
 209  		} else {
 210  			for _, kv := range h {
 211  				k, _ := parseHeaderField(kv)
 212  				headerKeys = append(headerKeys, k)
 213  			}
 214  		}
 215  		params["h"] = formatTagList(headerKeys)
 216  
 217  		if options.Identifier != "" {
 218  			params["i"] = options.Identifier
 219  		}
 220  
 221  		if options.QueryMethods != nil {
 222  			methods := make([]string, len(options.QueryMethods))
 223  			for i, method := range options.QueryMethods {
 224  				methods[i] = string(method)
 225  			}
 226  			params["q"] = formatTagList(methods)
 227  		}
 228  
 229  		if !options.Expiration.IsZero() {
 230  			params["x"] = formatTime(options.Expiration)
 231  		}
 232  
 233  		// Hash and sign headers
 234  		hasher.Reset()
 235  		picker := newHeaderPicker(h)
 236  		for _, k := range headerKeys {
 237  			kv := picker.Pick(k)
 238  			if kv == "" {
 239  				// The Signer MAY include more instances of a header field name
 240  				// in "h=" than there are actual corresponding header fields so
 241  				// that the signature will not verify if additional header
 242  				// fields of that name are added.
 243  				continue
 244  			}
 245  
 246  			kv = canonicalizers[headerCan].CanonicalizeHeader(kv)
 247  			if _, err := io.WriteString(hasher, kv); err != nil {
 248  				closeReadWithError(err)
 249  				return
 250  			}
 251  		}
 252  
 253  		params["b"] = ""
 254  		sigField := formatSignature(params)
 255  		sigField = canonicalizers[headerCan].CanonicalizeHeader(sigField)
 256  		sigField = strings.TrimRight(sigField, crlf)
 257  		if _, err := io.WriteString(hasher, sigField); err != nil {
 258  			closeReadWithError(err)
 259  			return
 260  		}
 261  		hashed := hasher.Sum(nil)
 262  
 263  		// Don't pass Hash to Sign for ed25519 as it doesn't support it
 264  		// and will return an error ("ed25519: cannot sign hashed message").
 265  		if keyAlgo == "ed25519" {
 266  			hash = crypto.Hash(0)
 267  		}
 268  
 269  		sig, err := options.Signer.Sign(randReader, hashed, hash)
 270  		if err != nil {
 271  			closeReadWithError(err)
 272  			return
 273  		}
 274  		params["b"] = base64.StdEncoding.EncodeToString(sig)
 275  
 276  		s.sigParams = params
 277  		closeReadWithError(nil)
 278  	}()
 279  
 280  	return s, nil
 281  }
 282  
 283  // Write implements io.WriteCloser.
 284  func (s *Signer) Write(b []byte) (n int, err error) {
 285  	return s.pw.Write(b)
 286  }
 287  
 288  // Close implements io.WriteCloser. The error return by Close must be checked.
 289  func (s *Signer) Close() error {
 290  	if err := s.pw.Close(); err != nil {
 291  		return err
 292  	}
 293  	return <-s.done
 294  }
 295  
 296  // Signature returns the whole DKIM-Signature header field. It can only be
 297  // called after a successful Signer.Close call.
 298  //
 299  // The returned value contains both the header field name, its value and the
 300  // final CRLF.
 301  func (s *Signer) Signature() string {
 302  	if s.sigParams == nil {
 303  		panic("dkim: Signer.Signature must only be called after a succesful Signer.Close")
 304  	}
 305  	return formatSignature(s.sigParams)
 306  }
 307  
 308  // Sign signs a message. It reads it from r and writes the signed version to w.
 309  func Sign(w io.Writer, r io.Reader, options *SignOptions) error {
 310  	s, err := NewSigner(options)
 311  	if err != nil {
 312  		return err
 313  	}
 314  	defer s.Close()
 315  
 316  	// We need to keep the message in a buffer so we can write the new DKIM
 317  	// header field before the rest of the message
 318  	var b bytes.Buffer
 319  	mw := io.MultiWriter(&b, s)
 320  
 321  	if _, err := io.Copy(mw, r); err != nil {
 322  		return err
 323  	}
 324  	if err := s.Close(); err != nil {
 325  		return err
 326  	}
 327  
 328  	if _, err := io.WriteString(w, s.Signature()); err != nil {
 329  		return err
 330  	}
 331  	_, err = io.Copy(w, &b)
 332  	return err
 333  }
 334  
 335  func formatSignature(params map[string]string) string {
 336  	sig := formatHeaderParams(headerFieldName, params)
 337  	return sig
 338  }
 339  
 340  func formatTagList(l []string) string {
 341  	return strings.Join(l, ":")
 342  }
 343  
 344  func formatTime(t time.Time) string {
 345  	return strconv.FormatInt(t.Unix(), 10)
 346  }
 347