close.go raw
1 //go:build !js
2 // +build !js
3
4 package websocket
5
6 import (
7 "context"
8 "encoding/binary"
9 "errors"
10 "fmt"
11 "net"
12 "time"
13
14 "github.com/coder/websocket/internal/errd"
15 )
16
17 // StatusCode represents a WebSocket status code.
18 // https://tools.ietf.org/html/rfc6455#section-7.4
19 type StatusCode int
20
21 // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
22 //
23 // These are only the status codes defined by the protocol.
24 //
25 // You can define custom codes in the 3000-4999 range.
26 // The 3000-3999 range is reserved for use by libraries, frameworks and applications.
27 // The 4000-4999 range is reserved for private use.
28 const (
29 StatusNormalClosure StatusCode = 1000
30 StatusGoingAway StatusCode = 1001
31 StatusProtocolError StatusCode = 1002
32 StatusUnsupportedData StatusCode = 1003
33
34 // 1004 is reserved and so unexported.
35 statusReserved StatusCode = 1004
36
37 // StatusNoStatusRcvd cannot be sent in a close message.
38 // It is reserved for when a close message is received without
39 // a status code.
40 StatusNoStatusRcvd StatusCode = 1005
41
42 // StatusAbnormalClosure is exported for use only with Wasm.
43 // In non Wasm Go, the returned error will indicate whether the
44 // connection was closed abnormally.
45 StatusAbnormalClosure StatusCode = 1006
46
47 StatusInvalidFramePayloadData StatusCode = 1007
48 StatusPolicyViolation StatusCode = 1008
49 StatusMessageTooBig StatusCode = 1009
50 StatusMandatoryExtension StatusCode = 1010
51 StatusInternalError StatusCode = 1011
52 StatusServiceRestart StatusCode = 1012
53 StatusTryAgainLater StatusCode = 1013
54 StatusBadGateway StatusCode = 1014
55
56 // StatusTLSHandshake is only exported for use with Wasm.
57 // In non Wasm Go, the returned error will indicate whether there was
58 // a TLS handshake failure.
59 StatusTLSHandshake StatusCode = 1015
60 )
61
62 // CloseError is returned when the connection is closed with a status and reason.
63 //
64 // Use Go 1.13's errors.As to check for this error.
65 // Also see the CloseStatus helper.
66 type CloseError struct {
67 Code StatusCode
68 Reason string
69 }
70
71 func (ce CloseError) Error() string {
72 return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
73 }
74
75 // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
76 // the status code from a CloseError.
77 //
78 // -1 will be returned if the passed error is nil or not a CloseError.
79 func CloseStatus(err error) StatusCode {
80 var ce CloseError
81 if errors.As(err, &ce) {
82 return ce.Code
83 }
84 return -1
85 }
86
87 // Close performs the WebSocket close handshake with the given status code and reason.
88 //
89 // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
90 // the peer to send a close frame.
91 // All data messages received from the peer during the close handshake will be discarded.
92 //
93 // The connection can only be closed once. Additional calls to Close
94 // are no-ops.
95 //
96 // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
97 //
98 // Close will unblock all goroutines interacting with the connection once
99 // complete.
100 func (c *Conn) Close(code StatusCode, reason string) (err error) {
101 defer errd.Wrap(&err, "failed to close WebSocket")
102
103 if !c.casClosing() {
104 err = c.waitGoroutines()
105 if err != nil {
106 return err
107 }
108 return net.ErrClosed
109 }
110 defer func() {
111 if errors.Is(err, net.ErrClosed) {
112 err = nil
113 }
114 }()
115
116 err = c.closeHandshake(code, reason)
117
118 err2 := c.close()
119 if err == nil && err2 != nil {
120 err = err2
121 }
122
123 err2 = c.waitGoroutines()
124 if err == nil && err2 != nil {
125 err = err2
126 }
127
128 return err
129 }
130
131 // CloseNow closes the WebSocket connection without attempting a close handshake.
132 // Use when you do not want the overhead of the close handshake.
133 func (c *Conn) CloseNow() (err error) {
134 defer errd.Wrap(&err, "failed to immediately close WebSocket")
135
136 if !c.casClosing() {
137 err = c.waitGoroutines()
138 if err != nil {
139 return err
140 }
141 return net.ErrClosed
142 }
143 defer func() {
144 if errors.Is(err, net.ErrClosed) {
145 err = nil
146 }
147 }()
148
149 err = c.close()
150
151 err2 := c.waitGoroutines()
152 if err == nil && err2 != nil {
153 err = err2
154 }
155 return err
156 }
157
158 func (c *Conn) closeHandshake(code StatusCode, reason string) error {
159 err := c.writeClose(code, reason)
160 if err != nil {
161 return err
162 }
163
164 err = c.waitCloseHandshake()
165 if CloseStatus(err) != code {
166 return err
167 }
168 return nil
169 }
170
171 func (c *Conn) writeClose(code StatusCode, reason string) error {
172 ce := CloseError{
173 Code: code,
174 Reason: reason,
175 }
176
177 var p []byte
178 var err error
179 if ce.Code != StatusNoStatusRcvd {
180 p, err = ce.bytes()
181 if err != nil {
182 return err
183 }
184 }
185
186 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
187 defer cancel()
188
189 err = c.writeControl(ctx, opClose, p)
190 // If the connection closed as we're writing we ignore the error as we might
191 // have written the close frame, the peer responded and then someone else read it
192 // and closed the connection.
193 if err != nil && !errors.Is(err, net.ErrClosed) {
194 return err
195 }
196 return nil
197 }
198
199 func (c *Conn) waitCloseHandshake() error {
200 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
201 defer cancel()
202
203 err := c.readMu.lock(ctx)
204 if err != nil {
205 return err
206 }
207 defer c.readMu.unlock()
208
209 for i := int64(0); i < c.msgReader.payloadLength; i++ {
210 _, err := c.br.ReadByte()
211 if err != nil {
212 return err
213 }
214 }
215
216 for {
217 h, err := c.readLoop(ctx)
218 if err != nil {
219 return err
220 }
221
222 for i := int64(0); i < h.payloadLength; i++ {
223 _, err := c.br.ReadByte()
224 if err != nil {
225 return err
226 }
227 }
228 }
229 }
230
231 func (c *Conn) waitGoroutines() error {
232 t := time.NewTimer(time.Second * 15)
233 defer t.Stop()
234
235 select {
236 case <-c.timeoutLoopDone:
237 case <-t.C:
238 return errors.New("failed to wait for timeoutLoop goroutine to exit")
239 }
240
241 c.closeReadMu.Lock()
242 closeRead := c.closeReadCtx != nil
243 c.closeReadMu.Unlock()
244 if closeRead {
245 select {
246 case <-c.closeReadDone:
247 case <-t.C:
248 return errors.New("failed to wait for close read goroutine to exit")
249 }
250 }
251
252 select {
253 case <-c.closed:
254 case <-t.C:
255 return errors.New("failed to wait for connection to be closed")
256 }
257
258 return nil
259 }
260
261 func parseClosePayload(p []byte) (CloseError, error) {
262 if len(p) == 0 {
263 return CloseError{
264 Code: StatusNoStatusRcvd,
265 }, nil
266 }
267
268 if len(p) < 2 {
269 return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
270 }
271
272 ce := CloseError{
273 Code: StatusCode(binary.BigEndian.Uint16(p)),
274 Reason: string(p[2:]),
275 }
276
277 if !validWireCloseCode(ce.Code) {
278 return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
279 }
280
281 return ce, nil
282 }
283
284 // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
285 // and https://tools.ietf.org/html/rfc6455#section-7.4.1
286 func validWireCloseCode(code StatusCode) bool {
287 switch code {
288 case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
289 return false
290 }
291
292 if code >= StatusNormalClosure && code <= StatusBadGateway {
293 return true
294 }
295 if code >= 3000 && code <= 4999 {
296 return true
297 }
298
299 return false
300 }
301
302 func (ce CloseError) bytes() ([]byte, error) {
303 p, err := ce.bytesErr()
304 if err != nil {
305 err = fmt.Errorf("failed to marshal close frame: %w", err)
306 ce = CloseError{
307 Code: StatusInternalError,
308 }
309 p, _ = ce.bytesErr()
310 }
311 return p, err
312 }
313
314 const maxCloseReason = maxControlPayload - 2
315
316 func (ce CloseError) bytesErr() ([]byte, error) {
317 if len(ce.Reason) > maxCloseReason {
318 return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
319 }
320
321 if !validWireCloseCode(ce.Code) {
322 return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
323 }
324
325 buf := make([]byte, 2+len(ce.Reason))
326 binary.BigEndian.PutUint16(buf, uint16(ce.Code))
327 copy(buf[2:], ce.Reason)
328 return buf, nil
329 }
330
331 func (c *Conn) casClosing() bool {
332 c.closeMu.Lock()
333 defer c.closeMu.Unlock()
334 if !c.closing {
335 c.closing = true
336 return true
337 }
338 return false
339 }
340
341 func (c *Conn) isClosed() bool {
342 select {
343 case <-c.closed:
344 return true
345 default:
346 return false
347 }
348 }
349