duration.go raw

   1  /*
   2   *
   3   * Copyright 2023 gRPC authors.
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     http://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  package serviceconfig
  20  
  21  import (
  22  	"encoding/json"
  23  	"fmt"
  24  	"math"
  25  	"strconv"
  26  	"strings"
  27  	"time"
  28  )
  29  
  30  // Duration defines JSON marshal and unmarshal methods to conform to the
  31  // protobuf JSON spec defined [here].
  32  //
  33  // [here]: https://protobuf.dev/reference/protobuf/google.protobuf/#duration
  34  type Duration time.Duration
  35  
  36  func (d Duration) String() string {
  37  	return fmt.Sprint(time.Duration(d))
  38  }
  39  
  40  // MarshalJSON converts from d to a JSON string output.
  41  func (d Duration) MarshalJSON() ([]byte, error) {
  42  	ns := time.Duration(d).Nanoseconds()
  43  	sec := ns / int64(time.Second)
  44  	ns = ns % int64(time.Second)
  45  
  46  	var sign string
  47  	if sec < 0 || ns < 0 {
  48  		sign, sec, ns = "-", -1*sec, -1*ns
  49  	}
  50  
  51  	// Generated output always contains 0, 3, 6, or 9 fractional digits,
  52  	// depending on required precision.
  53  	str := fmt.Sprintf("%s%d.%09d", sign, sec, ns)
  54  	str = strings.TrimSuffix(str, "000")
  55  	str = strings.TrimSuffix(str, "000")
  56  	str = strings.TrimSuffix(str, ".000")
  57  	return []byte(fmt.Sprintf("\"%ss\"", str)), nil
  58  }
  59  
  60  // UnmarshalJSON unmarshals b as a duration JSON string into d.
  61  func (d *Duration) UnmarshalJSON(b []byte) error {
  62  	var s string
  63  	if err := json.Unmarshal(b, &s); err != nil {
  64  		return err
  65  	}
  66  	if !strings.HasSuffix(s, "s") {
  67  		return fmt.Errorf("malformed duration %q: missing seconds unit", s)
  68  	}
  69  	neg := false
  70  	if s[0] == '-' {
  71  		neg = true
  72  		s = s[1:]
  73  	}
  74  	ss := strings.SplitN(s[:len(s)-1], ".", 3)
  75  	if len(ss) > 2 {
  76  		return fmt.Errorf("malformed duration %q: too many decimals", s)
  77  	}
  78  	// hasDigits is set if either the whole or fractional part of the number is
  79  	// present, since both are optional but one is required.
  80  	hasDigits := false
  81  	var sec, ns int64
  82  	if len(ss[0]) > 0 {
  83  		var err error
  84  		if sec, err = strconv.ParseInt(ss[0], 10, 64); err != nil {
  85  			return fmt.Errorf("malformed duration %q: %v", s, err)
  86  		}
  87  		// Maximum seconds value per the durationpb spec.
  88  		const maxProtoSeconds = 315_576_000_000
  89  		if sec > maxProtoSeconds {
  90  			return fmt.Errorf("out of range: %q", s)
  91  		}
  92  		hasDigits = true
  93  	}
  94  	if len(ss) == 2 && len(ss[1]) > 0 {
  95  		if len(ss[1]) > 9 {
  96  			return fmt.Errorf("malformed duration %q: too many digits after decimal", s)
  97  		}
  98  		var err error
  99  		if ns, err = strconv.ParseInt(ss[1], 10, 64); err != nil {
 100  			return fmt.Errorf("malformed duration %q: %v", s, err)
 101  		}
 102  		for i := 9; i > len(ss[1]); i-- {
 103  			ns *= 10
 104  		}
 105  		hasDigits = true
 106  	}
 107  	if !hasDigits {
 108  		return fmt.Errorf("malformed duration %q: contains no numbers", s)
 109  	}
 110  
 111  	if neg {
 112  		sec *= -1
 113  		ns *= -1
 114  	}
 115  
 116  	// Maximum/minimum seconds/nanoseconds representable by Go's time.Duration.
 117  	const maxSeconds = math.MaxInt64 / int64(time.Second)
 118  	const maxNanosAtMaxSeconds = math.MaxInt64 % int64(time.Second)
 119  	const minSeconds = math.MinInt64 / int64(time.Second)
 120  	const minNanosAtMinSeconds = math.MinInt64 % int64(time.Second)
 121  
 122  	if sec > maxSeconds || (sec == maxSeconds && ns >= maxNanosAtMaxSeconds) {
 123  		*d = Duration(math.MaxInt64)
 124  	} else if sec < minSeconds || (sec == minSeconds && ns <= minNanosAtMinSeconds) {
 125  		*d = Duration(math.MinInt64)
 126  	} else {
 127  		*d = Duration(sec*int64(time.Second) + ns)
 128  	}
 129  	return nil
 130  }
 131