http_util.go raw
1 /*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 package transport
20
21 import (
22 "bufio"
23 "encoding/base64"
24 "errors"
25 "fmt"
26 "io"
27 "math"
28 "net/http"
29 "net/url"
30 "strconv"
31 "strings"
32 "sync"
33 "time"
34 "unicode/utf8"
35
36 "golang.org/x/net/http2"
37 "golang.org/x/net/http2/hpack"
38 "google.golang.org/grpc/codes"
39 "google.golang.org/grpc/mem"
40 )
41
42 const (
43 // http2MaxFrameLen specifies the max length of a HTTP2 frame.
44 http2MaxFrameLen = 16384 // 16KB frame
45 // https://httpwg.org/specs/rfc7540.html#SettingValues
46 http2InitHeaderTableSize = 4096
47 )
48
49 var (
50 clientPreface = []byte(http2.ClientPreface)
51 http2ErrConvTab = map[http2.ErrCode]codes.Code{
52 http2.ErrCodeNo: codes.Internal,
53 http2.ErrCodeProtocol: codes.Internal,
54 http2.ErrCodeInternal: codes.Internal,
55 http2.ErrCodeFlowControl: codes.ResourceExhausted,
56 http2.ErrCodeSettingsTimeout: codes.Internal,
57 http2.ErrCodeStreamClosed: codes.Internal,
58 http2.ErrCodeFrameSize: codes.Internal,
59 http2.ErrCodeRefusedStream: codes.Unavailable,
60 http2.ErrCodeCancel: codes.Canceled,
61 http2.ErrCodeCompression: codes.Internal,
62 http2.ErrCodeConnect: codes.Internal,
63 http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted,
64 http2.ErrCodeInadequateSecurity: codes.PermissionDenied,
65 http2.ErrCodeHTTP11Required: codes.Internal,
66 }
67 // HTTPStatusConvTab is the HTTP status code to gRPC error code conversion table.
68 HTTPStatusConvTab = map[int]codes.Code{
69 // 400 Bad Request - INTERNAL.
70 http.StatusBadRequest: codes.Internal,
71 // 401 Unauthorized - UNAUTHENTICATED.
72 http.StatusUnauthorized: codes.Unauthenticated,
73 // 403 Forbidden - PERMISSION_DENIED.
74 http.StatusForbidden: codes.PermissionDenied,
75 // 404 Not Found - UNIMPLEMENTED.
76 http.StatusNotFound: codes.Unimplemented,
77 // 429 Too Many Requests - UNAVAILABLE.
78 http.StatusTooManyRequests: codes.Unavailable,
79 // 502 Bad Gateway - UNAVAILABLE.
80 http.StatusBadGateway: codes.Unavailable,
81 // 503 Service Unavailable - UNAVAILABLE.
82 http.StatusServiceUnavailable: codes.Unavailable,
83 // 504 Gateway timeout - UNAVAILABLE.
84 http.StatusGatewayTimeout: codes.Unavailable,
85 }
86 )
87
88 var grpcStatusDetailsBinHeader = "grpc-status-details-bin"
89
90 // isReservedHeader checks whether hdr belongs to HTTP2 headers
91 // reserved by gRPC protocol. Any other headers are classified as the
92 // user-specified metadata.
93 func isReservedHeader(hdr string) bool {
94 if hdr != "" && hdr[0] == ':' {
95 return true
96 }
97 switch hdr {
98 case "content-type",
99 "user-agent",
100 "grpc-message-type",
101 "grpc-encoding",
102 "grpc-message",
103 "grpc-status",
104 "grpc-timeout",
105 // Intentionally exclude grpc-previous-rpc-attempts and
106 // grpc-retry-pushback-ms, which are "reserved", but their API
107 // intentionally works via metadata.
108 "te":
109 return true
110 default:
111 return false
112 }
113 }
114
115 // isWhitelistedHeader checks whether hdr should be propagated into metadata
116 // visible to users, even though it is classified as "reserved", above.
117 func isWhitelistedHeader(hdr string) bool {
118 switch hdr {
119 case ":authority", "user-agent":
120 return true
121 default:
122 return false
123 }
124 }
125
126 const binHdrSuffix = "-bin"
127
128 func encodeBinHeader(v []byte) string {
129 return base64.RawStdEncoding.EncodeToString(v)
130 }
131
132 func decodeBinHeader(v string) ([]byte, error) {
133 if len(v)%4 == 0 {
134 // Input was padded, or padding was not necessary.
135 return base64.StdEncoding.DecodeString(v)
136 }
137 return base64.RawStdEncoding.DecodeString(v)
138 }
139
140 func encodeMetadataHeader(k, v string) string {
141 if strings.HasSuffix(k, binHdrSuffix) {
142 return encodeBinHeader(([]byte)(v))
143 }
144 return v
145 }
146
147 func decodeMetadataHeader(k, v string) (string, error) {
148 if strings.HasSuffix(k, binHdrSuffix) {
149 b, err := decodeBinHeader(v)
150 return string(b), err
151 }
152 return v, nil
153 }
154
155 type timeoutUnit uint8
156
157 const (
158 hour timeoutUnit = 'H'
159 minute timeoutUnit = 'M'
160 second timeoutUnit = 'S'
161 millisecond timeoutUnit = 'm'
162 microsecond timeoutUnit = 'u'
163 nanosecond timeoutUnit = 'n'
164 )
165
166 func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) {
167 switch u {
168 case hour:
169 return time.Hour, true
170 case minute:
171 return time.Minute, true
172 case second:
173 return time.Second, true
174 case millisecond:
175 return time.Millisecond, true
176 case microsecond:
177 return time.Microsecond, true
178 case nanosecond:
179 return time.Nanosecond, true
180 default:
181 }
182 return
183 }
184
185 func decodeTimeout(s string) (time.Duration, error) {
186 size := len(s)
187 if size < 2 {
188 return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
189 }
190 if size > 9 {
191 // Spec allows for 8 digits plus the unit.
192 return 0, fmt.Errorf("transport: timeout string is too long: %q", s)
193 }
194 unit := timeoutUnit(s[size-1])
195 d, ok := timeoutUnitToDuration(unit)
196 if !ok {
197 return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s)
198 }
199 t, err := strconv.ParseUint(s[:size-1], 10, 64)
200 if err != nil {
201 return 0, err
202 }
203 const maxHours = math.MaxInt64 / uint64(time.Hour)
204 if d == time.Hour && t > maxHours {
205 // This timeout would overflow math.MaxInt64; clamp it.
206 return time.Duration(math.MaxInt64), nil
207 }
208 return d * time.Duration(t), nil
209 }
210
211 const (
212 spaceByte = ' '
213 tildeByte = '~'
214 percentByte = '%'
215 )
216
217 // encodeGrpcMessage is used to encode status code in header field
218 // "grpc-message". It does percent encoding and also replaces invalid utf-8
219 // characters with Unicode replacement character.
220 //
221 // It checks to see if each individual byte in msg is an allowable byte, and
222 // then either percent encoding or passing it through. When percent encoding,
223 // the byte is converted into hexadecimal notation with a '%' prepended.
224 func encodeGrpcMessage(msg string) string {
225 if msg == "" {
226 return ""
227 }
228 lenMsg := len(msg)
229 for i := 0; i < lenMsg; i++ {
230 c := msg[i]
231 if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
232 return encodeGrpcMessageUnchecked(msg)
233 }
234 }
235 return msg
236 }
237
238 func encodeGrpcMessageUnchecked(msg string) string {
239 var sb strings.Builder
240 for len(msg) > 0 {
241 r, size := utf8.DecodeRuneInString(msg)
242 for _, b := range []byte(string(r)) {
243 if size > 1 {
244 // If size > 1, r is not ascii. Always do percent encoding.
245 fmt.Fprintf(&sb, "%%%02X", b)
246 continue
247 }
248
249 // The for loop is necessary even if size == 1. r could be
250 // utf8.RuneError.
251 //
252 // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
253 if b >= spaceByte && b <= tildeByte && b != percentByte {
254 sb.WriteByte(b)
255 } else {
256 fmt.Fprintf(&sb, "%%%02X", b)
257 }
258 }
259 msg = msg[size:]
260 }
261 return sb.String()
262 }
263
264 // decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
265 func decodeGrpcMessage(msg string) string {
266 if msg == "" {
267 return ""
268 }
269 lenMsg := len(msg)
270 for i := 0; i < lenMsg; i++ {
271 if msg[i] == percentByte && i+2 < lenMsg {
272 return decodeGrpcMessageUnchecked(msg)
273 }
274 }
275 return msg
276 }
277
278 func decodeGrpcMessageUnchecked(msg string) string {
279 var sb strings.Builder
280 lenMsg := len(msg)
281 for i := 0; i < lenMsg; i++ {
282 c := msg[i]
283 if c == percentByte && i+2 < lenMsg {
284 parsed, err := strconv.ParseUint(msg[i+1:i+3], 16, 8)
285 if err != nil {
286 sb.WriteByte(c)
287 } else {
288 sb.WriteByte(byte(parsed))
289 i += 2
290 }
291 } else {
292 sb.WriteByte(c)
293 }
294 }
295 return sb.String()
296 }
297
298 type bufWriter struct {
299 pool *sync.Pool
300 buf []byte
301 offset int
302 batchSize int
303 conn io.Writer
304 err error
305 }
306
307 func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter {
308 w := &bufWriter{
309 batchSize: batchSize,
310 conn: conn,
311 pool: pool,
312 }
313 // this indicates that we should use non shared buf
314 if pool == nil {
315 w.buf = make([]byte, batchSize)
316 }
317 return w
318 }
319
320 func (w *bufWriter) Write(b []byte) (int, error) {
321 if w.err != nil {
322 return 0, w.err
323 }
324 if w.batchSize == 0 { // Buffer has been disabled.
325 n, err := w.conn.Write(b)
326 return n, toIOError(err)
327 }
328 if w.buf == nil {
329 b := w.pool.Get().(*[]byte)
330 w.buf = *b
331 }
332 written := 0
333 for len(b) > 0 {
334 copied := copy(w.buf[w.offset:], b)
335 b = b[copied:]
336 written += copied
337 w.offset += copied
338 if w.offset < w.batchSize {
339 continue
340 }
341 if err := w.flushKeepBuffer(); err != nil {
342 return written, err
343 }
344 }
345 return written, nil
346 }
347
348 func (w *bufWriter) Flush() error {
349 err := w.flushKeepBuffer()
350 // Only release the buffer if we are in a "shared" mode
351 if w.buf != nil && w.pool != nil {
352 b := w.buf
353 w.pool.Put(&b)
354 w.buf = nil
355 }
356 return err
357 }
358
359 func (w *bufWriter) flushKeepBuffer() error {
360 if w.err != nil {
361 return w.err
362 }
363 if w.offset == 0 {
364 return nil
365 }
366 _, w.err = w.conn.Write(w.buf[:w.offset])
367 w.err = toIOError(w.err)
368 w.offset = 0
369 return w.err
370 }
371
372 type ioError struct {
373 error
374 }
375
376 func (i ioError) Unwrap() error {
377 return i.error
378 }
379
380 func isIOError(err error) bool {
381 return errors.As(err, &ioError{})
382 }
383
384 func toIOError(err error) error {
385 if err == nil {
386 return nil
387 }
388 return ioError{error: err}
389 }
390
391 type parsedDataFrame struct {
392 http2.FrameHeader
393 data mem.Buffer
394 }
395
396 func (df *parsedDataFrame) StreamEnded() bool {
397 return df.FrameHeader.Flags.Has(http2.FlagDataEndStream)
398 }
399
400 type framer struct {
401 writer *bufWriter
402 fr *http2.Framer
403 headerBuf []byte // cached slice for framer headers to reduce heap allocs.
404 reader io.Reader
405 dataFrame parsedDataFrame // Cached data frame to avoid heap allocations.
406 pool mem.BufferPool
407 errDetail error
408 }
409
410 var writeBufferPoolMap = make(map[int]*sync.Pool)
411 var writeBufferMutex sync.Mutex
412
413 func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer {
414 if writeBufferSize < 0 {
415 writeBufferSize = 0
416 }
417 var r io.Reader = conn
418 if readBufferSize > 0 {
419 r = bufio.NewReaderSize(r, readBufferSize)
420 }
421 var pool *sync.Pool
422 if sharedWriteBuffer {
423 pool = getWriteBufferPool(writeBufferSize)
424 }
425 w := newBufWriter(conn, writeBufferSize, pool)
426 f := &framer{
427 writer: w,
428 fr: http2.NewFramer(w, r),
429 reader: r,
430 pool: memPool,
431 }
432 f.fr.SetMaxReadFrameSize(http2MaxFrameLen)
433 // Opt-in to Frame reuse API on framer to reduce garbage.
434 // Frames aren't safe to read from after a subsequent call to ReadFrame.
435 f.fr.SetReuseFrames()
436 f.fr.MaxHeaderListSize = maxHeaderListSize
437 f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
438 return f
439 }
440
441 // writeData writes a DATA frame.
442 //
443 // It is the caller's responsibility not to violate the maximum frame size.
444 func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error {
445 var flags http2.Flags
446 if endStream {
447 flags = http2.FlagDataEndStream
448 }
449 length := uint32(0)
450 for _, d := range data {
451 length += uint32(len(d))
452 }
453 // TODO: Replace the header write with the framer API being added in
454 // https://github.com/golang/go/issues/66655.
455 f.headerBuf = append(f.headerBuf[:0],
456 byte(length>>16),
457 byte(length>>8),
458 byte(length),
459 byte(http2.FrameData),
460 byte(flags),
461 byte(streamID>>24),
462 byte(streamID>>16),
463 byte(streamID>>8),
464 byte(streamID))
465 if _, err := f.writer.Write(f.headerBuf); err != nil {
466 return err
467 }
468 for _, d := range data {
469 if _, err := f.writer.Write(d); err != nil {
470 return err
471 }
472 }
473 return nil
474 }
475
476 // readFrame reads a single frame. The returned Frame is only valid
477 // until the next call to readFrame.
478 func (f *framer) readFrame() (any, error) {
479 f.errDetail = nil
480 fh, err := f.fr.ReadFrameHeader()
481 if err != nil {
482 f.errDetail = f.fr.ErrorDetail()
483 return nil, err
484 }
485 // Read the data frame directly from the underlying io.Reader to avoid
486 // copies.
487 if fh.Type == http2.FrameData {
488 err = f.readDataFrame(fh)
489 return &f.dataFrame, err
490 }
491 fr, err := f.fr.ReadFrameForHeader(fh)
492 if err != nil {
493 f.errDetail = f.fr.ErrorDetail()
494 return nil, err
495 }
496 return fr, err
497 }
498
499 // errorDetail returns a more detailed error of the last error
500 // returned by framer.readFrame. For instance, if readFrame
501 // returns a StreamError with code PROTOCOL_ERROR, errorDetail
502 // will say exactly what was invalid. errorDetail is not guaranteed
503 // to return a non-nil value.
504 // errorDetail is reset after the next call to readFrame.
505 func (f *framer) errorDetail() error {
506 return f.errDetail
507 }
508
509 func (f *framer) readDataFrame(fh http2.FrameHeader) (err error) {
510 if fh.StreamID == 0 {
511 // DATA frames MUST be associated with a stream. If a
512 // DATA frame is received whose stream identifier
513 // field is 0x0, the recipient MUST respond with a
514 // connection error (Section 5.4.1) of type
515 // PROTOCOL_ERROR.
516 f.errDetail = errors.New("DATA frame with stream ID 0")
517 return http2.ConnectionError(http2.ErrCodeProtocol)
518 }
519 // Converting a *[]byte to a mem.SliceBuffer incurs a heap allocation. This
520 // conversion is performed by mem.NewBuffer. To avoid the extra allocation
521 // a []byte is allocated directly if required and cast to a mem.SliceBuffer.
522 var buf []byte
523 // poolHandle is the pointer returned by the buffer pool (if it's used.).
524 var poolHandle *[]byte
525 useBufferPool := !mem.IsBelowBufferPoolingThreshold(int(fh.Length))
526 if useBufferPool {
527 poolHandle = f.pool.Get(int(fh.Length))
528 buf = *poolHandle
529 defer func() {
530 if err != nil {
531 f.pool.Put(poolHandle)
532 }
533 }()
534 } else {
535 buf = make([]byte, int(fh.Length))
536 }
537 if fh.Flags.Has(http2.FlagDataPadded) {
538 if fh.Length == 0 {
539 return io.ErrUnexpectedEOF
540 }
541 // This initial 1-byte read can be inefficient for unbuffered readers,
542 // but it allows the rest of the payload to be read directly to the
543 // start of the destination slice. This makes it easy to return the
544 // original slice back to the buffer pool.
545 if _, err := io.ReadFull(f.reader, buf[:1]); err != nil {
546 return err
547 }
548 padSize := buf[0]
549 buf = buf[:len(buf)-1]
550 if int(padSize) > len(buf) {
551 // If the length of the padding is greater than the
552 // length of the frame payload, the recipient MUST
553 // treat this as a connection error.
554 // Filed: https://github.com/http2/http2-spec/issues/610
555 f.errDetail = errors.New("pad size larger than data payload")
556 return http2.ConnectionError(http2.ErrCodeProtocol)
557 }
558 if _, err := io.ReadFull(f.reader, buf); err != nil {
559 return err
560 }
561 buf = buf[:len(buf)-int(padSize)]
562 } else if _, err := io.ReadFull(f.reader, buf); err != nil {
563 return err
564 }
565
566 f.dataFrame.FrameHeader = fh
567 if useBufferPool {
568 // Update the handle to point to the (potentially re-sliced) buf.
569 *poolHandle = buf
570 f.dataFrame.data = mem.NewBuffer(poolHandle, f.pool)
571 } else {
572 f.dataFrame.data = mem.SliceBuffer(buf)
573 }
574 return nil
575 }
576
577 func (df *parsedDataFrame) Header() http2.FrameHeader {
578 return df.FrameHeader
579 }
580
581 func getWriteBufferPool(size int) *sync.Pool {
582 writeBufferMutex.Lock()
583 defer writeBufferMutex.Unlock()
584 pool, ok := writeBufferPoolMap[size]
585 if ok {
586 return pool
587 }
588 pool = &sync.Pool{
589 New: func() any {
590 b := make([]byte, size)
591 return &b
592 },
593 }
594 writeBufferPoolMap[size] = pool
595 return pool
596 }
597
598 // ParseDialTarget returns the network and address to pass to dialer.
599 func ParseDialTarget(target string) (string, string) {
600 net := "tcp"
601 m1 := strings.Index(target, ":")
602 m2 := strings.Index(target, ":/")
603 // handle unix:addr which will fail with url.Parse
604 if m1 >= 0 && m2 < 0 {
605 if n := target[0:m1]; n == "unix" {
606 return n, target[m1+1:]
607 }
608 }
609 if m2 >= 0 {
610 t, err := url.Parse(target)
611 if err != nil {
612 return net, target
613 }
614 scheme := t.Scheme
615 addr := t.Path
616 if scheme == "unix" {
617 if addr == "" {
618 addr = t.Host
619 }
620 return scheme, addr
621 }
622 }
623 return net, target
624 }
625