retry.go raw

   1  // Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
   2  // resty source code and usage is governed by a MIT style
   3  // license that can be found in the LICENSE file.
   4  
   5  package resty
   6  
   7  import (
   8  	"context"
   9  	"io"
  10  	"math"
  11  	"math/rand"
  12  	"sync"
  13  	"time"
  14  )
  15  
  16  const (
  17  	defaultMaxRetries  = 3
  18  	defaultWaitTime    = time.Duration(100) * time.Millisecond
  19  	defaultMaxWaitTime = time.Duration(2000) * time.Millisecond
  20  )
  21  
  22  type (
  23  	// Option is to create convenient retry options like wait time, max retries, etc.
  24  	Option func(*Options)
  25  
  26  	// RetryConditionFunc type is for the retry condition function
  27  	// input: non-nil Response OR request execution error
  28  	RetryConditionFunc func(*Response, error) bool
  29  
  30  	// OnRetryFunc is for side-effecting functions triggered on retry
  31  	OnRetryFunc func(*Response, error)
  32  
  33  	// RetryAfterFunc returns time to wait before retry
  34  	// For example, it can parse HTTP Retry-After header
  35  	// https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
  36  	// Non-nil error is returned if it is found that the request is not retryable
  37  	// (0, nil) is a special result that means 'use default algorithm'
  38  	RetryAfterFunc func(*Client, *Response) (time.Duration, error)
  39  
  40  	// Options struct is used to hold retry settings.
  41  	Options struct {
  42  		maxRetries      int
  43  		waitTime        time.Duration
  44  		maxWaitTime     time.Duration
  45  		retryConditions []RetryConditionFunc
  46  		retryHooks      []OnRetryFunc
  47  		resetReaders    bool
  48  	}
  49  )
  50  
  51  // Retries sets the max number of retries
  52  func Retries(value int) Option {
  53  	return func(o *Options) {
  54  		o.maxRetries = value
  55  	}
  56  }
  57  
  58  // WaitTime sets the default wait time to sleep between requests
  59  func WaitTime(value time.Duration) Option {
  60  	return func(o *Options) {
  61  		o.waitTime = value
  62  	}
  63  }
  64  
  65  // MaxWaitTime sets the max wait time to sleep between requests
  66  func MaxWaitTime(value time.Duration) Option {
  67  	return func(o *Options) {
  68  		o.maxWaitTime = value
  69  	}
  70  }
  71  
  72  // RetryConditions sets the conditions that will be checked for retry
  73  func RetryConditions(conditions []RetryConditionFunc) Option {
  74  	return func(o *Options) {
  75  		o.retryConditions = conditions
  76  	}
  77  }
  78  
  79  // RetryHooks sets the hooks that will be executed after each retry
  80  func RetryHooks(hooks []OnRetryFunc) Option {
  81  	return func(o *Options) {
  82  		o.retryHooks = hooks
  83  	}
  84  }
  85  
  86  // ResetMultipartReaders sets a boolean value which will lead the start being seeked out
  87  // on all multipart file readers if they implement [io.ReadSeeker]
  88  func ResetMultipartReaders(value bool) Option {
  89  	return func(o *Options) {
  90  		o.resetReaders = value
  91  	}
  92  }
  93  
  94  // Backoff retries with increasing timeout duration up until X amount of retries
  95  // (Default is 3 attempts, Override with option Retries(n))
  96  func Backoff(operation func() (*Response, error), options ...Option) error {
  97  	// Defaults
  98  	opts := Options{
  99  		maxRetries:      defaultMaxRetries,
 100  		waitTime:        defaultWaitTime,
 101  		maxWaitTime:     defaultMaxWaitTime,
 102  		retryConditions: []RetryConditionFunc{},
 103  	}
 104  
 105  	for _, o := range options {
 106  		o(&opts)
 107  	}
 108  
 109  	var (
 110  		resp *Response
 111  		err  error
 112  	)
 113  
 114  	for attempt := 0; attempt <= opts.maxRetries; attempt++ {
 115  		resp, err = operation()
 116  		ctx := context.Background()
 117  		if resp != nil && resp.Request.ctx != nil {
 118  			ctx = resp.Request.ctx
 119  		}
 120  		if ctx.Err() != nil {
 121  			return err
 122  		}
 123  
 124  		err1 := unwrapNoRetryErr(err)           // raw error, it used for return users callback.
 125  		needsRetry := err != nil && err == err1 // retry on a few operation errors by default
 126  
 127  		for _, condition := range opts.retryConditions {
 128  			needsRetry = condition(resp, err1)
 129  			if needsRetry {
 130  				break
 131  			}
 132  		}
 133  
 134  		if !needsRetry {
 135  			return err
 136  		}
 137  
 138  		if opts.resetReaders {
 139  			if err := resetFileReaders(resp.Request.multipartFiles); err != nil {
 140  				return err
 141  			}
 142  			if err := resetFieldReaders(resp.Request.multipartFields); err != nil {
 143  				return err
 144  			}
 145  		}
 146  
 147  		for _, hook := range opts.retryHooks {
 148  			hook(resp, err)
 149  		}
 150  
 151  		// Don't need to wait when no retries left.
 152  		// Still run retry hooks even on last retry to keep compatibility.
 153  		if attempt == opts.maxRetries {
 154  			return err
 155  		}
 156  
 157  		waitTime, err2 := sleepDuration(resp, opts.waitTime, opts.maxWaitTime, attempt)
 158  		if err2 != nil {
 159  			if err == nil {
 160  				err = err2
 161  			}
 162  			return err
 163  		}
 164  
 165  		select {
 166  		case <-time.After(waitTime):
 167  		case <-ctx.Done():
 168  			return ctx.Err()
 169  		}
 170  	}
 171  
 172  	return err
 173  }
 174  
 175  func sleepDuration(resp *Response, min, max time.Duration, attempt int) (time.Duration, error) {
 176  	const maxInt = 1<<31 - 1 // max int for arch 386
 177  	if max < 0 {
 178  		max = maxInt
 179  	}
 180  	if resp == nil {
 181  		return jitterBackoff(min, max, attempt), nil
 182  	}
 183  
 184  	retryAfterFunc := resp.Request.client.RetryAfter
 185  
 186  	// Check for custom callback
 187  	if retryAfterFunc == nil {
 188  		return jitterBackoff(min, max, attempt), nil
 189  	}
 190  
 191  	result, err := retryAfterFunc(resp.Request.client, resp)
 192  	if err != nil {
 193  		return 0, err // i.e. 'API quota exceeded'
 194  	}
 195  	if result == 0 {
 196  		return jitterBackoff(min, max, attempt), nil
 197  	}
 198  	if result < 0 || max < result {
 199  		result = max
 200  	}
 201  	if result < min {
 202  		result = min
 203  	}
 204  	return result, nil
 205  }
 206  
 207  // Return capped exponential backoff with jitter
 208  // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
 209  func jitterBackoff(min, max time.Duration, attempt int) time.Duration {
 210  	base := float64(min)
 211  	capLevel := float64(max)
 212  
 213  	temp := math.Min(capLevel, base*math.Exp2(float64(attempt)))
 214  	ri := time.Duration(temp / 2)
 215  	if ri == 0 {
 216  		ri = time.Nanosecond
 217  	}
 218  	result := randDuration(ri)
 219  
 220  	if result < min {
 221  		result = min
 222  	}
 223  
 224  	return result
 225  }
 226  
 227  var rnd = newRnd()
 228  var rndMu sync.Mutex
 229  
 230  func randDuration(center time.Duration) time.Duration {
 231  	rndMu.Lock()
 232  	defer rndMu.Unlock()
 233  
 234  	var ri = int64(center)
 235  	var jitter = rnd.Int63n(ri)
 236  	return time.Duration(math.Abs(float64(ri + jitter)))
 237  }
 238  
 239  func newRnd() *rand.Rand {
 240  	var seed = time.Now().UnixNano()
 241  	var src = rand.NewSource(seed)
 242  	return rand.New(src)
 243  }
 244  
 245  func resetFileReaders(files []*File) error {
 246  	for _, f := range files {
 247  		if rs, ok := f.Reader.(io.ReadSeeker); ok {
 248  			if _, err := rs.Seek(0, io.SeekStart); err != nil {
 249  				return err
 250  			}
 251  		}
 252  	}
 253  
 254  	return nil
 255  }
 256  
 257  func resetFieldReaders(fields []*MultipartField) error {
 258  	for _, f := range fields {
 259  		if rs, ok := f.Reader.(io.ReadSeeker); ok {
 260  			if _, err := rs.Seek(0, io.SeekStart); err != nil {
 261  				return err
 262  			}
 263  		}
 264  	}
 265  
 266  	return nil
 267  }
 268