1 //go:build !js
2 // +build !js
3 4 package websocket
5 6 import (
7 "compress/flate"
8 "io"
9 "sync"
10 )
11 12 // CompressionMode represents the modes available to the permessage-deflate extension.
13 // See https://tools.ietf.org/html/rfc7692
14 //
15 // Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
16 //
17 // Compression is only used if the peer supports the mode selected.
18 type CompressionMode int
19 20 const (
21 // CompressionDisabled disables the negotiation of the permessage-deflate extension.
22 //
23 // This is the default. Do not enable compression without benchmarking for your particular use case first.
24 CompressionDisabled CompressionMode = iota
25 26 // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
27 // previous messages. i.e compression context across messages is preserved.
28 //
29 // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
30 //
31 // The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
32 // that are used when reading and then returned.
33 //
34 // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
35 //
36 // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
37 CompressionContextTakeover
38 39 // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
40 // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
41 // a sync.Pool.
42 //
43 // This means less efficient compression as the sliding window from previous messages will not be used but the
44 // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
45 // Especially if the connections are long lived and seldom written to.
46 //
47 // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
48 //
49 // If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
50 CompressionNoContextTakeover
51 )
52 53 func (m CompressionMode) opts() *compressionOptions {
54 return &compressionOptions{
55 clientNoContextTakeover: m == CompressionNoContextTakeover,
56 serverNoContextTakeover: m == CompressionNoContextTakeover,
57 }
58 }
59 60 type compressionOptions struct {
61 clientNoContextTakeover bool
62 serverNoContextTakeover bool
63 }
64 65 func (copts *compressionOptions) String() string {
66 s := "permessage-deflate"
67 if copts.clientNoContextTakeover {
68 s += "; client_no_context_takeover"
69 }
70 if copts.serverNoContextTakeover {
71 s += "; server_no_context_takeover"
72 }
73 return s
74 }
75 76 // These bytes are required to get flate.Reader to return.
77 // They are removed when sending to avoid the overhead as
78 // WebSocket framing tell's when the message has ended but then
79 // we need to add them back otherwise flate.Reader keeps
80 // trying to read more bytes.
81 const deflateMessageTail = "\x00\x00\xff\xff"
82 83 type trimLastFourBytesWriter struct {
84 w io.Writer
85 tail []byte
86 }
87 88 func (tw *trimLastFourBytesWriter) reset() {
89 if tw != nil && tw.tail != nil {
90 tw.tail = tw.tail[:0]
91 }
92 }
93 94 func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
95 if tw.tail == nil {
96 tw.tail = make([]byte, 0, 4)
97 }
98 99 extra := len(tw.tail) + len(p) - 4
100 101 if extra <= 0 {
102 tw.tail = append(tw.tail, p...)
103 return len(p), nil
104 }
105 106 // Now we need to write as many extra bytes as we can from the previous tail.
107 if extra > len(tw.tail) {
108 extra = len(tw.tail)
109 }
110 if extra > 0 {
111 _, err := tw.w.Write(tw.tail[:extra])
112 if err != nil {
113 return 0, err
114 }
115 116 // Shift remaining bytes in tail over.
117 n := copy(tw.tail, tw.tail[extra:])
118 tw.tail = tw.tail[:n]
119 }
120 121 // If p is less than or equal to 4 bytes,
122 // all of it is is part of the tail.
123 if len(p) <= 4 {
124 tw.tail = append(tw.tail, p...)
125 return len(p), nil
126 }
127 128 // Otherwise, only the last 4 bytes are.
129 tw.tail = append(tw.tail, p[len(p)-4:]...)
130 131 p = p[:len(p)-4]
132 n, err := tw.w.Write(p)
133 return n + 4, err
134 }
135 136 var flateReaderPool sync.Pool
137 138 func getFlateReader(r io.Reader, dict []byte) io.Reader {
139 fr, ok := flateReaderPool.Get().(io.Reader)
140 if !ok {
141 return flate.NewReaderDict(r, dict)
142 }
143 fr.(flate.Resetter).Reset(r, dict)
144 return fr
145 }
146 147 func putFlateReader(fr io.Reader) {
148 flateReaderPool.Put(fr)
149 }
150 151 var flateWriterPool sync.Pool
152 153 func getFlateWriter(w io.Writer) *flate.Writer {
154 fw, ok := flateWriterPool.Get().(*flate.Writer)
155 if !ok {
156 fw, _ = flate.NewWriter(w, flate.BestSpeed)
157 return fw
158 }
159 fw.Reset(w)
160 return fw
161 }
162 163 func putFlateWriter(w *flate.Writer) {
164 flateWriterPool.Put(w)
165 }
166 167 type slidingWindow struct {
168 buf []byte
169 }
170 171 var swPoolMu sync.RWMutex
172 var swPool = map[int]*sync.Pool{}
173 174 func slidingWindowPool(n int) *sync.Pool {
175 swPoolMu.RLock()
176 p, ok := swPool[n]
177 swPoolMu.RUnlock()
178 if ok {
179 return p
180 }
181 182 p = &sync.Pool{}
183 184 swPoolMu.Lock()
185 swPool[n] = p
186 swPoolMu.Unlock()
187 188 return p
189 }
190 191 func (sw *slidingWindow) init(n int) {
192 if sw.buf != nil {
193 return
194 }
195 196 if n == 0 {
197 n = 32768
198 }
199 200 p := slidingWindowPool(n)
201 sw2, ok := p.Get().(*slidingWindow)
202 if ok {
203 *sw = *sw2
204 } else {
205 sw.buf = make([]byte, 0, n)
206 }
207 }
208 209 func (sw *slidingWindow) close() {
210 sw.buf = sw.buf[:0]
211 swPoolMu.Lock()
212 swPool[cap(sw.buf)].Put(sw)
213 swPoolMu.Unlock()
214 }
215 216 func (sw *slidingWindow) write(p []byte) {
217 if len(p) >= cap(sw.buf) {
218 sw.buf = sw.buf[:cap(sw.buf)]
219 p = p[len(p)-cap(sw.buf):]
220 copy(sw.buf, p)
221 return
222 }
223 224 left := cap(sw.buf) - len(sw.buf)
225 if left < len(p) {
226 // We need to shift spaceNeeded bytes from the end to make room for p at the end.
227 spaceNeeded := len(p) - left
228 copy(sw.buf, sw.buf[spaceNeeded:])
229 sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
230 }
231 232 sw.buf = append(sw.buf, p...)
233 }
234