write.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"bufio"
   8  	"context"
   9  	"crypto/rand"
  10  	"encoding/binary"
  11  	"errors"
  12  	"fmt"
  13  	"io"
  14  	"net"
  15  	"time"
  16  
  17  	"compress/flate"
  18  
  19  	"github.com/coder/websocket/internal/errd"
  20  	"github.com/coder/websocket/internal/util"
  21  )
  22  
  23  // Writer returns a writer bounded by the context that will write
  24  // a WebSocket message of type dataType to the connection.
  25  //
  26  // You must close the writer once you have written the entire message.
  27  //
  28  // Only one writer can be open at a time, multiple calls will block until the previous writer
  29  // is closed.
  30  func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
  31  	w, err := c.writer(ctx, typ)
  32  	if err != nil {
  33  		return nil, fmt.Errorf("failed to get writer: %w", err)
  34  	}
  35  	return w, nil
  36  }
  37  
  38  // Write writes a message to the connection.
  39  //
  40  // See the Writer method if you want to stream a message.
  41  //
  42  // If compression is disabled or the compression threshold is not met, then it
  43  // will write the message in a single frame.
  44  func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
  45  	_, err := c.write(ctx, typ, p)
  46  	if err != nil {
  47  		return fmt.Errorf("failed to write msg: %w", err)
  48  	}
  49  	return nil
  50  }
  51  
  52  type msgWriter struct {
  53  	c *Conn
  54  
  55  	mu      *mu
  56  	writeMu *mu
  57  	closed  bool
  58  
  59  	ctx    context.Context
  60  	opcode opcode
  61  	flate  bool
  62  
  63  	trimWriter  *trimLastFourBytesWriter
  64  	flateWriter *flate.Writer
  65  }
  66  
  67  func newMsgWriter(c *Conn) *msgWriter {
  68  	mw := &msgWriter{
  69  		c:       c,
  70  		mu:      newMu(c),
  71  		writeMu: newMu(c),
  72  	}
  73  	return mw
  74  }
  75  
  76  func (mw *msgWriter) ensureFlate() {
  77  	if mw.trimWriter == nil {
  78  		mw.trimWriter = &trimLastFourBytesWriter{
  79  			w: util.WriterFunc(mw.write),
  80  		}
  81  	}
  82  
  83  	if mw.flateWriter == nil {
  84  		mw.flateWriter = getFlateWriter(mw.trimWriter)
  85  	}
  86  	mw.flate = true
  87  }
  88  
  89  func (mw *msgWriter) flateContextTakeover() bool {
  90  	if mw.c.client {
  91  		return !mw.c.copts.clientNoContextTakeover
  92  	}
  93  	return !mw.c.copts.serverNoContextTakeover
  94  }
  95  
  96  func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
  97  	err := c.msgWriter.reset(ctx, typ)
  98  	if err != nil {
  99  		return nil, err
 100  	}
 101  	return c.msgWriter, nil
 102  }
 103  
 104  func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
 105  	mw, err := c.writer(ctx, typ)
 106  	if err != nil {
 107  		return 0, err
 108  	}
 109  
 110  	if !c.flate() {
 111  		defer c.msgWriter.mu.unlock()
 112  		return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
 113  	}
 114  
 115  	n, err := mw.Write(p)
 116  	if err != nil {
 117  		return n, err
 118  	}
 119  
 120  	err = mw.Close()
 121  	return n, err
 122  }
 123  
 124  func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 125  	err := mw.mu.lock(ctx)
 126  	if err != nil {
 127  		return err
 128  	}
 129  
 130  	mw.ctx = ctx
 131  	mw.opcode = opcode(typ)
 132  	mw.flate = false
 133  	mw.closed = false
 134  
 135  	mw.trimWriter.reset()
 136  
 137  	return nil
 138  }
 139  
 140  func (mw *msgWriter) putFlateWriter() {
 141  	if mw.flateWriter != nil {
 142  		putFlateWriter(mw.flateWriter)
 143  		mw.flateWriter = nil
 144  	}
 145  }
 146  
 147  // Write writes the given bytes to the WebSocket connection.
 148  func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 149  	err = mw.writeMu.lock(mw.ctx)
 150  	if err != nil {
 151  		return 0, fmt.Errorf("failed to write: %w", err)
 152  	}
 153  	defer mw.writeMu.unlock()
 154  
 155  	if mw.closed {
 156  		return 0, errors.New("cannot use closed writer")
 157  	}
 158  
 159  	defer func() {
 160  		if err != nil {
 161  			err = fmt.Errorf("failed to write: %w", err)
 162  		}
 163  	}()
 164  
 165  	if mw.c.flate() {
 166  		// Only enables flate if the length crosses the
 167  		// threshold on the first frame
 168  		if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
 169  			mw.ensureFlate()
 170  		}
 171  	}
 172  
 173  	if mw.flate {
 174  		return mw.flateWriter.Write(p)
 175  	}
 176  
 177  	return mw.write(p)
 178  }
 179  
 180  func (mw *msgWriter) write(p []byte) (int, error) {
 181  	n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
 182  	if err != nil {
 183  		return n, fmt.Errorf("failed to write data frame: %w", err)
 184  	}
 185  	mw.opcode = opContinuation
 186  	return n, nil
 187  }
 188  
 189  // Close flushes the frame to the connection.
 190  func (mw *msgWriter) Close() (err error) {
 191  	defer errd.Wrap(&err, "failed to close writer")
 192  
 193  	err = mw.writeMu.lock(mw.ctx)
 194  	if err != nil {
 195  		return err
 196  	}
 197  	defer mw.writeMu.unlock()
 198  
 199  	if mw.closed {
 200  		return errors.New("writer already closed")
 201  	}
 202  	mw.closed = true
 203  
 204  	if mw.flate {
 205  		err = mw.flateWriter.Flush()
 206  		if err != nil {
 207  			return fmt.Errorf("failed to flush flate: %w", err)
 208  		}
 209  	}
 210  
 211  	_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
 212  	if err != nil {
 213  		return fmt.Errorf("failed to write fin frame: %w", err)
 214  	}
 215  
 216  	if mw.flate && !mw.flateContextTakeover() {
 217  		mw.putFlateWriter()
 218  	}
 219  	mw.mu.unlock()
 220  	return nil
 221  }
 222  
 223  func (mw *msgWriter) close() {
 224  	if mw.c.client {
 225  		mw.c.writeFrameMu.forceLock()
 226  		putBufioWriter(mw.c.bw)
 227  	}
 228  
 229  	mw.writeMu.forceLock()
 230  	mw.putFlateWriter()
 231  }
 232  
 233  func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
 234  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
 235  	defer cancel()
 236  
 237  	_, err := c.writeFrame(ctx, true, false, opcode, p)
 238  	if err != nil {
 239  		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
 240  	}
 241  	return nil
 242  }
 243  
 244  // writeFrame handles all writes to the connection.
 245  func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
 246  	err = c.writeFrameMu.lock(ctx)
 247  	if err != nil {
 248  		return 0, err
 249  	}
 250  	defer c.writeFrameMu.unlock()
 251  
 252  	select {
 253  	case <-c.closed:
 254  		return 0, net.ErrClosed
 255  	case c.writeTimeout <- ctx:
 256  	}
 257  
 258  	defer func() {
 259  		if err != nil {
 260  			select {
 261  			case <-c.closed:
 262  				err = net.ErrClosed
 263  			case <-ctx.Done():
 264  				err = ctx.Err()
 265  			default:
 266  			}
 267  			err = fmt.Errorf("failed to write frame: %w", err)
 268  		}
 269  	}()
 270  
 271  	c.writeHeader.fin = fin
 272  	c.writeHeader.opcode = opcode
 273  	c.writeHeader.payloadLength = int64(len(p))
 274  
 275  	if c.client {
 276  		c.writeHeader.masked = true
 277  		_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
 278  		if err != nil {
 279  			return 0, fmt.Errorf("failed to generate masking key: %w", err)
 280  		}
 281  		c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
 282  	}
 283  
 284  	c.writeHeader.rsv1 = false
 285  	if flate && (opcode == opText || opcode == opBinary) {
 286  		c.writeHeader.rsv1 = true
 287  	}
 288  
 289  	err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
 290  	if err != nil {
 291  		return 0, err
 292  	}
 293  
 294  	n, err := c.writeFramePayload(p)
 295  	if err != nil {
 296  		return n, err
 297  	}
 298  
 299  	if c.writeHeader.fin {
 300  		err = c.bw.Flush()
 301  		if err != nil {
 302  			return n, fmt.Errorf("failed to flush: %w", err)
 303  		}
 304  	}
 305  
 306  	select {
 307  	case <-c.closed:
 308  		if opcode == opClose {
 309  			return n, nil
 310  		}
 311  		return n, net.ErrClosed
 312  	case c.writeTimeout <- context.Background():
 313  	}
 314  
 315  	return n, nil
 316  }
 317  
 318  func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
 319  	defer errd.Wrap(&err, "failed to write frame payload")
 320  
 321  	if !c.writeHeader.masked {
 322  		return c.bw.Write(p)
 323  	}
 324  
 325  	maskKey := c.writeHeader.maskKey
 326  	for len(p) > 0 {
 327  		// If the buffer is full, we need to flush.
 328  		if c.bw.Available() == 0 {
 329  			err = c.bw.Flush()
 330  			if err != nil {
 331  				return n, err
 332  			}
 333  		}
 334  
 335  		// Start of next write in the buffer.
 336  		i := c.bw.Buffered()
 337  
 338  		j := len(p)
 339  		if j > c.bw.Available() {
 340  			j = c.bw.Available()
 341  		}
 342  
 343  		_, err := c.bw.Write(p[:j])
 344  		if err != nil {
 345  			return n, err
 346  		}
 347  
 348  		maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
 349  
 350  		p = p[j:]
 351  		n += j
 352  	}
 353  
 354  	return n, nil
 355  }
 356  
 357  // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
 358  // and returns it.
 359  func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
 360  	var writeBuf []byte
 361  	bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
 362  		writeBuf = p2[:cap(p2)]
 363  		return len(p2), nil
 364  	}))
 365  
 366  	bw.WriteByte(0)
 367  	bw.Flush()
 368  
 369  	bw.Reset(w)
 370  
 371  	return writeBuf
 372  }
 373  
 374  func (c *Conn) writeError(code StatusCode, err error) {
 375  	c.writeClose(code, err.Error())
 376  }
 377