write.go raw
1 //go:build !js
2 // +build !js
3
4 package websocket
5
6 import (
7 "bufio"
8 "context"
9 "crypto/rand"
10 "encoding/binary"
11 "errors"
12 "fmt"
13 "io"
14 "net"
15 "time"
16
17 "compress/flate"
18
19 "github.com/coder/websocket/internal/errd"
20 "github.com/coder/websocket/internal/util"
21 )
22
23 // Writer returns a writer bounded by the context that will write
24 // a WebSocket message of type dataType to the connection.
25 //
26 // You must close the writer once you have written the entire message.
27 //
28 // Only one writer can be open at a time, multiple calls will block until the previous writer
29 // is closed.
30 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
31 w, err := c.writer(ctx, typ)
32 if err != nil {
33 return nil, fmt.Errorf("failed to get writer: %w", err)
34 }
35 return w, nil
36 }
37
38 // Write writes a message to the connection.
39 //
40 // See the Writer method if you want to stream a message.
41 //
42 // If compression is disabled or the compression threshold is not met, then it
43 // will write the message in a single frame.
44 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
45 _, err := c.write(ctx, typ, p)
46 if err != nil {
47 return fmt.Errorf("failed to write msg: %w", err)
48 }
49 return nil
50 }
51
52 type msgWriter struct {
53 c *Conn
54
55 mu *mu
56 writeMu *mu
57 closed bool
58
59 ctx context.Context
60 opcode opcode
61 flate bool
62
63 trimWriter *trimLastFourBytesWriter
64 flateWriter *flate.Writer
65 }
66
67 func newMsgWriter(c *Conn) *msgWriter {
68 mw := &msgWriter{
69 c: c,
70 mu: newMu(c),
71 writeMu: newMu(c),
72 }
73 return mw
74 }
75
76 func (mw *msgWriter) ensureFlate() {
77 if mw.trimWriter == nil {
78 mw.trimWriter = &trimLastFourBytesWriter{
79 w: util.WriterFunc(mw.write),
80 }
81 }
82
83 if mw.flateWriter == nil {
84 mw.flateWriter = getFlateWriter(mw.trimWriter)
85 }
86 mw.flate = true
87 }
88
89 func (mw *msgWriter) flateContextTakeover() bool {
90 if mw.c.client {
91 return !mw.c.copts.clientNoContextTakeover
92 }
93 return !mw.c.copts.serverNoContextTakeover
94 }
95
96 func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
97 err := c.msgWriter.reset(ctx, typ)
98 if err != nil {
99 return nil, err
100 }
101 return c.msgWriter, nil
102 }
103
104 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
105 mw, err := c.writer(ctx, typ)
106 if err != nil {
107 return 0, err
108 }
109
110 if !c.flate() {
111 defer c.msgWriter.mu.unlock()
112 return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
113 }
114
115 n, err := mw.Write(p)
116 if err != nil {
117 return n, err
118 }
119
120 err = mw.Close()
121 return n, err
122 }
123
124 func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
125 err := mw.mu.lock(ctx)
126 if err != nil {
127 return err
128 }
129
130 mw.ctx = ctx
131 mw.opcode = opcode(typ)
132 mw.flate = false
133 mw.closed = false
134
135 mw.trimWriter.reset()
136
137 return nil
138 }
139
140 func (mw *msgWriter) putFlateWriter() {
141 if mw.flateWriter != nil {
142 putFlateWriter(mw.flateWriter)
143 mw.flateWriter = nil
144 }
145 }
146
147 // Write writes the given bytes to the WebSocket connection.
148 func (mw *msgWriter) Write(p []byte) (_ int, err error) {
149 err = mw.writeMu.lock(mw.ctx)
150 if err != nil {
151 return 0, fmt.Errorf("failed to write: %w", err)
152 }
153 defer mw.writeMu.unlock()
154
155 if mw.closed {
156 return 0, errors.New("cannot use closed writer")
157 }
158
159 defer func() {
160 if err != nil {
161 err = fmt.Errorf("failed to write: %w", err)
162 }
163 }()
164
165 if mw.c.flate() {
166 // Only enables flate if the length crosses the
167 // threshold on the first frame
168 if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
169 mw.ensureFlate()
170 }
171 }
172
173 if mw.flate {
174 return mw.flateWriter.Write(p)
175 }
176
177 return mw.write(p)
178 }
179
180 func (mw *msgWriter) write(p []byte) (int, error) {
181 n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
182 if err != nil {
183 return n, fmt.Errorf("failed to write data frame: %w", err)
184 }
185 mw.opcode = opContinuation
186 return n, nil
187 }
188
189 // Close flushes the frame to the connection.
190 func (mw *msgWriter) Close() (err error) {
191 defer errd.Wrap(&err, "failed to close writer")
192
193 err = mw.writeMu.lock(mw.ctx)
194 if err != nil {
195 return err
196 }
197 defer mw.writeMu.unlock()
198
199 if mw.closed {
200 return errors.New("writer already closed")
201 }
202 mw.closed = true
203
204 if mw.flate {
205 err = mw.flateWriter.Flush()
206 if err != nil {
207 return fmt.Errorf("failed to flush flate: %w", err)
208 }
209 }
210
211 _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
212 if err != nil {
213 return fmt.Errorf("failed to write fin frame: %w", err)
214 }
215
216 if mw.flate && !mw.flateContextTakeover() {
217 mw.putFlateWriter()
218 }
219 mw.mu.unlock()
220 return nil
221 }
222
223 func (mw *msgWriter) close() {
224 if mw.c.client {
225 mw.c.writeFrameMu.forceLock()
226 putBufioWriter(mw.c.bw)
227 }
228
229 mw.writeMu.forceLock()
230 mw.putFlateWriter()
231 }
232
233 func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
234 ctx, cancel := context.WithTimeout(ctx, time.Second*5)
235 defer cancel()
236
237 _, err := c.writeFrame(ctx, true, false, opcode, p)
238 if err != nil {
239 return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
240 }
241 return nil
242 }
243
244 // writeFrame handles all writes to the connection.
245 func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
246 err = c.writeFrameMu.lock(ctx)
247 if err != nil {
248 return 0, err
249 }
250 defer c.writeFrameMu.unlock()
251
252 select {
253 case <-c.closed:
254 return 0, net.ErrClosed
255 case c.writeTimeout <- ctx:
256 }
257
258 defer func() {
259 if err != nil {
260 select {
261 case <-c.closed:
262 err = net.ErrClosed
263 case <-ctx.Done():
264 err = ctx.Err()
265 default:
266 }
267 err = fmt.Errorf("failed to write frame: %w", err)
268 }
269 }()
270
271 c.writeHeader.fin = fin
272 c.writeHeader.opcode = opcode
273 c.writeHeader.payloadLength = int64(len(p))
274
275 if c.client {
276 c.writeHeader.masked = true
277 _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
278 if err != nil {
279 return 0, fmt.Errorf("failed to generate masking key: %w", err)
280 }
281 c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
282 }
283
284 c.writeHeader.rsv1 = false
285 if flate && (opcode == opText || opcode == opBinary) {
286 c.writeHeader.rsv1 = true
287 }
288
289 err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
290 if err != nil {
291 return 0, err
292 }
293
294 n, err := c.writeFramePayload(p)
295 if err != nil {
296 return n, err
297 }
298
299 if c.writeHeader.fin {
300 err = c.bw.Flush()
301 if err != nil {
302 return n, fmt.Errorf("failed to flush: %w", err)
303 }
304 }
305
306 select {
307 case <-c.closed:
308 if opcode == opClose {
309 return n, nil
310 }
311 return n, net.ErrClosed
312 case c.writeTimeout <- context.Background():
313 }
314
315 return n, nil
316 }
317
318 func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
319 defer errd.Wrap(&err, "failed to write frame payload")
320
321 if !c.writeHeader.masked {
322 return c.bw.Write(p)
323 }
324
325 maskKey := c.writeHeader.maskKey
326 for len(p) > 0 {
327 // If the buffer is full, we need to flush.
328 if c.bw.Available() == 0 {
329 err = c.bw.Flush()
330 if err != nil {
331 return n, err
332 }
333 }
334
335 // Start of next write in the buffer.
336 i := c.bw.Buffered()
337
338 j := len(p)
339 if j > c.bw.Available() {
340 j = c.bw.Available()
341 }
342
343 _, err := c.bw.Write(p[:j])
344 if err != nil {
345 return n, err
346 }
347
348 maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
349
350 p = p[j:]
351 n += j
352 }
353
354 return n, nil
355 }
356
357 // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
358 // and returns it.
359 func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
360 var writeBuf []byte
361 bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
362 writeBuf = p2[:cap(p2)]
363 return len(p2), nil
364 }))
365
366 bw.WriteByte(0)
367 bw.Flush()
368
369 bw.Reset(w)
370
371 return writeBuf
372 }
373
374 func (c *Conn) writeError(code StatusCode, err error) {
375 c.writeClose(code, err.Error())
376 }
377