client.go raw
1 package rest
2
3 import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "io"
8 "net/http"
9 "net/url"
10 "regexp"
11 "strconv"
12 "time"
13 )
14
15 const (
16 clientVersion = "2.16.0"
17
18 defaultBase = "https://api.nsone.net"
19 defaultEndpoint = defaultBase + "/v1/"
20 defaultShouldFollowPagination = true
21 defaultUserAgent = "go-ns1/" + clientVersion
22
23 headerAuth = "X-NSONE-Key"
24 headerRateLimit = "X-Ratelimit-Limit"
25 headerRateRemaining = "X-Ratelimit-Remaining"
26 headerRatePeriod = "X-Ratelimit-Period"
27
28 defaultRateLimitWaitTime = time.Millisecond * 100
29 )
30
31 // Doer is a single method interface that allows a user to extend/augment an http.Client instance.
32 // Note: http.Client satisfies the Doer interface.
33 type Doer interface {
34 Do(*http.Request) (*http.Response, error)
35 }
36
37 // Client manages communication with the NS1 Rest API.
38 type Client struct {
39 // httpClient handles all rest api communication,
40 // and expects an *http.Client.
41 httpClient Doer
42
43 // NS1 rest endpoint, overrides default if given.
44 Endpoint *url.URL
45
46 // NS1 api key (value for http request header 'X-NSONE-Key').
47 APIKey string
48
49 // NS1 go rest user agent (value for http request header 'User-Agent').
50 UserAgent string
51
52 // Func to call after response is returned in Do
53 RateLimitFunc func(RateLimit)
54
55 // Whether the client should handle paginated responses automatically.
56 FollowPagination bool
57
58 // From the excellent github-go client.
59 common service // Reuse a single struct instead of allocating one for each service on the heap.
60
61 // Services used for communicating with different components of the NS1 API.
62 APIKeys *APIKeysService
63 DataFeeds *DataFeedsService
64 DataSources *DataSourcesService
65 Jobs *JobsService
66 MonitorRegions *MonitorRegionsService
67 PulsarJobs *PulsarJobsService
68 PulsarDecisions *PulsarDecisionsService
69 Notifications *NotificationsService
70 Records *RecordsService
71 Applications *ApplicationsService
72 RecordSearch *RecordSearchService
73 ZoneSearch *ZoneSearchService
74 Settings *SettingsService
75 Stats *StatsService
76 Teams *TeamsService
77 Users *UsersService
78 Warnings *WarningsService
79 Zones *ZonesService
80 Versions *VersionsService
81 DNSSEC *DNSSECService
82 TSIG *TsigService
83 View *DNSViewService
84 Network *NetworkService
85 GlobalIPWhitelist *GlobalIPWhitelistService
86 Datasets *DatasetsService
87 Activity *ActivityService
88 Redirects *RedirectService
89 RedirectCertificates *RedirectCertificateService
90 Alerts *AlertsService
91 BillingUsage *BillingUsageService
92 }
93
94 // NewClient constructs and returns a reference to an instantiated Client.
95 func NewClient(httpClient Doer, options ...func(*Client)) *Client {
96 endpoint, _ := url.Parse(defaultEndpoint)
97
98 if httpClient == nil {
99 httpClient = http.DefaultClient
100 }
101
102 c := &Client{
103 httpClient: httpClient,
104 Endpoint: endpoint,
105 RateLimitFunc: defaultRateLimitFunc,
106 UserAgent: defaultUserAgent,
107 FollowPagination: defaultShouldFollowPagination,
108 }
109
110 c.common.client = c
111 c.APIKeys = (*APIKeysService)(&c.common)
112 c.DataFeeds = (*DataFeedsService)(&c.common)
113 c.DataSources = (*DataSourcesService)(&c.common)
114 c.Jobs = (*JobsService)(&c.common)
115 c.MonitorRegions = (*MonitorRegionsService)(&c.common)
116 c.PulsarJobs = (*PulsarJobsService)(&c.common)
117 c.PulsarDecisions = (*PulsarDecisionsService)(&c.common)
118 c.Notifications = (*NotificationsService)(&c.common)
119 c.Records = (*RecordsService)(&c.common)
120 c.Applications = (*ApplicationsService)(&c.common)
121 c.RecordSearch = (*RecordSearchService)(&c.common)
122 c.ZoneSearch = (*ZoneSearchService)(&c.common)
123 c.Settings = (*SettingsService)(&c.common)
124 c.Stats = (*StatsService)(&c.common)
125 c.Teams = (*TeamsService)(&c.common)
126 c.Users = (*UsersService)(&c.common)
127 c.Warnings = (*WarningsService)(&c.common)
128 c.Zones = (*ZonesService)(&c.common)
129 c.Versions = (*VersionsService)(&c.common)
130 c.DNSSEC = (*DNSSECService)(&c.common)
131 c.TSIG = (*TsigService)(&c.common)
132 c.View = (*DNSViewService)(&c.common)
133 c.Network = (*NetworkService)(&c.common)
134 c.GlobalIPWhitelist = (*GlobalIPWhitelistService)(&c.common)
135 c.Datasets = (*DatasetsService)(&c.common)
136 c.Activity = (*ActivityService)(&c.common)
137 c.Redirects = (*RedirectService)(&c.common)
138 c.RedirectCertificates = (*RedirectCertificateService)(&c.common)
139 c.Alerts = (*AlertsService)(&c.common)
140 c.BillingUsage = (*BillingUsageService)(&c.common)
141
142 for _, option := range options {
143 option(c)
144 }
145 return c
146 }
147
148 type service struct {
149 client *Client
150 }
151
152 // SetHTTPClient sets a Client instances' httpClient.
153 func SetHTTPClient(httpClient Doer) func(*Client) {
154 return func(c *Client) { c.httpClient = httpClient }
155 }
156
157 // SetAPIKey sets a Client instances' APIKey.
158 func SetAPIKey(key string) func(*Client) {
159 return func(c *Client) { c.APIKey = key }
160 }
161
162 // SetEndpoint sets a Client instances' Endpoint.
163 func SetEndpoint(endpoint string) func(*Client) {
164 return func(c *Client) { c.Endpoint, _ = url.Parse(endpoint) }
165 }
166
167 // SetUserAgent sets a Client instances' user agent.
168 func SetUserAgent(ua string) func(*Client) {
169 return func(c *Client) { c.UserAgent = ua }
170 }
171
172 // SetRateLimitFunc sets a Client instances' RateLimitFunc.
173 func SetRateLimitFunc(ratefunc func(rl RateLimit)) func(*Client) {
174 return func(c *Client) { c.RateLimitFunc = ratefunc }
175 }
176
177 // SetFollowPagination sets a Client instances' FollowPagination attribute.
178 func SetFollowPagination(shouldFollow bool) func(*Client) {
179 return func(c *Client) { c.FollowPagination = shouldFollow }
180 }
181
182 // Param is a container struct which holds a `Key` and `Value` field corresponding to the values of a URL parameter.
183 type Param struct {
184 Key, Value string
185 }
186
187 // Do satisfies the Doer interface. resp will be nil if a non-HTTP error
188 // occurs, otherwise it is available for inspection when the error reflects a
189 // non-2XX response. It accepts a variadic number of optional URL parameters to
190 // supply to the request. URL parameters are of type `rest.Param`.
191 func (c Client) Do(req *http.Request, v interface{}, params ...Param) (*http.Response, error) {
192 q := req.URL.Query()
193 for _, p := range params {
194 q.Set(p.Key, p.Value)
195 }
196 req.URL.RawQuery = q.Encode()
197
198 resp, err := c.httpClient.Do(req)
199 if err != nil {
200 return nil, err
201 }
202 defer resp.Body.Close()
203
204 rl := parseRate(resp)
205 c.RateLimitFunc(rl)
206
207 err = CheckResponse(resp)
208 if err != nil {
209 return resp, err
210 }
211
212 if v != nil {
213 // For non-JSON responses, the desired destination might be a bytes buffer
214 if buf, ok := v.(*bytes.Buffer); ok {
215 if _, err := io.Copy(buf, resp.Body); err != nil {
216 return nil, err
217 }
218 return resp, err
219 }
220
221 // Try to unmarshal body into given type using streaming decoder.
222 if err := json.NewDecoder(resp.Body).Decode(&v); err != nil {
223 return nil, err
224 }
225 }
226
227 return resp, err
228 }
229
230 // NextFunc knows how to get and parse additional info from uri into v.
231 type NextFunc func(v *interface{}, uri string) (*http.Response, error)
232
233 // DoWithPagination Does, and follows Link headers for pagination. The returned
234 // Response is from the last URI visited - either the last page, or one that
235 // responded with a non-2XX status. If a non-HTTP error occurs, resp will be
236 // nil. It accepts a variadic number of optional URL parameters to supply to
237 // the underlying `.Do()` method request(s). URL parameters are of type
238 // `rest.Param`.
239 func (c Client) DoWithPagination(req *http.Request, v interface{}, f NextFunc, params ...Param) (*http.Response, error) {
240 resp, err := c.Do(req, v, params...)
241 if err != nil {
242 return resp, err
243 }
244
245 // See PLAT-188
246 forceHTTPS := c.Endpoint.Scheme == "https"
247
248 nextURI := ParseLink(resp.Header.Get("Link"), forceHTTPS).Next()
249 for nextURI != "" {
250 resp, err = f(&v, nextURI)
251 if err != nil {
252 return resp, err
253 }
254 nextURI = ParseLink(resp.Header.Get("Link"), forceHTTPS).Next()
255 }
256 return resp, nil
257 }
258
259 // NewRequest constructs and returns a http.Request.
260 func (c *Client) NewRequest(method, path string, body interface{}) (*http.Request, error) {
261 rel, err := url.Parse(path)
262 if err != nil {
263 return nil, err
264 }
265
266 uri := c.Endpoint.ResolveReference(rel)
267
268 // Encode body as json
269 buf := new(bytes.Buffer)
270 if body != nil {
271 err := json.NewEncoder(buf).Encode(body)
272 if err != nil {
273 return nil, err
274 }
275 }
276
277 req, err := http.NewRequest(method, uri.String(), buf)
278 if err != nil {
279 return nil, err
280 }
281
282 req.Header.Add(headerAuth, c.APIKey)
283 req.Header.Add("User-Agent", c.UserAgent)
284 return req, nil
285 }
286
287 // Response wraps stdlib http response.
288 type Response struct {
289 *http.Response
290 }
291
292 // Error contains all http responses outside the 2xx range.
293 type Error struct {
294 Resp *http.Response
295 Message string
296 }
297
298 // Satisfy std lib error interface.
299 func (re *Error) Error() string {
300 return fmt.Sprintf("%v %v: %d %v", re.Resp.Request.Method, re.Resp.Request.URL, re.Resp.StatusCode, re.Message)
301 }
302
303 // CheckResponse handles parsing of rest api errors. Returns nil if no error.
304 func CheckResponse(resp *http.Response) error {
305 if c := resp.StatusCode; c >= 200 && c <= 299 {
306 return nil
307 }
308
309 restErr := &Error{Resp: resp}
310
311 msgBody, err := io.ReadAll(resp.Body)
312 if err != nil {
313 return err
314 }
315 if len(msgBody) == 0 {
316 return restErr
317 }
318
319 err = json.Unmarshal(msgBody, restErr)
320 if err != nil {
321 restErr.Message = string(msgBody)
322 return restErr
323 }
324
325 return restErr
326 }
327
328 // Helper function for parsing API responses for a specific error.
329 // Ideally this would take place in CheckResponse above rather than
330 // in each caller.
331 var resourceMissingMatch = regexp.MustCompile(` not found`).MatchString
332
333 // RateLimitFunc is rate limiting strategy for the Client instance.
334 type RateLimitFunc func(RateLimit)
335
336 // RateLimit stores X-Ratelimit-* headers
337 type RateLimit struct {
338 Limit int
339 Remaining int
340 Period int
341 }
342
343 var defaultRateLimitFunc = func(rl RateLimit) {}
344
345 // PercentageLeft returns the ratio of Remaining to Limit as a percentage
346 func (rl RateLimit) PercentageLeft() int {
347 return rl.Remaining * 100 / rl.Limit
348 }
349
350 // WaitTime returns the time.Duration ratio of Period to Limit
351 func (rl RateLimit) WaitTime() time.Duration {
352 if rl.Limit == 0 || rl.Period == 0 {
353 // rate-limit headers missing or corrupt, punt
354 return defaultRateLimitWaitTime
355 }
356 return (time.Second * time.Duration(rl.Period)) / time.Duration(rl.Limit)
357 }
358
359 // WaitTimeRemaining returns the time.Duration ratio of Period to Remaining
360 func (rl RateLimit) WaitTimeRemaining() time.Duration {
361 if rl.Remaining < 2 {
362 return time.Second * time.Duration(rl.Period)
363 }
364 return (time.Second * time.Duration(rl.Period)) / time.Duration(rl.Remaining)
365 }
366
367 // RateLimitStrategySleep sets RateLimitFunc to sleep by WaitTimeRemaining
368 func (c *Client) RateLimitStrategySleep() {
369 c.RateLimitFunc = func(rl RateLimit) {
370 remaining := rl.WaitTimeRemaining()
371 time.Sleep(remaining)
372 }
373 }
374
375 // RateLimitStrategyConcurrent sleeps for WaitTime * parallelism when
376 // remaining is less than or equal to parallelism.
377 func (c *Client) RateLimitStrategyConcurrent(parallelism int) {
378 c.RateLimitFunc = func(rl RateLimit) {
379 if rl.Remaining <= parallelism {
380 wait := rl.WaitTime() * time.Duration(parallelism)
381 time.Sleep(wait)
382 }
383 }
384 }
385
386 // parseRate parses rate related headers from http response.
387 func parseRate(resp *http.Response) RateLimit {
388 var rl RateLimit
389
390 if limit := resp.Header.Get(headerRateLimit); limit != "" {
391 rl.Limit, _ = strconv.Atoi(limit)
392 }
393 if remaining := resp.Header.Get(headerRateRemaining); remaining != "" {
394 rl.Remaining, _ = strconv.Atoi(remaining)
395 }
396 if period := resp.Header.Get(headerRatePeriod); period != "" {
397 rl.Period, _ = strconv.Atoi(period)
398 }
399
400 return rl
401 }
402
403 // SetTimeParam sets a url timestamp query param given the parameters name.
404 func SetTimeParam(key string, t time.Time) func(*url.Values) {
405 return func(v *url.Values) { v.Set(key, strconv.Itoa(int(t.Unix()))) }
406 }
407
408 // SetBoolParam sets a url boolean query param given the parameters name.
409 func SetBoolParam(key string, b bool) func(*url.Values) {
410 return func(v *url.Values) { v.Set(key, strconv.FormatBool(b)) }
411 }
412
413 // SetStringParam sets a url string query param given the parameters name.
414 func SetStringParam(key, val string) func(*url.Values) {
415 return func(v *url.Values) { v.Set(key, val) }
416 }
417
418 // SetIntParam sets a url integer query param given the parameters name.
419 func SetIntParam(key string, val int) func(*url.Values) {
420 return func(v *url.Values) { v.Set(key, strconv.Itoa(val)) }
421 }
422
423 func (c *Client) getURI(v interface{}, uri string) (*http.Response, error) {
424 req, err := c.NewRequest("GET", uri, nil)
425 if err != nil {
426 return nil, err
427 }
428 // For non-2XX responses, Do returns the response as well as an error, for
429 // other errs, resp will be nil. Caller's responsibility to sort that out.
430 return c.Do(req, v)
431 }
432