http_util.go raw

   1  /*
   2   *
   3   * Copyright 2014 gRPC authors.
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     http://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  package transport
  20  
  21  import (
  22  	"bufio"
  23  	"encoding/base64"
  24  	"errors"
  25  	"fmt"
  26  	"io"
  27  	"math"
  28  	"net/http"
  29  	"net/url"
  30  	"strconv"
  31  	"strings"
  32  	"sync"
  33  	"time"
  34  	"unicode/utf8"
  35  
  36  	"golang.org/x/net/http2"
  37  	"golang.org/x/net/http2/hpack"
  38  	"google.golang.org/grpc/codes"
  39  	"google.golang.org/grpc/mem"
  40  )
  41  
  42  const (
  43  	// http2MaxFrameLen specifies the max length of a HTTP2 frame.
  44  	http2MaxFrameLen = 16384 // 16KB frame
  45  	// https://httpwg.org/specs/rfc7540.html#SettingValues
  46  	http2InitHeaderTableSize = 4096
  47  )
  48  
  49  var (
  50  	clientPreface   = []byte(http2.ClientPreface)
  51  	http2ErrConvTab = map[http2.ErrCode]codes.Code{
  52  		http2.ErrCodeNo:                 codes.Internal,
  53  		http2.ErrCodeProtocol:           codes.Internal,
  54  		http2.ErrCodeInternal:           codes.Internal,
  55  		http2.ErrCodeFlowControl:        codes.ResourceExhausted,
  56  		http2.ErrCodeSettingsTimeout:    codes.Internal,
  57  		http2.ErrCodeStreamClosed:       codes.Internal,
  58  		http2.ErrCodeFrameSize:          codes.Internal,
  59  		http2.ErrCodeRefusedStream:      codes.Unavailable,
  60  		http2.ErrCodeCancel:             codes.Canceled,
  61  		http2.ErrCodeCompression:        codes.Internal,
  62  		http2.ErrCodeConnect:            codes.Internal,
  63  		http2.ErrCodeEnhanceYourCalm:    codes.ResourceExhausted,
  64  		http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
  65  		http2.ErrCodeHTTP11Required:     codes.Internal,
  66  	}
  67  	// HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
  68  	HTTPStatusConvTab = map[int]codes.Code{
  69  		// 400 Bad Request - INTERNAL.
  70  		http.StatusBadRequest: codes.Internal,
  71  		// 401 Unauthorized  - UNAUTHENTICATED.
  72  		http.StatusUnauthorized: codes.Unauthenticated,
  73  		// 403 Forbidden - PERMISSION_DENIED.
  74  		http.StatusForbidden: codes.PermissionDenied,
  75  		// 404 Not Found - UNIMPLEMENTED.
  76  		http.StatusNotFound: codes.Unimplemented,
  77  		// 429 Too Many Requests - UNAVAILABLE.
  78  		http.StatusTooManyRequests: codes.Unavailable,
  79  		// 502 Bad Gateway - UNAVAILABLE.
  80  		http.StatusBadGateway: codes.Unavailable,
  81  		// 503 Service Unavailable - UNAVAILABLE.
  82  		http.StatusServiceUnavailable: codes.Unavailable,
  83  		// 504 Gateway timeout - UNAVAILABLE.
  84  		http.StatusGatewayTimeout: codes.Unavailable,
  85  	}
  86  )
  87  
  88  var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
  89  
  90  // isReservedHeader checks whether hdr belongs to HTTP2 headers
  91  // reserved by gRPC protocol. Any other headers are classified as the
  92  // user-specified metadata.
  93  func isReservedHeader(hdr string) bool {
  94  	if hdr != "" && hdr[0] == ':' {
  95  		return true
  96  	}
  97  	switch hdr {
  98  	case "content-type",
  99  		"user-agent",
 100  		"grpc-message-type",
 101  		"grpc-encoding",
 102  		"grpc-message",
 103  		"grpc-status",
 104  		"grpc-timeout",
 105  		// Intentionally exclude grpc-previous-rpc-attempts and
 106  		// grpc-retry-pushback-ms, which are "reserved", but their API
 107  		// intentionally works via metadata.
 108  		"te":
 109  		return true
 110  	default:
 111  		return false
 112  	}
 113  }
 114  
 115  // isWhitelistedHeader checks whether hdr should be propagated into metadata
 116  // visible to users, even though it is classified as "reserved", above.
 117  func isWhitelistedHeader(hdr string) bool {
 118  	switch hdr {
 119  	case ":authority", "user-agent":
 120  		return true
 121  	default:
 122  		return false
 123  	}
 124  }
 125  
 126  const binHdrSuffix = "-bin"
 127  
 128  func encodeBinHeader(v []byte) string {
 129  	return base64.RawStdEncoding.EncodeToString(v)
 130  }
 131  
 132  func decodeBinHeader(v string) ([]byte, error) {
 133  	if len(v)%4 == 0 {
 134  		// Input was padded, or padding was not necessary.
 135  		return base64.StdEncoding.DecodeString(v)
 136  	}
 137  	return base64.RawStdEncoding.DecodeString(v)
 138  }
 139  
 140  func encodeMetadataHeader(k, v string) string {
 141  	if strings.HasSuffix(k, binHdrSuffix) {
 142  		return encodeBinHeader(([]byte)(v))
 143  	}
 144  	return v
 145  }
 146  
 147  func decodeMetadataHeader(k, v string) (string, error) {
 148  	if strings.HasSuffix(k, binHdrSuffix) {
 149  		b, err := decodeBinHeader(v)
 150  		return string(b), err
 151  	}
 152  	return v, nil
 153  }
 154  
 155  type timeoutUnit uint8
 156  
 157  const (
 158  	hour        timeoutUnit = 'H'
 159  	minute      timeoutUnit = 'M'
 160  	second      timeoutUnit = 'S'
 161  	millisecond timeoutUnit = 'm'
 162  	microsecond timeoutUnit = 'u'
 163  	nanosecond  timeoutUnit = 'n'
 164  )
 165  
 166  func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
 167  	switch u {
 168  	case hour:
 169  		return time.Hour, true
 170  	case minute:
 171  		return time.Minute, true
 172  	case second:
 173  		return time.Second, true
 174  	case millisecond:
 175  		return time.Millisecond, true
 176  	case microsecond:
 177  		return time.Microsecond, true
 178  	case nanosecond:
 179  		return time.Nanosecond, true
 180  	default:
 181  	}
 182  	return
 183  }
 184  
 185  func decodeTimeout(s string) (time.Duration, error) {
 186  	size := len(s)
 187  	if size < 2 {
 188  		return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
 189  	}
 190  	if size > 9 {
 191  		// Spec allows for 8 digits plus the unit.
 192  		return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
 193  	}
 194  	unit := timeoutUnit(s[size-1])
 195  	d, ok := timeoutUnitToDuration(unit)
 196  	if !ok {
 197  		return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
 198  	}
 199  	t, err := strconv.ParseUint(s[:size-1], 10, 64)
 200  	if err != nil {
 201  		return 0, err
 202  	}
 203  	const maxHours = math.MaxInt64 / uint64(time.Hour)
 204  	if d == time.Hour && t > maxHours {
 205  		// This timeout would overflow math.MaxInt64; clamp it.
 206  		return time.Duration(math.MaxInt64), nil
 207  	}
 208  	return d * time.Duration(t), nil
 209  }
 210  
 211  const (
 212  	spaceByte   = ' '
 213  	tildeByte   = '~'
 214  	percentByte = '%'
 215  )
 216  
 217  // encodeGrpcMessage is used to encode status code in header field
 218  // "grpc-message". It does percent encoding and also replaces invalid utf-8
 219  // characters with Unicode replacement character.
 220  //
 221  // It checks to see if each individual byte in msg is an allowable byte, and
 222  // then either percent encoding or passing it through. When percent encoding,
 223  // the byte is converted into hexadecimal notation with a '%' prepended.
 224  func encodeGrpcMessage(msg string) string {
 225  	if msg == "" {
 226  		return ""
 227  	}
 228  	lenMsg := len(msg)
 229  	for i := 0; i < lenMsg; i++ {
 230  		c := msg[i]
 231  		if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
 232  			return encodeGrpcMessageUnchecked(msg)
 233  		}
 234  	}
 235  	return msg
 236  }
 237  
 238  func encodeGrpcMessageUnchecked(msg string) string {
 239  	var sb strings.Builder
 240  	for len(msg) > 0 {
 241  		r, size := utf8.DecodeRuneInString(msg)
 242  		for _, b := range []byte(string(r)) {
 243  			if size > 1 {
 244  				// If size > 1, r is not ascii. Always do percent encoding.
 245  				fmt.Fprintf(&sb, "%%%02X", b)
 246  				continue
 247  			}
 248  
 249  			// The for loop is necessary even if size == 1. r could be
 250  			// utf8.RuneError.
 251  			//
 252  			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
 253  			if b >= spaceByte && b <= tildeByte && b != percentByte {
 254  				sb.WriteByte(b)
 255  			} else {
 256  				fmt.Fprintf(&sb, "%%%02X", b)
 257  			}
 258  		}
 259  		msg = msg[size:]
 260  	}
 261  	return sb.String()
 262  }
 263  
 264  // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
 265  func decodeGrpcMessage(msg string) string {
 266  	if msg == "" {
 267  		return ""
 268  	}
 269  	lenMsg := len(msg)
 270  	for i := 0; i < lenMsg; i++ {
 271  		if msg[i] == percentByte && i+2 < lenMsg {
 272  			return decodeGrpcMessageUnchecked(msg)
 273  		}
 274  	}
 275  	return msg
 276  }
 277  
 278  func decodeGrpcMessageUnchecked(msg string) string {
 279  	var sb strings.Builder
 280  	lenMsg := len(msg)
 281  	for i := 0; i < lenMsg; i++ {
 282  		c := msg[i]
 283  		if c == percentByte && i+2 < lenMsg {
 284  			parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
 285  			if err != nil {
 286  				sb.WriteByte(c)
 287  			} else {
 288  				sb.WriteByte(byte(parsed))
 289  				i += 2
 290  			}
 291  		} else {
 292  			sb.WriteByte(c)
 293  		}
 294  	}
 295  	return sb.String()
 296  }
 297  
 298  type bufWriter struct {
 299  	pool      *sync.Pool
 300  	buf       []byte
 301  	offset    int
 302  	batchSize int
 303  	conn      io.Writer
 304  	err       error
 305  }
 306  
 307  func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter {
 308  	w := &bufWriter{
 309  		batchSize: batchSize,
 310  		conn:      conn,
 311  		pool:      pool,
 312  	}
 313  	// this indicates that we should use non shared buf
 314  	if pool == nil {
 315  		w.buf = make([]byte, batchSize)
 316  	}
 317  	return w
 318  }
 319  
 320  func (w *bufWriter) Write(b []byte) (int, error) {
 321  	if w.err != nil {
 322  		return 0, w.err
 323  	}
 324  	if w.batchSize == 0 { // Buffer has been disabled.
 325  		n, err := w.conn.Write(b)
 326  		return n, toIOError(err)
 327  	}
 328  	if w.buf == nil {
 329  		b := w.pool.Get().(*[]byte)
 330  		w.buf = *b
 331  	}
 332  	written := 0
 333  	for len(b) > 0 {
 334  		copied := copy(w.buf[w.offset:], b)
 335  		b = b[copied:]
 336  		written += copied
 337  		w.offset += copied
 338  		if w.offset < w.batchSize {
 339  			continue
 340  		}
 341  		if err := w.flushKeepBuffer(); err != nil {
 342  			return written, err
 343  		}
 344  	}
 345  	return written, nil
 346  }
 347  
 348  func (w *bufWriter) Flush() error {
 349  	err := w.flushKeepBuffer()
 350  	// Only release the buffer if we are in a "shared" mode
 351  	if w.buf != nil && w.pool != nil {
 352  		b := w.buf
 353  		w.pool.Put(&b)
 354  		w.buf = nil
 355  	}
 356  	return err
 357  }
 358  
 359  func (w *bufWriter) flushKeepBuffer() error {
 360  	if w.err != nil {
 361  		return w.err
 362  	}
 363  	if w.offset == 0 {
 364  		return nil
 365  	}
 366  	_, w.err = w.conn.Write(w.buf[:w.offset])
 367  	w.err = toIOError(w.err)
 368  	w.offset = 0
 369  	return w.err
 370  }
 371  
 372  type ioError struct {
 373  	error
 374  }
 375  
 376  func (i ioError) Unwrap() error {
 377  	return i.error
 378  }
 379  
 380  func isIOError(err error) bool {
 381  	return errors.As(err, &ioError{})
 382  }
 383  
 384  func toIOError(err error) error {
 385  	if err == nil {
 386  		return nil
 387  	}
 388  	return ioError{error: err}
 389  }
 390  
 391  type parsedDataFrame struct {
 392  	http2.FrameHeader
 393  	data mem.Buffer
 394  }
 395  
 396  func (df *parsedDataFrame) StreamEnded() bool {
 397  	return df.FrameHeader.Flags.Has(http2.FlagDataEndStream)
 398  }
 399  
 400  type framer struct {
 401  	writer    *bufWriter
 402  	fr        *http2.Framer
 403  	headerBuf []byte // cached slice for framer headers to reduce heap allocs.
 404  	reader    io.Reader
 405  	dataFrame parsedDataFrame // Cached data frame to avoid heap allocations.
 406  	pool      mem.BufferPool
 407  	errDetail error
 408  }
 409  
 410  var writeBufferPoolMap = make(map[int]*sync.Pool)
 411  var writeBufferMutex sync.Mutex
 412  
 413  func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer {
 414  	if writeBufferSize < 0 {
 415  		writeBufferSize = 0
 416  	}
 417  	var r io.Reader = conn
 418  	if readBufferSize > 0 {
 419  		r = bufio.NewReaderSize(r, readBufferSize)
 420  	}
 421  	var pool *sync.Pool
 422  	if sharedWriteBuffer {
 423  		pool = getWriteBufferPool(writeBufferSize)
 424  	}
 425  	w := newBufWriter(conn, writeBufferSize, pool)
 426  	f := &framer{
 427  		writer: w,
 428  		fr:     http2.NewFramer(w, r),
 429  		reader: r,
 430  		pool:   memPool,
 431  	}
 432  	f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
 433  	// Opt-in to Frame reuse API on framer to reduce garbage.
 434  	// Frames aren't safe to read from after a subsequent call to ReadFrame.
 435  	f.fr.SetReuseFrames()
 436  	f.fr.MaxHeaderListSize = maxHeaderListSize
 437  	f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
 438  	return f
 439  }
 440  
 441  // writeData writes a DATA frame.
 442  //
 443  // It is the caller's responsibility not to violate the maximum frame size.
 444  func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error {
 445  	var flags http2.Flags
 446  	if endStream {
 447  		flags = http2.FlagDataEndStream
 448  	}
 449  	length := uint32(0)
 450  	for _, d := range data {
 451  		length += uint32(len(d))
 452  	}
 453  	// TODO: Replace the header write with the framer API being added in
 454  	// https://github.com/golang/go/issues/66655.
 455  	f.headerBuf = append(f.headerBuf[:0],
 456  		byte(length>>16),
 457  		byte(length>>8),
 458  		byte(length),
 459  		byte(http2.FrameData),
 460  		byte(flags),
 461  		byte(streamID>>24),
 462  		byte(streamID>>16),
 463  		byte(streamID>>8),
 464  		byte(streamID))
 465  	if _, err := f.writer.Write(f.headerBuf); err != nil {
 466  		return err
 467  	}
 468  	for _, d := range data {
 469  		if _, err := f.writer.Write(d); err != nil {
 470  			return err
 471  		}
 472  	}
 473  	return nil
 474  }
 475  
 476  // readFrame reads a single frame. The returned Frame is only valid
 477  // until the next call to readFrame.
 478  func (f *framer) readFrame() (any, error) {
 479  	f.errDetail = nil
 480  	fh, err := f.fr.ReadFrameHeader()
 481  	if err != nil {
 482  		f.errDetail = f.fr.ErrorDetail()
 483  		return nil, err
 484  	}
 485  	// Read the data frame directly from the underlying io.Reader to avoid
 486  	// copies.
 487  	if fh.Type == http2.FrameData {
 488  		err = f.readDataFrame(fh)
 489  		return &f.dataFrame, err
 490  	}
 491  	fr, err := f.fr.ReadFrameForHeader(fh)
 492  	if err != nil {
 493  		f.errDetail = f.fr.ErrorDetail()
 494  		return nil, err
 495  	}
 496  	return fr, err
 497  }
 498  
 499  // errorDetail returns a more detailed error of the last error
 500  // returned by framer.readFrame. For instance, if readFrame
 501  // returns a StreamError with code PROTOCOL_ERROR, errorDetail
 502  // will say exactly what was invalid. errorDetail is not guaranteed
 503  // to return a non-nil value.
 504  // errorDetail is reset after the next call to readFrame.
 505  func (f *framer) errorDetail() error {
 506  	return f.errDetail
 507  }
 508  
 509  func (f *framer) readDataFrame(fh http2.FrameHeader) (err error) {
 510  	if fh.StreamID == 0 {
 511  		// DATA frames MUST be associated with a stream. If a
 512  		// DATA frame is received whose stream identifier
 513  		// field is 0x0, the recipient MUST respond with a
 514  		// connection error (Section 5.4.1) of type
 515  		// PROTOCOL_ERROR.
 516  		f.errDetail = errors.New("DATA frame with stream ID 0")
 517  		return http2.ConnectionError(http2.ErrCodeProtocol)
 518  	}
 519  	// Converting a *[]byte to a mem.SliceBuffer incurs a heap allocation. This
 520  	// conversion is performed by mem.NewBuffer. To avoid the extra allocation
 521  	// a []byte is allocated directly if required and cast to a mem.SliceBuffer.
 522  	var buf []byte
 523  	// poolHandle is the pointer returned by the buffer pool (if it's used.).
 524  	var poolHandle *[]byte
 525  	useBufferPool := !mem.IsBelowBufferPoolingThreshold(int(fh.Length))
 526  	if useBufferPool {
 527  		poolHandle = f.pool.Get(int(fh.Length))
 528  		buf = *poolHandle
 529  		defer func() {
 530  			if err != nil {
 531  				f.pool.Put(poolHandle)
 532  			}
 533  		}()
 534  	} else {
 535  		buf = make([]byte, int(fh.Length))
 536  	}
 537  	if fh.Flags.Has(http2.FlagDataPadded) {
 538  		if fh.Length == 0 {
 539  			return io.ErrUnexpectedEOF
 540  		}
 541  		// This initial 1-byte read can be inefficient for unbuffered readers,
 542  		// but it allows the rest of the payload to be read directly to the
 543  		// start of the destination slice. This makes it easy to return the
 544  		// original slice back to the buffer pool.
 545  		if _, err := io.ReadFull(f.reader, buf[:1]); err != nil {
 546  			return err
 547  		}
 548  		padSize := buf[0]
 549  		buf = buf[:len(buf)-1]
 550  		if int(padSize) > len(buf) {
 551  			// If the length of the padding is greater than the
 552  			// length of the frame payload, the recipient MUST
 553  			// treat this as a connection error.
 554  			// Filed: https://github.com/http2/http2-spec/issues/610
 555  			f.errDetail = errors.New("pad size larger than data payload")
 556  			return http2.ConnectionError(http2.ErrCodeProtocol)
 557  		}
 558  		if _, err := io.ReadFull(f.reader, buf); err != nil {
 559  			return err
 560  		}
 561  		buf = buf[:len(buf)-int(padSize)]
 562  	} else if _, err := io.ReadFull(f.reader, buf); err != nil {
 563  		return err
 564  	}
 565  
 566  	f.dataFrame.FrameHeader = fh
 567  	if useBufferPool {
 568  		// Update the handle to point to the (potentially re-sliced) buf.
 569  		*poolHandle = buf
 570  		f.dataFrame.data = mem.NewBuffer(poolHandle, f.pool)
 571  	} else {
 572  		f.dataFrame.data = mem.SliceBuffer(buf)
 573  	}
 574  	return nil
 575  }
 576  
 577  func (df *parsedDataFrame) Header() http2.FrameHeader {
 578  	return df.FrameHeader
 579  }
 580  
 581  func getWriteBufferPool(size int) *sync.Pool {
 582  	writeBufferMutex.Lock()
 583  	defer writeBufferMutex.Unlock()
 584  	pool, ok := writeBufferPoolMap[size]
 585  	if ok {
 586  		return pool
 587  	}
 588  	pool = &sync.Pool{
 589  		New: func() any {
 590  			b := make([]byte, size)
 591  			return &b
 592  		},
 593  	}
 594  	writeBufferPoolMap[size] = pool
 595  	return pool
 596  }
 597  
 598  // ParseDialTarget returns the network and address to pass to dialer.
 599  func ParseDialTarget(target string) (string, string) {
 600  	net := "tcp"
 601  	m1 := strings.Index(target, ":")
 602  	m2 := strings.Index(target, ":/")
 603  	// handle unix:addr which will fail with url.Parse
 604  	if m1 >= 0 && m2 < 0 {
 605  		if n := target[0:m1]; n == "unix" {
 606  			return n, target[m1+1:]
 607  		}
 608  	}
 609  	if m2 >= 0 {
 610  		t, err := url.Parse(target)
 611  		if err != nil {
 612  			return net, target
 613  		}
 614  		scheme := t.Scheme
 615  		addr := t.Path
 616  		if scheme == "unix" {
 617  			if addr == "" {
 618  				addr = t.Host
 619  			}
 620  			return scheme, addr
 621  		}
 622  	}
 623  	return net, target
 624  }
 625