writer.mx raw

   1  // Copyright 2011 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package multipart
   6  
   7  import (
   8  	"bytes"
   9  	"crypto/rand"
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  	"maps"
  14  	"net/textproto"
  15  	"slices"
  16  )
  17  
  18  // A Writer generates multipart messages.
  19  type Writer struct {
  20  	w        io.Writer
  21  	boundary []byte
  22  	lastpart *part
  23  }
  24  
  25  // NewWriter returns a new multipart [Writer] with a random boundary,
  26  // writing to w.
  27  func NewWriter(w io.Writer) *Writer {
  28  	return &Writer{
  29  		w:        w,
  30  		boundary: randomBoundary(),
  31  	}
  32  }
  33  
  34  // Boundary returns the [Writer]'s boundary.
  35  func (w *Writer) Boundary() []byte {
  36  	return w.boundary
  37  }
  38  
  39  // SetBoundary overrides the [Writer]'s default randomly-generated
  40  // boundary separator with an explicit value.
  41  //
  42  // SetBoundary must be called before any parts are created, may only
  43  // contain certain ASCII characters, and must be non-empty and
  44  // at most 70 bytes long.
  45  func (w *Writer) SetBoundary(boundary []byte) error {
  46  	if w.lastpart != nil {
  47  		return errors.New("mime: SetBoundary called after write")
  48  	}
  49  	// rfc2046#section-5.1.1
  50  	if len(boundary) < 1 || len(boundary) > 70 {
  51  		return errors.New("mime: invalid boundary length")
  52  	}
  53  	end := len(boundary) - 1
  54  	for i, b := range boundary {
  55  		if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
  56  			continue
  57  		}
  58  		switch b {
  59  		case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
  60  			continue
  61  		case ' ':
  62  			if i != end {
  63  				continue
  64  			}
  65  		}
  66  		return errors.New("mime: invalid boundary character")
  67  	}
  68  	w.boundary = boundary
  69  	return nil
  70  }
  71  
  72  // FormDataContentType returns the Content-Type for an HTTP
  73  // multipart/form-data with this [Writer]'s Boundary.
  74  func (w *Writer) FormDataContentType() []byte {
  75  	b := w.boundary
  76  	// We must quote the boundary if it contains any of the
  77  	// tspecials characters defined by RFC 2045, or space.
  78  	if bytes.ContainsAny(b, `()<>@,;:\"/[]?= `) {
  79  		b = `"` + b + `"`
  80  	}
  81  	return "multipart/form-data; boundary=" + b
  82  }
  83  
  84  func randomBoundary() []byte {
  85  	var buf [30]byte
  86  	_, err := io.ReadFull(rand.Reader, buf[:])
  87  	if err != nil {
  88  		panic(err)
  89  	}
  90  	return fmt.Sprintf("%x", buf[:])
  91  }
  92  
  93  // CreatePart creates a new multipart section with the provided
  94  // header. The body of the part should be written to the returned
  95  // [Writer]. After calling CreatePart, any previous part may no longer
  96  // be written to.
  97  func (w *Writer) CreatePart(header textproto.MIMEHeader) (io.Writer, error) {
  98  	if w.lastpart != nil {
  99  		if err := w.lastpart.close(); err != nil {
 100  			return nil, err
 101  		}
 102  	}
 103  	var b bytes.Buffer
 104  	if w.lastpart != nil {
 105  		fmt.Fprintf(&b, "\r\n--%s\r\n", w.boundary)
 106  	} else {
 107  		fmt.Fprintf(&b, "--%s\r\n", w.boundary)
 108  	}
 109  
 110  	for _, k := range slices.Sorted(maps.Keys(header)) {
 111  		for _, v := range header[k] {
 112  			fmt.Fprintf(&b, "%s: %s\r\n", k, v)
 113  		}
 114  	}
 115  	fmt.Fprintf(&b, "\r\n")
 116  	_, err := io.Copy(w.w, &b)
 117  	if err != nil {
 118  		return nil, err
 119  	}
 120  	p := &part{
 121  		mw: w,
 122  	}
 123  	w.lastpart = p
 124  	return p, nil
 125  }
 126  
 127  var quoteEscaper = bytes.NewReplacer("\\", "\\\\", `"`, "\\\"")
 128  
 129  func escapeQuotes(s []byte) []byte {
 130  	return quoteEscaper.Replace(s)
 131  }
 132  
 133  // CreateFormFile is a convenience wrapper around [Writer.CreatePart]. It creates
 134  // a new form-data header with the provided field name and file name.
 135  func (w *Writer) CreateFormFile(fieldname, filename []byte) (io.Writer, error) {
 136  	h := make(textproto.MIMEHeader)
 137  	h.Set("Content-Disposition", FileContentDisposition(fieldname, filename))
 138  	h.Set("Content-Type", "application/octet-stream")
 139  	return w.CreatePart(h)
 140  }
 141  
 142  // CreateFormField calls [Writer.CreatePart] with a header using the
 143  // given field name.
 144  func (w *Writer) CreateFormField(fieldname []byte) (io.Writer, error) {
 145  	h := make(textproto.MIMEHeader)
 146  	h.Set("Content-Disposition",
 147  		fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldname)))
 148  	return w.CreatePart(h)
 149  }
 150  
 151  // FileContentDisposition returns the value of a Content-Disposition header
 152  // with the provided field name and file name.
 153  func FileContentDisposition(fieldname, filename []byte) []byte {
 154  	return fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
 155  		escapeQuotes(fieldname), escapeQuotes(filename))
 156  }
 157  
 158  // WriteField calls [Writer.CreateFormField] and then writes the given value.
 159  func (w *Writer) WriteField(fieldname, value []byte) error {
 160  	p, err := w.CreateFormField(fieldname)
 161  	if err != nil {
 162  		return err
 163  	}
 164  	_, err = p.Write([]byte(value))
 165  	return err
 166  }
 167  
 168  // Close finishes the multipart message and writes the trailing
 169  // boundary end line to the output.
 170  func (w *Writer) Close() error {
 171  	if w.lastpart != nil {
 172  		if err := w.lastpart.close(); err != nil {
 173  			return err
 174  		}
 175  		w.lastpart = nil
 176  	}
 177  	_, err := fmt.Fprintf(w.w, "\r\n--%s--\r\n", w.boundary)
 178  	return err
 179  }
 180  
 181  type part struct {
 182  	mw     *Writer
 183  	closed bool
 184  	we     error // last error that occurred writing
 185  }
 186  
 187  func (p *part) close() error {
 188  	p.closed = true
 189  	return p.we
 190  }
 191  
 192  func (p *part) Write(d []byte) (n int, err error) {
 193  	if p.closed {
 194  		return 0, errors.New("multipart: can't write to finished part")
 195  	}
 196  	n, err = p.mw.w.Write(d)
 197  	if err != nil {
 198  		p.we = err
 199  	}
 200  	return
 201  }
 202