read.go raw
1 //go:build !js
2 // +build !js
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "errors"
10 "fmt"
11 "io"
12 "net"
13 "strings"
14 "time"
15
16 "github.com/coder/websocket/internal/errd"
17 "github.com/coder/websocket/internal/util"
18 "github.com/coder/websocket/internal/xsync"
19 )
20
21 // Reader reads from the connection until there is a WebSocket
22 // data message to be read. It will handle ping, pong and close frames as appropriate.
23 //
24 // It returns the type of the message and an io.Reader to read it.
25 // The passed context will also bound the reader.
26 // Ensure you read to EOF otherwise the connection will hang.
27 //
28 // Call CloseRead if you do not expect any data messages from the peer.
29 //
30 // Only one Reader may be open at a time.
31 //
32 // If you need a separate timeout on the Reader call and the Read itself,
33 // use time.AfterFunc to cancel the context passed in.
34 // See https://github.com/nhooyr/websocket/issues/87#issue-451703332
35 // Most users should not need this.
36 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
37 return c.reader(ctx)
38 }
39
40 // Read is a convenience method around Reader to read a single message
41 // from the connection.
42 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
43 typ, r, err := c.Reader(ctx)
44 if err != nil {
45 return 0, nil, err
46 }
47
48 b, err := io.ReadAll(r)
49 return typ, b, err
50 }
51
52 // CloseRead starts a goroutine to read from the connection until it is closed
53 // or a data message is received.
54 //
55 // Once CloseRead is called you cannot read any messages from the connection.
56 // The returned context will be cancelled when the connection is closed.
57 //
58 // If a data message is received, the connection will be closed with StatusPolicyViolation.
59 //
60 // Call CloseRead when you do not expect to read any more messages.
61 // Since it actively reads from the connection, it will ensure that ping, pong and close
62 // frames are responded to. This means c.Ping and c.Close will still work as expected.
63 //
64 // This function is idempotent.
65 func (c *Conn) CloseRead(ctx context.Context) context.Context {
66 c.closeReadMu.Lock()
67 ctx2 := c.closeReadCtx
68 if ctx2 != nil {
69 c.closeReadMu.Unlock()
70 return ctx2
71 }
72 ctx, cancel := context.WithCancel(ctx)
73 c.closeReadCtx = ctx
74 c.closeReadDone = make(chan struct{})
75 c.closeReadMu.Unlock()
76
77 go func() {
78 defer close(c.closeReadDone)
79 defer cancel()
80 defer c.close()
81 _, _, err := c.Reader(ctx)
82 if err == nil {
83 c.Close(StatusPolicyViolation, "unexpected data message")
84 }
85 }()
86 return ctx
87 }
88
89 // SetReadLimit sets the max number of bytes to read for a single message.
90 // It applies to the Reader and Read methods.
91 //
92 // By default, the connection has a message read limit of 32768 bytes.
93 //
94 // When the limit is hit, the connection will be closed with StatusMessageTooBig.
95 //
96 // Set to -1 to disable.
97 func (c *Conn) SetReadLimit(n int64) {
98 if n >= 0 {
99 // We read one more byte than the limit in case
100 // there is a fin frame that needs to be read.
101 n++
102 }
103
104 c.msgReader.limitReader.limit.Store(n)
105 }
106
107 const defaultReadLimit = 32768
108
109 func newMsgReader(c *Conn) *msgReader {
110 mr := &msgReader{
111 c: c,
112 fin: true,
113 }
114 mr.readFunc = mr.read
115
116 mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
117 return mr
118 }
119
120 func (mr *msgReader) resetFlate() {
121 if mr.flateContextTakeover() {
122 if mr.dict == nil {
123 mr.dict = &slidingWindow{}
124 }
125 mr.dict.init(32768)
126 }
127 if mr.flateBufio == nil {
128 mr.flateBufio = getBufioReader(mr.readFunc)
129 }
130
131 if mr.flateContextTakeover() {
132 mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
133 } else {
134 mr.flateReader = getFlateReader(mr.flateBufio, nil)
135 }
136 mr.limitReader.r = mr.flateReader
137 mr.flateTail.Reset(deflateMessageTail)
138 }
139
140 func (mr *msgReader) putFlateReader() {
141 if mr.flateReader != nil {
142 putFlateReader(mr.flateReader)
143 mr.flateReader = nil
144 }
145 }
146
147 func (mr *msgReader) close() {
148 mr.c.readMu.forceLock()
149 mr.putFlateReader()
150 if mr.dict != nil {
151 mr.dict.close()
152 mr.dict = nil
153 }
154 if mr.flateBufio != nil {
155 putBufioReader(mr.flateBufio)
156 }
157
158 if mr.c.client {
159 putBufioReader(mr.c.br)
160 mr.c.br = nil
161 }
162 }
163
164 func (mr *msgReader) flateContextTakeover() bool {
165 if mr.c.client {
166 return !mr.c.copts.serverNoContextTakeover
167 }
168 return !mr.c.copts.clientNoContextTakeover
169 }
170
171 func (c *Conn) readRSV1Illegal(h header) bool {
172 // If compression is disabled, rsv1 is illegal.
173 if !c.flate() {
174 return true
175 }
176 // rsv1 is only allowed on data frames beginning messages.
177 if h.opcode != opText && h.opcode != opBinary {
178 return true
179 }
180 return false
181 }
182
183 func (c *Conn) readLoop(ctx context.Context) (header, error) {
184 for {
185 h, err := c.readFrameHeader(ctx)
186 if err != nil {
187 return header{}, err
188 }
189
190 if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
191 err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
192 c.writeError(StatusProtocolError, err)
193 return header{}, err
194 }
195
196 if !c.client && !h.masked {
197 return header{}, errors.New("received unmasked frame from client")
198 }
199
200 switch h.opcode {
201 case opClose, opPing, opPong:
202 err = c.handleControl(ctx, h)
203 if err != nil {
204 // Pass through CloseErrors when receiving a close frame.
205 if h.opcode == opClose && CloseStatus(err) != -1 {
206 return header{}, err
207 }
208 return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
209 }
210 case opContinuation, opText, opBinary:
211 return h, nil
212 default:
213 err := fmt.Errorf("received unknown opcode %v", h.opcode)
214 c.writeError(StatusProtocolError, err)
215 return header{}, err
216 }
217 }
218 }
219
220 func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
221 select {
222 case <-c.closed:
223 return header{}, net.ErrClosed
224 case c.readTimeout <- ctx:
225 }
226
227 h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
228 if err != nil {
229 select {
230 case <-c.closed:
231 return header{}, net.ErrClosed
232 case <-ctx.Done():
233 return header{}, ctx.Err()
234 default:
235 return header{}, err
236 }
237 }
238
239 select {
240 case <-c.closed:
241 return header{}, net.ErrClosed
242 case c.readTimeout <- context.Background():
243 }
244
245 return h, nil
246 }
247
248 func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
249 select {
250 case <-c.closed:
251 return 0, net.ErrClosed
252 case c.readTimeout <- ctx:
253 }
254
255 n, err := io.ReadFull(c.br, p)
256 if err != nil {
257 select {
258 case <-c.closed:
259 return n, net.ErrClosed
260 case <-ctx.Done():
261 return n, ctx.Err()
262 default:
263 return n, fmt.Errorf("failed to read frame payload: %w", err)
264 }
265 }
266
267 select {
268 case <-c.closed:
269 return n, net.ErrClosed
270 case c.readTimeout <- context.Background():
271 }
272
273 return n, err
274 }
275
276 func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
277 if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
278 err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
279 c.writeError(StatusProtocolError, err)
280 return err
281 }
282
283 if !h.fin {
284 err := errors.New("received fragmented control frame")
285 c.writeError(StatusProtocolError, err)
286 return err
287 }
288
289 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
290 defer cancel()
291
292 b := c.readControlBuf[:h.payloadLength]
293 _, err = c.readFramePayload(ctx, b)
294 if err != nil {
295 return err
296 }
297
298 if h.masked {
299 mask(b, h.maskKey)
300 }
301
302 switch h.opcode {
303 case opPing:
304 return c.writeControl(ctx, opPong, b)
305 case opPong:
306 c.activePingsMu.Lock()
307 pong, ok := c.activePings[string(b)]
308 c.activePingsMu.Unlock()
309 if ok {
310 select {
311 case pong <- struct{}{}:
312 default:
313 }
314 }
315 return nil
316 }
317
318 // opClose
319
320 ce, err := parseClosePayload(b)
321 if err != nil {
322 err = fmt.Errorf("received invalid close payload: %w", err)
323 c.writeError(StatusProtocolError, err)
324 return err
325 }
326
327 err = fmt.Errorf("received close frame: %w", ce)
328 c.writeClose(ce.Code, ce.Reason)
329 c.readMu.unlock()
330 c.close()
331 return err
332 }
333
334 func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
335 defer errd.Wrap(&err, "failed to get reader")
336
337 err = c.readMu.lock(ctx)
338 if err != nil {
339 return 0, nil, err
340 }
341 defer c.readMu.unlock()
342
343 if !c.msgReader.fin {
344 return 0, nil, errors.New("previous message not read to completion")
345 }
346
347 h, err := c.readLoop(ctx)
348 if err != nil {
349 return 0, nil, err
350 }
351
352 if h.opcode == opContinuation {
353 err := errors.New("received continuation frame without text or binary frame")
354 c.writeError(StatusProtocolError, err)
355 return 0, nil, err
356 }
357
358 c.msgReader.reset(ctx, h)
359
360 return MessageType(h.opcode), c.msgReader, nil
361 }
362
363 type msgReader struct {
364 c *Conn
365
366 ctx context.Context
367 flate bool
368 flateReader io.Reader
369 flateBufio *bufio.Reader
370 flateTail strings.Reader
371 limitReader *limitReader
372 dict *slidingWindow
373
374 fin bool
375 payloadLength int64
376 maskKey uint32
377
378 // util.ReaderFunc(mr.Read) to avoid continuous allocations.
379 readFunc util.ReaderFunc
380 }
381
382 func (mr *msgReader) reset(ctx context.Context, h header) {
383 mr.ctx = ctx
384 mr.flate = h.rsv1
385 mr.limitReader.reset(mr.readFunc)
386
387 if mr.flate {
388 mr.resetFlate()
389 }
390
391 mr.setFrame(h)
392 }
393
394 func (mr *msgReader) setFrame(h header) {
395 mr.fin = h.fin
396 mr.payloadLength = h.payloadLength
397 mr.maskKey = h.maskKey
398 }
399
400 func (mr *msgReader) Read(p []byte) (n int, err error) {
401 err = mr.c.readMu.lock(mr.ctx)
402 if err != nil {
403 return 0, fmt.Errorf("failed to read: %w", err)
404 }
405 defer mr.c.readMu.unlock()
406
407 n, err = mr.limitReader.Read(p)
408 if mr.flate && mr.flateContextTakeover() {
409 p = p[:n]
410 mr.dict.write(p)
411 }
412 if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
413 mr.putFlateReader()
414 return n, io.EOF
415 }
416 if err != nil {
417 return n, fmt.Errorf("failed to read: %w", err)
418 }
419 return n, nil
420 }
421
422 func (mr *msgReader) read(p []byte) (int, error) {
423 for {
424 if mr.payloadLength == 0 {
425 if mr.fin {
426 if mr.flate {
427 return mr.flateTail.Read(p)
428 }
429 return 0, io.EOF
430 }
431
432 h, err := mr.c.readLoop(mr.ctx)
433 if err != nil {
434 return 0, err
435 }
436 if h.opcode != opContinuation {
437 err := errors.New("received new data message without finishing the previous message")
438 mr.c.writeError(StatusProtocolError, err)
439 return 0, err
440 }
441 mr.setFrame(h)
442
443 continue
444 }
445
446 if int64(len(p)) > mr.payloadLength {
447 p = p[:mr.payloadLength]
448 }
449
450 n, err := mr.c.readFramePayload(mr.ctx, p)
451 if err != nil {
452 return n, err
453 }
454
455 mr.payloadLength -= int64(n)
456
457 if !mr.c.client {
458 mr.maskKey = mask(p, mr.maskKey)
459 }
460
461 return n, nil
462 }
463 }
464
465 type limitReader struct {
466 c *Conn
467 r io.Reader
468 limit xsync.Int64
469 n int64
470 }
471
472 func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
473 lr := &limitReader{
474 c: c,
475 }
476 lr.limit.Store(limit)
477 lr.reset(r)
478 return lr
479 }
480
481 func (lr *limitReader) reset(r io.Reader) {
482 lr.n = lr.limit.Load()
483 lr.r = r
484 }
485
486 func (lr *limitReader) Read(p []byte) (int, error) {
487 if lr.n < 0 {
488 return lr.r.Read(p)
489 }
490
491 if lr.n == 0 {
492 err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
493 lr.c.writeError(StatusMessageTooBig, err)
494 return 0, err
495 }
496
497 if int64(len(p)) > lr.n {
498 p = p[:lr.n]
499 }
500 n, err := lr.r.Read(p)
501 lr.n -= int64(n)
502 if lr.n < 0 {
503 lr.n = 0
504 }
505 return n, err
506 }
507