client.go raw

   1  package rest
   2  
   3  import (
   4  	"bytes"
   5  	"encoding/json"
   6  	"fmt"
   7  	"io"
   8  	"net/http"
   9  	"net/url"
  10  	"regexp"
  11  	"strconv"
  12  	"time"
  13  )
  14  
  15  const (
  16  	clientVersion = "2.16.0"
  17  
  18  	defaultBase                   = "https://api.nsone.net"
  19  	defaultEndpoint               = defaultBase + "/v1/"
  20  	defaultShouldFollowPagination = true
  21  	defaultUserAgent              = "go-ns1/" + clientVersion
  22  
  23  	headerAuth          = "X-NSONE-Key"
  24  	headerRateLimit     = "X-Ratelimit-Limit"
  25  	headerRateRemaining = "X-Ratelimit-Remaining"
  26  	headerRatePeriod    = "X-Ratelimit-Period"
  27  
  28  	defaultRateLimitWaitTime = time.Millisecond * 100
  29  )
  30  
  31  // Doer is a single method interface that allows a user to extend/augment an http.Client instance.
  32  // Note: http.Client satisfies the Doer interface.
  33  type Doer interface {
  34  	Do(*http.Request) (*http.Response, error)
  35  }
  36  
  37  // Client manages communication with the NS1 Rest API.
  38  type Client struct {
  39  	// httpClient handles all rest api communication,
  40  	// and expects an *http.Client.
  41  	httpClient Doer
  42  
  43  	// NS1 rest endpoint, overrides default if given.
  44  	Endpoint *url.URL
  45  
  46  	// NS1 api key (value for http request header 'X-NSONE-Key').
  47  	APIKey string
  48  
  49  	// NS1 go rest user agent (value for http request header 'User-Agent').
  50  	UserAgent string
  51  
  52  	// Func to call after response is returned in Do
  53  	RateLimitFunc func(RateLimit)
  54  
  55  	// Whether the client should handle paginated responses automatically.
  56  	FollowPagination bool
  57  
  58  	// From the excellent github-go client.
  59  	common service // Reuse a single struct instead of allocating one for each service on the heap.
  60  
  61  	// Services used for communicating with different components of the NS1 API.
  62  	APIKeys              *APIKeysService
  63  	DataFeeds            *DataFeedsService
  64  	DataSources          *DataSourcesService
  65  	Jobs                 *JobsService
  66  	MonitorRegions       *MonitorRegionsService
  67  	PulsarJobs           *PulsarJobsService
  68  	PulsarDecisions      *PulsarDecisionsService
  69  	Notifications        *NotificationsService
  70  	Records              *RecordsService
  71  	Applications         *ApplicationsService
  72  	RecordSearch         *RecordSearchService
  73  	ZoneSearch           *ZoneSearchService
  74  	Settings             *SettingsService
  75  	Stats                *StatsService
  76  	Teams                *TeamsService
  77  	Users                *UsersService
  78  	Warnings             *WarningsService
  79  	Zones                *ZonesService
  80  	Versions             *VersionsService
  81  	DNSSEC               *DNSSECService
  82  	TSIG                 *TsigService
  83  	View                 *DNSViewService
  84  	Network              *NetworkService
  85  	GlobalIPWhitelist    *GlobalIPWhitelistService
  86  	Datasets             *DatasetsService
  87  	Activity             *ActivityService
  88  	Redirects            *RedirectService
  89  	RedirectCertificates *RedirectCertificateService
  90  	Alerts               *AlertsService
  91  	BillingUsage         *BillingUsageService
  92  }
  93  
  94  // NewClient constructs and returns a reference to an instantiated Client.
  95  func NewClient(httpClient Doer, options ...func(*Client)) *Client {
  96  	endpoint, _ := url.Parse(defaultEndpoint)
  97  
  98  	if httpClient == nil {
  99  		httpClient = http.DefaultClient
 100  	}
 101  
 102  	c := &Client{
 103  		httpClient:       httpClient,
 104  		Endpoint:         endpoint,
 105  		RateLimitFunc:    defaultRateLimitFunc,
 106  		UserAgent:        defaultUserAgent,
 107  		FollowPagination: defaultShouldFollowPagination,
 108  	}
 109  
 110  	c.common.client = c
 111  	c.APIKeys = (*APIKeysService)(&c.common)
 112  	c.DataFeeds = (*DataFeedsService)(&c.common)
 113  	c.DataSources = (*DataSourcesService)(&c.common)
 114  	c.Jobs = (*JobsService)(&c.common)
 115  	c.MonitorRegions = (*MonitorRegionsService)(&c.common)
 116  	c.PulsarJobs = (*PulsarJobsService)(&c.common)
 117  	c.PulsarDecisions = (*PulsarDecisionsService)(&c.common)
 118  	c.Notifications = (*NotificationsService)(&c.common)
 119  	c.Records = (*RecordsService)(&c.common)
 120  	c.Applications = (*ApplicationsService)(&c.common)
 121  	c.RecordSearch = (*RecordSearchService)(&c.common)
 122  	c.ZoneSearch = (*ZoneSearchService)(&c.common)
 123  	c.Settings = (*SettingsService)(&c.common)
 124  	c.Stats = (*StatsService)(&c.common)
 125  	c.Teams = (*TeamsService)(&c.common)
 126  	c.Users = (*UsersService)(&c.common)
 127  	c.Warnings = (*WarningsService)(&c.common)
 128  	c.Zones = (*ZonesService)(&c.common)
 129  	c.Versions = (*VersionsService)(&c.common)
 130  	c.DNSSEC = (*DNSSECService)(&c.common)
 131  	c.TSIG = (*TsigService)(&c.common)
 132  	c.View = (*DNSViewService)(&c.common)
 133  	c.Network = (*NetworkService)(&c.common)
 134  	c.GlobalIPWhitelist = (*GlobalIPWhitelistService)(&c.common)
 135  	c.Datasets = (*DatasetsService)(&c.common)
 136  	c.Activity = (*ActivityService)(&c.common)
 137  	c.Redirects = (*RedirectService)(&c.common)
 138  	c.RedirectCertificates = (*RedirectCertificateService)(&c.common)
 139  	c.Alerts = (*AlertsService)(&c.common)
 140  	c.BillingUsage = (*BillingUsageService)(&c.common)
 141  
 142  	for _, option := range options {
 143  		option(c)
 144  	}
 145  	return c
 146  }
 147  
 148  type service struct {
 149  	client *Client
 150  }
 151  
 152  // SetHTTPClient sets a Client instances' httpClient.
 153  func SetHTTPClient(httpClient Doer) func(*Client) {
 154  	return func(c *Client) { c.httpClient = httpClient }
 155  }
 156  
 157  // SetAPIKey sets a Client instances' APIKey.
 158  func SetAPIKey(key string) func(*Client) {
 159  	return func(c *Client) { c.APIKey = key }
 160  }
 161  
 162  // SetEndpoint sets a Client instances' Endpoint.
 163  func SetEndpoint(endpoint string) func(*Client) {
 164  	return func(c *Client) { c.Endpoint, _ = url.Parse(endpoint) }
 165  }
 166  
 167  // SetUserAgent sets a Client instances' user agent.
 168  func SetUserAgent(ua string) func(*Client) {
 169  	return func(c *Client) { c.UserAgent = ua }
 170  }
 171  
 172  // SetRateLimitFunc sets a Client instances' RateLimitFunc.
 173  func SetRateLimitFunc(ratefunc func(rl RateLimit)) func(*Client) {
 174  	return func(c *Client) { c.RateLimitFunc = ratefunc }
 175  }
 176  
 177  // SetFollowPagination sets a Client instances' FollowPagination attribute.
 178  func SetFollowPagination(shouldFollow bool) func(*Client) {
 179  	return func(c *Client) { c.FollowPagination = shouldFollow }
 180  }
 181  
 182  // Param is a container struct which holds a `Key` and `Value` field corresponding to the values of a URL parameter.
 183  type Param struct {
 184  	Key, Value string
 185  }
 186  
 187  // Do satisfies the Doer interface. resp will be nil if a non-HTTP error
 188  // occurs, otherwise it is available for inspection when the error reflects a
 189  // non-2XX response. It accepts a variadic number of optional URL parameters to
 190  // supply to the request. URL parameters are of type `rest.Param`.
 191  func (c Client) Do(req *http.Request, v interface{}, params ...Param) (*http.Response, error) {
 192  	q := req.URL.Query()
 193  	for _, p := range params {
 194  		q.Set(p.Key, p.Value)
 195  	}
 196  	req.URL.RawQuery = q.Encode()
 197  
 198  	resp, err := c.httpClient.Do(req)
 199  	if err != nil {
 200  		return nil, err
 201  	}
 202  	defer resp.Body.Close()
 203  
 204  	rl := parseRate(resp)
 205  	c.RateLimitFunc(rl)
 206  
 207  	err = CheckResponse(resp)
 208  	if err != nil {
 209  		return resp, err
 210  	}
 211  
 212  	if v != nil {
 213  		// For non-JSON responses, the desired destination might be a bytes buffer
 214  		if buf, ok := v.(*bytes.Buffer); ok {
 215  			if _, err := io.Copy(buf, resp.Body); err != nil {
 216  				return nil, err
 217  			}
 218  			return resp, err
 219  		}
 220  
 221  		// Try to unmarshal body into given type using streaming decoder.
 222  		if err := json.NewDecoder(resp.Body).Decode(&v); err != nil {
 223  			return nil, err
 224  		}
 225  	}
 226  
 227  	return resp, err
 228  }
 229  
 230  // NextFunc knows how to get and parse additional info from uri into v.
 231  type NextFunc func(v *interface{}, uri string) (*http.Response, error)
 232  
 233  // DoWithPagination Does, and follows Link headers for pagination. The returned
 234  // Response is from the last URI visited - either the last page, or one that
 235  // responded with a non-2XX status. If a non-HTTP error occurs, resp will be
 236  // nil. It accepts a variadic number of optional URL parameters to supply to
 237  // the underlying `.Do()` method request(s). URL parameters are of type
 238  // `rest.Param`.
 239  func (c Client) DoWithPagination(req *http.Request, v interface{}, f NextFunc, params ...Param) (*http.Response, error) {
 240  	resp, err := c.Do(req, v, params...)
 241  	if err != nil {
 242  		return resp, err
 243  	}
 244  
 245  	// See PLAT-188
 246  	forceHTTPS := c.Endpoint.Scheme == "https"
 247  
 248  	nextURI := ParseLink(resp.Header.Get("Link"), forceHTTPS).Next()
 249  	for nextURI != "" {
 250  		resp, err = f(&v, nextURI)
 251  		if err != nil {
 252  			return resp, err
 253  		}
 254  		nextURI = ParseLink(resp.Header.Get("Link"), forceHTTPS).Next()
 255  	}
 256  	return resp, nil
 257  }
 258  
 259  // NewRequest constructs and returns a http.Request.
 260  func (c *Client) NewRequest(method, path string, body interface{}) (*http.Request, error) {
 261  	rel, err := url.Parse(path)
 262  	if err != nil {
 263  		return nil, err
 264  	}
 265  
 266  	uri := c.Endpoint.ResolveReference(rel)
 267  
 268  	// Encode body as json
 269  	buf := new(bytes.Buffer)
 270  	if body != nil {
 271  		err := json.NewEncoder(buf).Encode(body)
 272  		if err != nil {
 273  			return nil, err
 274  		}
 275  	}
 276  
 277  	req, err := http.NewRequest(method, uri.String(), buf)
 278  	if err != nil {
 279  		return nil, err
 280  	}
 281  
 282  	req.Header.Add(headerAuth, c.APIKey)
 283  	req.Header.Add("User-Agent", c.UserAgent)
 284  	return req, nil
 285  }
 286  
 287  // Response wraps stdlib http response.
 288  type Response struct {
 289  	*http.Response
 290  }
 291  
 292  // Error contains all http responses outside the 2xx range.
 293  type Error struct {
 294  	Resp    *http.Response
 295  	Message string
 296  }
 297  
 298  // Satisfy std lib error interface.
 299  func (re *Error) Error() string {
 300  	return fmt.Sprintf("%v %v: %d %v", re.Resp.Request.Method, re.Resp.Request.URL, re.Resp.StatusCode, re.Message)
 301  }
 302  
 303  // CheckResponse handles parsing of rest api errors. Returns nil if no error.
 304  func CheckResponse(resp *http.Response) error {
 305  	if c := resp.StatusCode; c >= 200 && c <= 299 {
 306  		return nil
 307  	}
 308  
 309  	restErr := &Error{Resp: resp}
 310  
 311  	msgBody, err := io.ReadAll(resp.Body)
 312  	if err != nil {
 313  		return err
 314  	}
 315  	if len(msgBody) == 0 {
 316  		return restErr
 317  	}
 318  
 319  	err = json.Unmarshal(msgBody, restErr)
 320  	if err != nil {
 321  		restErr.Message = string(msgBody)
 322  		return restErr
 323  	}
 324  
 325  	return restErr
 326  }
 327  
 328  // Helper function for parsing API responses for a specific error.
 329  // Ideally this would take place in CheckResponse above rather than
 330  // in each caller.
 331  var resourceMissingMatch = regexp.MustCompile(` not found`).MatchString
 332  
 333  // RateLimitFunc is rate limiting strategy for the Client instance.
 334  type RateLimitFunc func(RateLimit)
 335  
 336  // RateLimit stores X-Ratelimit-* headers
 337  type RateLimit struct {
 338  	Limit     int
 339  	Remaining int
 340  	Period    int
 341  }
 342  
 343  var defaultRateLimitFunc = func(rl RateLimit) {}
 344  
 345  // PercentageLeft returns the ratio of Remaining to Limit as a percentage
 346  func (rl RateLimit) PercentageLeft() int {
 347  	return rl.Remaining * 100 / rl.Limit
 348  }
 349  
 350  // WaitTime returns the time.Duration ratio of Period to Limit
 351  func (rl RateLimit) WaitTime() time.Duration {
 352  	if rl.Limit == 0 || rl.Period == 0 {
 353  		// rate-limit headers missing or corrupt, punt
 354  		return defaultRateLimitWaitTime
 355  	}
 356  	return (time.Second * time.Duration(rl.Period)) / time.Duration(rl.Limit)
 357  }
 358  
 359  // WaitTimeRemaining returns the time.Duration ratio of Period to Remaining
 360  func (rl RateLimit) WaitTimeRemaining() time.Duration {
 361  	if rl.Remaining < 2 {
 362  		return time.Second * time.Duration(rl.Period)
 363  	}
 364  	return (time.Second * time.Duration(rl.Period)) / time.Duration(rl.Remaining)
 365  }
 366  
 367  // RateLimitStrategySleep sets RateLimitFunc to sleep by WaitTimeRemaining
 368  func (c *Client) RateLimitStrategySleep() {
 369  	c.RateLimitFunc = func(rl RateLimit) {
 370  		remaining := rl.WaitTimeRemaining()
 371  		time.Sleep(remaining)
 372  	}
 373  }
 374  
 375  // RateLimitStrategyConcurrent sleeps for WaitTime * parallelism when
 376  // remaining is less than or equal to parallelism.
 377  func (c *Client) RateLimitStrategyConcurrent(parallelism int) {
 378  	c.RateLimitFunc = func(rl RateLimit) {
 379  		if rl.Remaining <= parallelism {
 380  			wait := rl.WaitTime() * time.Duration(parallelism)
 381  			time.Sleep(wait)
 382  		}
 383  	}
 384  }
 385  
 386  // parseRate parses rate related headers from http response.
 387  func parseRate(resp *http.Response) RateLimit {
 388  	var rl RateLimit
 389  
 390  	if limit := resp.Header.Get(headerRateLimit); limit != "" {
 391  		rl.Limit, _ = strconv.Atoi(limit)
 392  	}
 393  	if remaining := resp.Header.Get(headerRateRemaining); remaining != "" {
 394  		rl.Remaining, _ = strconv.Atoi(remaining)
 395  	}
 396  	if period := resp.Header.Get(headerRatePeriod); period != "" {
 397  		rl.Period, _ = strconv.Atoi(period)
 398  	}
 399  
 400  	return rl
 401  }
 402  
 403  // SetTimeParam sets a url timestamp query param given the parameters name.
 404  func SetTimeParam(key string, t time.Time) func(*url.Values) {
 405  	return func(v *url.Values) { v.Set(key, strconv.Itoa(int(t.Unix()))) }
 406  }
 407  
 408  // SetBoolParam sets a url boolean query param given the parameters name.
 409  func SetBoolParam(key string, b bool) func(*url.Values) {
 410  	return func(v *url.Values) { v.Set(key, strconv.FormatBool(b)) }
 411  }
 412  
 413  // SetStringParam sets a url string query param given the parameters name.
 414  func SetStringParam(key, val string) func(*url.Values) {
 415  	return func(v *url.Values) { v.Set(key, val) }
 416  }
 417  
 418  // SetIntParam sets a url integer query param given the parameters name.
 419  func SetIntParam(key string, val int) func(*url.Values) {
 420  	return func(v *url.Values) { v.Set(key, strconv.Itoa(val)) }
 421  }
 422  
 423  func (c *Client) getURI(v interface{}, uri string) (*http.Response, error) {
 424  	req, err := c.NewRequest("GET", uri, nil)
 425  	if err != nil {
 426  		return nil, err
 427  	}
 428  	// For non-2XX responses, Do returns the response as well as an error, for
 429  	// other errs, resp will be nil. Caller's responsibility to sort that out.
 430  	return c.Do(req, v)
 431  }
 432