conn.go raw
1 //go:build !js
2 // +build !js
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "fmt"
10 "io"
11 "net"
12 "runtime"
13 "strconv"
14 "sync"
15 "sync/atomic"
16 )
17
18 // MessageType represents the type of a WebSocket message.
19 // See https://tools.ietf.org/html/rfc6455#section-5.6
20 type MessageType int
21
22 // MessageType constants.
23 const (
24 // MessageText is for UTF-8 encoded text messages like JSON.
25 MessageText MessageType = iota + 1
26 // MessageBinary is for binary messages like protobufs.
27 MessageBinary
28 )
29
30 // Conn represents a WebSocket connection.
31 // All methods may be called concurrently except for Reader and Read.
32 //
33 // You must always read from the connection. Otherwise control
34 // frames will not be handled. See Reader and CloseRead.
35 //
36 // Be sure to call Close on the connection when you
37 // are finished with it to release associated resources.
38 //
39 // On any error from any method, the connection is closed
40 // with an appropriate reason.
41 //
42 // This applies to context expirations as well unfortunately.
43 // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
44 type Conn struct {
45 noCopy noCopy
46
47 subprotocol string
48 rwc io.ReadWriteCloser
49 client bool
50 copts *compressionOptions
51 flateThreshold int
52 br *bufio.Reader
53 bw *bufio.Writer
54
55 readTimeout chan context.Context
56 writeTimeout chan context.Context
57 timeoutLoopDone chan struct{}
58
59 // Read state.
60 readMu *mu
61 readHeaderBuf [8]byte
62 readControlBuf [maxControlPayload]byte
63 msgReader *msgReader
64
65 // Write state.
66 msgWriter *msgWriter
67 writeFrameMu *mu
68 writeBuf []byte
69 writeHeaderBuf [8]byte
70 writeHeader header
71
72 closeReadMu sync.Mutex
73 closeReadCtx context.Context
74 closeReadDone chan struct{}
75
76 closed chan struct{}
77 closeMu sync.Mutex
78 closing bool
79
80 pingCounter int32
81 activePingsMu sync.Mutex
82 activePings map[string]chan<- struct{}
83 }
84
85 type connConfig struct {
86 subprotocol string
87 rwc io.ReadWriteCloser
88 client bool
89 copts *compressionOptions
90 flateThreshold int
91
92 br *bufio.Reader
93 bw *bufio.Writer
94 }
95
96 func newConn(cfg connConfig) *Conn {
97 c := &Conn{
98 subprotocol: cfg.subprotocol,
99 rwc: cfg.rwc,
100 client: cfg.client,
101 copts: cfg.copts,
102 flateThreshold: cfg.flateThreshold,
103
104 br: cfg.br,
105 bw: cfg.bw,
106
107 readTimeout: make(chan context.Context),
108 writeTimeout: make(chan context.Context),
109 timeoutLoopDone: make(chan struct{}),
110
111 closed: make(chan struct{}),
112 activePings: make(map[string]chan<- struct{}),
113 }
114
115 c.readMu = newMu(c)
116 c.writeFrameMu = newMu(c)
117
118 c.msgReader = newMsgReader(c)
119
120 c.msgWriter = newMsgWriter(c)
121 if c.client {
122 c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
123 }
124
125 if c.flate() && c.flateThreshold == 0 {
126 c.flateThreshold = 128
127 if !c.msgWriter.flateContextTakeover() {
128 c.flateThreshold = 512
129 }
130 }
131
132 runtime.SetFinalizer(c, func(c *Conn) {
133 c.close()
134 })
135
136 go c.timeoutLoop()
137
138 return c
139 }
140
141 // Subprotocol returns the negotiated subprotocol.
142 // An empty string means the default protocol.
143 func (c *Conn) Subprotocol() string {
144 return c.subprotocol
145 }
146
147 func (c *Conn) close() error {
148 c.closeMu.Lock()
149 defer c.closeMu.Unlock()
150
151 if c.isClosed() {
152 return net.ErrClosed
153 }
154 runtime.SetFinalizer(c, nil)
155 close(c.closed)
156
157 // Have to close after c.closed is closed to ensure any goroutine that wakes up
158 // from the connection being closed also sees that c.closed is closed and returns
159 // closeErr.
160 err := c.rwc.Close()
161 // With the close of rwc, these become safe to close.
162 c.msgWriter.close()
163 c.msgReader.close()
164 return err
165 }
166
167 func (c *Conn) timeoutLoop() {
168 defer close(c.timeoutLoopDone)
169
170 readCtx := context.Background()
171 writeCtx := context.Background()
172
173 for {
174 select {
175 case <-c.closed:
176 return
177
178 case writeCtx = <-c.writeTimeout:
179 case readCtx = <-c.readTimeout:
180
181 case <-readCtx.Done():
182 c.close()
183 return
184 case <-writeCtx.Done():
185 c.close()
186 return
187 }
188 }
189 }
190
191 func (c *Conn) flate() bool {
192 return c.copts != nil
193 }
194
195 // Ping sends a ping to the peer and waits for a pong.
196 // Use this to measure latency or ensure the peer is responsive.
197 // Ping must be called concurrently with Reader as it does
198 // not read from the connection but instead waits for a Reader call
199 // to read the pong.
200 //
201 // TCP Keepalives should suffice for most use cases.
202 func (c *Conn) Ping(ctx context.Context) error {
203 p := atomic.AddInt32(&c.pingCounter, 1)
204
205 err := c.ping(ctx, strconv.Itoa(int(p)))
206 if err != nil {
207 return fmt.Errorf("failed to ping: %w", err)
208 }
209 return nil
210 }
211
212 func (c *Conn) ping(ctx context.Context, p string) error {
213 pong := make(chan struct{}, 1)
214
215 c.activePingsMu.Lock()
216 c.activePings[p] = pong
217 c.activePingsMu.Unlock()
218
219 defer func() {
220 c.activePingsMu.Lock()
221 delete(c.activePings, p)
222 c.activePingsMu.Unlock()
223 }()
224
225 err := c.writeControl(ctx, opPing, []byte(p))
226 if err != nil {
227 return err
228 }
229
230 select {
231 case <-c.closed:
232 return net.ErrClosed
233 case <-ctx.Done():
234 return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
235 case <-pong:
236 return nil
237 }
238 }
239
240 type mu struct {
241 c *Conn
242 ch chan struct{}
243 }
244
245 func newMu(c *Conn) *mu {
246 return &mu{
247 c: c,
248 ch: make(chan struct{}, 1),
249 }
250 }
251
252 func (m *mu) forceLock() {
253 m.ch <- struct{}{}
254 }
255
256 func (m *mu) tryLock() bool {
257 select {
258 case m.ch <- struct{}{}:
259 return true
260 default:
261 return false
262 }
263 }
264
265 func (m *mu) lock(ctx context.Context) error {
266 select {
267 case <-m.c.closed:
268 return net.ErrClosed
269 case <-ctx.Done():
270 return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
271 case m.ch <- struct{}{}:
272 // To make sure the connection is certainly alive.
273 // As it's possible the send on m.ch was selected
274 // over the receive on closed.
275 select {
276 case <-m.c.closed:
277 // Make sure to release.
278 m.unlock()
279 return net.ErrClosed
280 default:
281 }
282 return nil
283 }
284 }
285
286 func (m *mu) unlock() {
287 select {
288 case <-m.ch:
289 default:
290 }
291 }
292
293 type noCopy struct{}
294
295 func (*noCopy) Lock() {}
296