compress.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"compress/flate"
   8  	"io"
   9  	"sync"
  10  )
  11  
  12  // CompressionMode represents the modes available to the permessage-deflate extension.
  13  // See https://tools.ietf.org/html/rfc7692
  14  //
  15  // Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
  16  //
  17  // Compression is only used if the peer supports the mode selected.
  18  type CompressionMode int
  19  
  20  const (
  21  	// CompressionDisabled disables the negotiation of the permessage-deflate extension.
  22  	//
  23  	// This is the default. Do not enable compression without benchmarking for your particular use case first.
  24  	CompressionDisabled CompressionMode = iota
  25  
  26  	// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
  27  	// previous messages. i.e compression context across messages is preserved.
  28  	//
  29  	// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
  30  	//
  31  	// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
  32  	// that are used when reading and then returned.
  33  	//
  34  	// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
  35  	//
  36  	// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
  37  	CompressionContextTakeover
  38  
  39  	// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
  40  	// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
  41  	// a sync.Pool.
  42  	//
  43  	// This means less efficient compression as the sliding window from previous messages will not be used but the
  44  	// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
  45  	// Especially if the connections are long lived and seldom written to.
  46  	//
  47  	// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
  48  	//
  49  	// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
  50  	CompressionNoContextTakeover
  51  )
  52  
  53  func (m CompressionMode) opts() *compressionOptions {
  54  	return &compressionOptions{
  55  		clientNoContextTakeover: m == CompressionNoContextTakeover,
  56  		serverNoContextTakeover: m == CompressionNoContextTakeover,
  57  	}
  58  }
  59  
  60  type compressionOptions struct {
  61  	clientNoContextTakeover bool
  62  	serverNoContextTakeover bool
  63  }
  64  
  65  func (copts *compressionOptions) String() string {
  66  	s := "permessage-deflate"
  67  	if copts.clientNoContextTakeover {
  68  		s += "; client_no_context_takeover"
  69  	}
  70  	if copts.serverNoContextTakeover {
  71  		s += "; server_no_context_takeover"
  72  	}
  73  	return s
  74  }
  75  
  76  // These bytes are required to get flate.Reader to return.
  77  // They are removed when sending to avoid the overhead as
  78  // WebSocket framing tell's when the message has ended but then
  79  // we need to add them back otherwise flate.Reader keeps
  80  // trying to read more bytes.
  81  const deflateMessageTail = "\x00\x00\xff\xff"
  82  
  83  type trimLastFourBytesWriter struct {
  84  	w    io.Writer
  85  	tail []byte
  86  }
  87  
  88  func (tw *trimLastFourBytesWriter) reset() {
  89  	if tw != nil && tw.tail != nil {
  90  		tw.tail = tw.tail[:0]
  91  	}
  92  }
  93  
  94  func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
  95  	if tw.tail == nil {
  96  		tw.tail = make([]byte, 0, 4)
  97  	}
  98  
  99  	extra := len(tw.tail) + len(p) - 4
 100  
 101  	if extra <= 0 {
 102  		tw.tail = append(tw.tail, p...)
 103  		return len(p), nil
 104  	}
 105  
 106  	// Now we need to write as many extra bytes as we can from the previous tail.
 107  	if extra > len(tw.tail) {
 108  		extra = len(tw.tail)
 109  	}
 110  	if extra > 0 {
 111  		_, err := tw.w.Write(tw.tail[:extra])
 112  		if err != nil {
 113  			return 0, err
 114  		}
 115  
 116  		// Shift remaining bytes in tail over.
 117  		n := copy(tw.tail, tw.tail[extra:])
 118  		tw.tail = tw.tail[:n]
 119  	}
 120  
 121  	// If p is less than or equal to 4 bytes,
 122  	// all of it is is part of the tail.
 123  	if len(p) <= 4 {
 124  		tw.tail = append(tw.tail, p...)
 125  		return len(p), nil
 126  	}
 127  
 128  	// Otherwise, only the last 4 bytes are.
 129  	tw.tail = append(tw.tail, p[len(p)-4:]...)
 130  
 131  	p = p[:len(p)-4]
 132  	n, err := tw.w.Write(p)
 133  	return n + 4, err
 134  }
 135  
 136  var flateReaderPool sync.Pool
 137  
 138  func getFlateReader(r io.Reader, dict []byte) io.Reader {
 139  	fr, ok := flateReaderPool.Get().(io.Reader)
 140  	if !ok {
 141  		return flate.NewReaderDict(r, dict)
 142  	}
 143  	fr.(flate.Resetter).Reset(r, dict)
 144  	return fr
 145  }
 146  
 147  func putFlateReader(fr io.Reader) {
 148  	flateReaderPool.Put(fr)
 149  }
 150  
 151  var flateWriterPool sync.Pool
 152  
 153  func getFlateWriter(w io.Writer) *flate.Writer {
 154  	fw, ok := flateWriterPool.Get().(*flate.Writer)
 155  	if !ok {
 156  		fw, _ = flate.NewWriter(w, flate.BestSpeed)
 157  		return fw
 158  	}
 159  	fw.Reset(w)
 160  	return fw
 161  }
 162  
 163  func putFlateWriter(w *flate.Writer) {
 164  	flateWriterPool.Put(w)
 165  }
 166  
 167  type slidingWindow struct {
 168  	buf []byte
 169  }
 170  
 171  var swPoolMu sync.RWMutex
 172  var swPool = map[int]*sync.Pool{}
 173  
 174  func slidingWindowPool(n int) *sync.Pool {
 175  	swPoolMu.RLock()
 176  	p, ok := swPool[n]
 177  	swPoolMu.RUnlock()
 178  	if ok {
 179  		return p
 180  	}
 181  
 182  	p = &sync.Pool{}
 183  
 184  	swPoolMu.Lock()
 185  	swPool[n] = p
 186  	swPoolMu.Unlock()
 187  
 188  	return p
 189  }
 190  
 191  func (sw *slidingWindow) init(n int) {
 192  	if sw.buf != nil {
 193  		return
 194  	}
 195  
 196  	if n == 0 {
 197  		n = 32768
 198  	}
 199  
 200  	p := slidingWindowPool(n)
 201  	sw2, ok := p.Get().(*slidingWindow)
 202  	if ok {
 203  		*sw = *sw2
 204  	} else {
 205  		sw.buf = make([]byte, 0, n)
 206  	}
 207  }
 208  
 209  func (sw *slidingWindow) close() {
 210  	sw.buf = sw.buf[:0]
 211  	swPoolMu.Lock()
 212  	swPool[cap(sw.buf)].Put(sw)
 213  	swPoolMu.Unlock()
 214  }
 215  
 216  func (sw *slidingWindow) write(p []byte) {
 217  	if len(p) >= cap(sw.buf) {
 218  		sw.buf = sw.buf[:cap(sw.buf)]
 219  		p = p[len(p)-cap(sw.buf):]
 220  		copy(sw.buf, p)
 221  		return
 222  	}
 223  
 224  	left := cap(sw.buf) - len(sw.buf)
 225  	if left < len(p) {
 226  		// We need to shift spaceNeeded bytes from the end to make room for p at the end.
 227  		spaceNeeded := len(p) - left
 228  		copy(sw.buf, sw.buf[spaceNeeded:])
 229  		sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
 230  	}
 231  
 232  	sw.buf = append(sw.buf, p...)
 233  }
 234