retry.go raw
1 // Copyright (c) 2015-2024 Jeevanandam M (jeeva@myjeeva.com), All rights reserved.
2 // resty source code and usage is governed by a MIT style
3 // license that can be found in the LICENSE file.
4
5 package resty
6
7 import (
8 "context"
9 "io"
10 "math"
11 "math/rand"
12 "sync"
13 "time"
14 )
15
16 const (
17 defaultMaxRetries = 3
18 defaultWaitTime = time.Duration(100) * time.Millisecond
19 defaultMaxWaitTime = time.Duration(2000) * time.Millisecond
20 )
21
22 type (
23 // Option is to create convenient retry options like wait time, max retries, etc.
24 Option func(*Options)
25
26 // RetryConditionFunc type is for the retry condition function
27 // input: non-nil Response OR request execution error
28 RetryConditionFunc func(*Response, error) bool
29
30 // OnRetryFunc is for side-effecting functions triggered on retry
31 OnRetryFunc func(*Response, error)
32
33 // RetryAfterFunc returns time to wait before retry
34 // For example, it can parse HTTP Retry-After header
35 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html
36 // Non-nil error is returned if it is found that the request is not retryable
37 // (0, nil) is a special result that means 'use default algorithm'
38 RetryAfterFunc func(*Client, *Response) (time.Duration, error)
39
40 // Options struct is used to hold retry settings.
41 Options struct {
42 maxRetries int
43 waitTime time.Duration
44 maxWaitTime time.Duration
45 retryConditions []RetryConditionFunc
46 retryHooks []OnRetryFunc
47 resetReaders bool
48 }
49 )
50
51 // Retries sets the max number of retries
52 func Retries(value int) Option {
53 return func(o *Options) {
54 o.maxRetries = value
55 }
56 }
57
58 // WaitTime sets the default wait time to sleep between requests
59 func WaitTime(value time.Duration) Option {
60 return func(o *Options) {
61 o.waitTime = value
62 }
63 }
64
65 // MaxWaitTime sets the max wait time to sleep between requests
66 func MaxWaitTime(value time.Duration) Option {
67 return func(o *Options) {
68 o.maxWaitTime = value
69 }
70 }
71
72 // RetryConditions sets the conditions that will be checked for retry
73 func RetryConditions(conditions []RetryConditionFunc) Option {
74 return func(o *Options) {
75 o.retryConditions = conditions
76 }
77 }
78
79 // RetryHooks sets the hooks that will be executed after each retry
80 func RetryHooks(hooks []OnRetryFunc) Option {
81 return func(o *Options) {
82 o.retryHooks = hooks
83 }
84 }
85
86 // ResetMultipartReaders sets a boolean value which will lead the start being seeked out
87 // on all multipart file readers if they implement [io.ReadSeeker]
88 func ResetMultipartReaders(value bool) Option {
89 return func(o *Options) {
90 o.resetReaders = value
91 }
92 }
93
94 // Backoff retries with increasing timeout duration up until X amount of retries
95 // (Default is 3 attempts, Override with option Retries(n))
96 func Backoff(operation func() (*Response, error), options ...Option) error {
97 // Defaults
98 opts := Options{
99 maxRetries: defaultMaxRetries,
100 waitTime: defaultWaitTime,
101 maxWaitTime: defaultMaxWaitTime,
102 retryConditions: []RetryConditionFunc{},
103 }
104
105 for _, o := range options {
106 o(&opts)
107 }
108
109 var (
110 resp *Response
111 err error
112 )
113
114 for attempt := 0; attempt <= opts.maxRetries; attempt++ {
115 resp, err = operation()
116 ctx := context.Background()
117 if resp != nil && resp.Request.ctx != nil {
118 ctx = resp.Request.ctx
119 }
120 if ctx.Err() != nil {
121 return err
122 }
123
124 err1 := unwrapNoRetryErr(err) // raw error, it used for return users callback.
125 needsRetry := err != nil && err == err1 // retry on a few operation errors by default
126
127 for _, condition := range opts.retryConditions {
128 needsRetry = condition(resp, err1)
129 if needsRetry {
130 break
131 }
132 }
133
134 if !needsRetry {
135 return err
136 }
137
138 if opts.resetReaders {
139 if err := resetFileReaders(resp.Request.multipartFiles); err != nil {
140 return err
141 }
142 if err := resetFieldReaders(resp.Request.multipartFields); err != nil {
143 return err
144 }
145 }
146
147 for _, hook := range opts.retryHooks {
148 hook(resp, err)
149 }
150
151 // Don't need to wait when no retries left.
152 // Still run retry hooks even on last retry to keep compatibility.
153 if attempt == opts.maxRetries {
154 return err
155 }
156
157 waitTime, err2 := sleepDuration(resp, opts.waitTime, opts.maxWaitTime, attempt)
158 if err2 != nil {
159 if err == nil {
160 err = err2
161 }
162 return err
163 }
164
165 select {
166 case <-time.After(waitTime):
167 case <-ctx.Done():
168 return ctx.Err()
169 }
170 }
171
172 return err
173 }
174
175 func sleepDuration(resp *Response, min, max time.Duration, attempt int) (time.Duration, error) {
176 const maxInt = 1<<31 - 1 // max int for arch 386
177 if max < 0 {
178 max = maxInt
179 }
180 if resp == nil {
181 return jitterBackoff(min, max, attempt), nil
182 }
183
184 retryAfterFunc := resp.Request.client.RetryAfter
185
186 // Check for custom callback
187 if retryAfterFunc == nil {
188 return jitterBackoff(min, max, attempt), nil
189 }
190
191 result, err := retryAfterFunc(resp.Request.client, resp)
192 if err != nil {
193 return 0, err // i.e. 'API quota exceeded'
194 }
195 if result == 0 {
196 return jitterBackoff(min, max, attempt), nil
197 }
198 if result < 0 || max < result {
199 result = max
200 }
201 if result < min {
202 result = min
203 }
204 return result, nil
205 }
206
207 // Return capped exponential backoff with jitter
208 // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
209 func jitterBackoff(min, max time.Duration, attempt int) time.Duration {
210 base := float64(min)
211 capLevel := float64(max)
212
213 temp := math.Min(capLevel, base*math.Exp2(float64(attempt)))
214 ri := time.Duration(temp / 2)
215 if ri == 0 {
216 ri = time.Nanosecond
217 }
218 result := randDuration(ri)
219
220 if result < min {
221 result = min
222 }
223
224 return result
225 }
226
227 var rnd = newRnd()
228 var rndMu sync.Mutex
229
230 func randDuration(center time.Duration) time.Duration {
231 rndMu.Lock()
232 defer rndMu.Unlock()
233
234 var ri = int64(center)
235 var jitter = rnd.Int63n(ri)
236 return time.Duration(math.Abs(float64(ri + jitter)))
237 }
238
239 func newRnd() *rand.Rand {
240 var seed = time.Now().UnixNano()
241 var src = rand.NewSource(seed)
242 return rand.New(src)
243 }
244
245 func resetFileReaders(files []*File) error {
246 for _, f := range files {
247 if rs, ok := f.Reader.(io.ReadSeeker); ok {
248 if _, err := rs.Seek(0, io.SeekStart); err != nil {
249 return err
250 }
251 }
252 }
253
254 return nil
255 }
256
257 func resetFieldReaders(fields []*MultipartField) error {
258 for _, f := range fields {
259 if rs, ok := f.Reader.(io.ReadSeeker); ok {
260 if _, err := rs.Seek(0, io.SeekStart); err != nil {
261 return err
262 }
263 }
264 }
265
266 return nil
267 }
268