client.go raw
1 package scw
2
3 import (
4 "context"
5 "crypto/tls"
6 "encoding/json"
7 "fmt"
8 "io"
9 "math"
10 "net/http"
11 "reflect"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16
17 "github.com/scaleway/scaleway-sdk-go/errors"
18 "github.com/scaleway/scaleway-sdk-go/internal/auth"
19 "github.com/scaleway/scaleway-sdk-go/internal/generic"
20 "github.com/scaleway/scaleway-sdk-go/logger"
21 )
22
23 // Client is the Scaleway client which performs API requests.
24 //
25 // This client should be passed in the `NewApi` functions whenever an API instance is created.
26 // Creating a Client is done with the `NewClient` function.
27 type Client struct {
28 httpClient httpClient
29 auth auth.Auth
30 apiURL string
31 userAgent string
32 defaultOrganizationID *string
33 defaultProjectID *string
34 defaultRegion *Region
35 defaultZone *Zone
36 defaultPageSize *uint32
37 }
38
39 func defaultOptions() []ClientOption {
40 return []ClientOption{
41 WithoutAuth(),
42 WithAPIURL("https://api.scaleway.com"),
43 withDefaultUserAgent(userAgent),
44 }
45 }
46
47 // NewClient instantiate a new Client object.
48 //
49 // Zero or more ClientOption object can be passed as a parameter.
50 // These options will then be applied to the client.
51 func NewClient(opts ...ClientOption) (*Client, error) {
52 s := newSettings()
53
54 // apply options
55 s.apply(append(defaultOptions(), opts...))
56
57 // validate settings
58 err := s.validate()
59 if err != nil {
60 return nil, err
61 }
62
63 // dial the API
64 if s.httpClient == nil {
65 s.httpClient = newHTTPClient()
66 }
67
68 // insecure mode
69 if s.insecure {
70 logger.Debugf("client: using insecure mode\n")
71 setInsecureMode(s.httpClient)
72 }
73
74 if logger.ShouldLog(logger.LogLevelDebug) {
75 logger.Debugf("client: using request logger\n")
76 setRequestLogging(s.httpClient)
77 }
78
79 logger.Debugf("client: using sdk version " + getVersion() + "\n")
80
81 return &Client{
82 auth: s.token,
83 httpClient: s.httpClient,
84 apiURL: s.apiURL,
85 userAgent: s.userAgent,
86 defaultOrganizationID: s.defaultOrganizationID,
87 defaultProjectID: s.defaultProjectID,
88 defaultRegion: s.defaultRegion,
89 defaultZone: s.defaultZone,
90 defaultPageSize: s.defaultPageSize,
91 }, nil
92 }
93
94 // GetDefaultOrganizationID returns the default organization ID
95 // of the client. This value can be set in the client option
96 // WithDefaultOrganizationID(). Be aware this value can be empty.
97 func (c *Client) GetDefaultOrganizationID() (organizationID string, exists bool) {
98 if c.defaultOrganizationID != nil {
99 return *c.defaultOrganizationID, true
100 }
101 return "", false
102 }
103
104 // GetDefaultProjectID returns the default project ID
105 // of the client. This value can be set in the client option
106 // WithDefaultProjectID(). Be aware this value can be empty.
107 func (c *Client) GetDefaultProjectID() (projectID string, exists bool) {
108 if c.defaultProjectID != nil {
109 return *c.defaultProjectID, true
110 }
111 return "", false
112 }
113
114 // GetDefaultRegion returns the default region of the client.
115 // This value can be set in the client option
116 // WithDefaultRegion(). Be aware this value can be empty.
117 func (c *Client) GetDefaultRegion() (region Region, exists bool) {
118 if c.defaultRegion != nil {
119 return *c.defaultRegion, true
120 }
121 return Region(""), false
122 }
123
124 // GetDefaultZone returns the default zone of the client.
125 // This value can be set in the client option
126 // WithDefaultZone(). Be aware this value can be empty.
127 func (c *Client) GetDefaultZone() (zone Zone, exists bool) {
128 if c.defaultZone != nil {
129 return *c.defaultZone, true
130 }
131 return Zone(""), false
132 }
133
134 func (c *Client) GetSecretKey() (secretKey string, exists bool) {
135 if token, isToken := c.auth.(*auth.Token); isToken {
136 return token.SecretKey, isToken
137 }
138 return "", false
139 }
140
141 func (c *Client) GetAccessKey() (accessKey string, exists bool) {
142 if token, isToken := c.auth.(*auth.Token); isToken {
143 return token.AccessKey, isToken
144 } else if token, isAccessKey := c.auth.(*auth.AccessKeyOnly); isAccessKey {
145 return token.AccessKey, isAccessKey
146 }
147
148 return "", false
149 }
150
151 // GetDefaultPageSize returns the default page size of the client.
152 // This value can be set in the client option
153 // WithDefaultPageSize(). Be aware this value can be empty.
154 func (c *Client) GetDefaultPageSize() (pageSize uint32, exists bool) {
155 if c.defaultPageSize != nil {
156 return *c.defaultPageSize, true
157 }
158 return 0, false
159 }
160
161 // Do performs HTTP request(s) based on the ScalewayRequest object.
162 // RequestOptions are applied prior to doing the request.
163 func (c *Client) Do(req *ScalewayRequest, res any, opts ...RequestOption) (err error) {
164 // apply request options
165 req.apply(opts)
166
167 // validate request options
168 err = req.validate()
169 if err != nil {
170 return err
171 }
172
173 if req.auth == nil {
174 req.auth = c.auth
175 }
176
177 if req.zones != nil {
178 return c.doListZones(req, res, req.zones)
179 }
180 if req.regions != nil {
181 return c.doListRegions(req, res, req.regions)
182 }
183
184 if req.allPages {
185 return c.doListAll(req, res)
186 }
187
188 return c.do(req, res)
189 }
190
191 // do performs a single HTTP request based on the ScalewayRequest object.
192 func (c *Client) do(req *ScalewayRequest, res any) (sdkErr error) {
193 if req == nil {
194 return errors.New("request must be non-nil")
195 }
196
197 // build url
198 url, sdkErr := req.getURL(c.apiURL)
199 if sdkErr != nil {
200 return sdkErr
201 }
202 logger.Debugf("creating %s request on %s\n", req.Method, url.String())
203
204 // build request
205 ctx := req.ctx
206 if ctx == nil {
207 ctx = context.Background()
208 }
209 httpRequest, err := http.NewRequestWithContext(ctx, req.Method, url.String(), req.Body)
210 if err != nil {
211 return errors.Wrap(err, "could not create request")
212 }
213
214 httpRequest.Header = req.getAllHeaders(req.auth, c.userAgent, false)
215
216 // execute request
217 httpResponse, err := c.httpClient.Do(httpRequest)
218 if err != nil {
219 return errors.Wrap(err, "error executing request")
220 }
221
222 defer func() {
223 closeErr := httpResponse.Body.Close()
224 if sdkErr == nil && closeErr != nil {
225 sdkErr = errors.Wrap(closeErr, "could not close http response")
226 }
227 }()
228
229 sdkErr = hasResponseError(httpResponse)
230 if sdkErr != nil {
231 return sdkErr
232 }
233
234 if res != nil && httpResponse.ContentLength != 0 {
235 contentType := httpResponse.Header.Get("Content-Type")
236
237 if strings.HasPrefix(contentType, "application/json") {
238 err = json.NewDecoder(httpResponse.Body).Decode(&res)
239 if err != nil {
240 return errors.Wrap(err, "could not parse %s response body", contentType)
241 }
242 } else {
243 buffer, isBuffer := res.(io.Writer)
244 if !isBuffer {
245 return errors.Wrap(err, "could not handle %s response body with %T result type", contentType, buffer)
246 }
247
248 _, err := io.Copy(buffer, httpResponse.Body)
249 if err != nil {
250 return errors.Wrap(err, "could not copy %s response body", contentType)
251 }
252 }
253
254 // Handle instance API X-Total-Count header
255 xTotalCountStr := httpResponse.Header.Get("X-Total-Count")
256 if legacyLister, isLegacyLister := res.(legacyLister); isLegacyLister && xTotalCountStr != "" {
257 xTotalCount, err := strconv.ParseInt(xTotalCountStr, 10, 32)
258 if err != nil {
259 return errors.Wrap(err, "could not parse X-Total-Count header")
260 }
261 legacyLister.UnsafeSetTotalCount(int(xTotalCount))
262 }
263 }
264
265 return nil
266 }
267
268 type lister interface {
269 UnsafeGetTotalCount() uint64
270 UnsafeAppend(any) (uint64, error)
271 }
272
273 // Old lister for uint32
274 // Used for retro-compatibility with response that use uint32
275 type lister32 interface {
276 UnsafeGetTotalCount() uint32
277 UnsafeAppend(any) (uint32, error)
278 }
279
280 type legacyLister interface {
281 UnsafeSetTotalCount(totalCount int)
282 }
283
284 func listerGetTotalCount(i any) uint64 {
285 if l, isLister := i.(lister); isLister {
286 return l.UnsafeGetTotalCount()
287 }
288 if l32, isLister32 := i.(lister32); isLister32 {
289 return uint64(l32.UnsafeGetTotalCount())
290 }
291 panic(fmt.Errorf("%T does not support pagination but checks failed, should not happen", i))
292 }
293
294 func listerAppend(recv any, elems any) (uint64, error) {
295 if l, isLister := recv.(lister); isLister {
296 return l.UnsafeAppend(elems)
297 } else if l32, isLister32 := recv.(lister32); isLister32 {
298 total, err := l32.UnsafeAppend(elems)
299 return uint64(total), err
300 }
301
302 panic(fmt.Errorf("%T does not support pagination but checks failed, should not happen", recv))
303 }
304
305 func isLister(i any) bool {
306 switch i.(type) {
307 case lister:
308 return true
309 case lister32:
310 return true
311 default:
312 return false
313 }
314 }
315
316 const maxPageCount uint64 = math.MaxUint32
317
318 // doListAll collects all pages of a List request and aggregate all results on a single response.
319 func (c *Client) doListAll(req *ScalewayRequest, res any) (err error) {
320 // check for lister interface
321 if isLister(res) {
322 pageCount := maxPageCount
323 for page := uint64(1); page <= pageCount; page++ {
324 // set current page
325 req.Query.Set("page", strconv.FormatUint(page, 10))
326
327 // request the next page
328 nextPage := newVariableFromType(res)
329 err := c.do(req, nextPage)
330 if err != nil {
331 return err
332 }
333
334 // append results
335 pageSize, err := listerAppend(res, nextPage)
336 if err != nil {
337 return err
338 }
339
340 if pageSize == 0 {
341 return nil
342 }
343
344 // set total count on first request
345 if pageCount == maxPageCount {
346 totalCount := listerGetTotalCount(nextPage)
347 pageCount = (totalCount + pageSize - 1) / pageSize
348 }
349 }
350 return nil
351 }
352
353 return errors.New("%T does not support pagination", res)
354 }
355
356 // doListLocalities collects all localities using multiple list requests and aggregate all results on a lister response
357 // results is sorted by locality
358 func (c *Client) doListLocalities(req *ScalewayRequest, res any, localities []string) (err error) {
359 path := req.Path
360 if !strings.Contains(path, "%locality%") {
361 return errors.New("request is not a valid locality request")
362 }
363 // Requests are parallelized
364 responseMutex := sync.Mutex{}
365 requestGroup := sync.WaitGroup{}
366 errChan := make(chan error, len(localities))
367
368 requestGroup.Add(len(localities))
369 for _, locality := range localities {
370 go func(locality string) {
371 defer requestGroup.Done()
372 // Request is cloned as doListAll will change header
373 // We remove zones as it would recurse in the same function
374 req := req.clone()
375 req.zones = []Zone(nil)
376 req.Path = strings.ReplaceAll(path, "%locality%", locality)
377
378 // We create a new response that we append to main response
379 zoneResponse := newVariableFromType(res)
380 err := c.Do(req, zoneResponse)
381 if err != nil {
382 errChan <- err
383 }
384 responseMutex.Lock()
385 _, err = listerAppend(res, zoneResponse)
386 responseMutex.Unlock()
387 if err != nil {
388 errChan <- err
389 }
390 }(locality)
391 }
392 requestGroup.Wait()
393
394 L: // We gather potential errors and return them all together
395 for {
396 select {
397 case newErr := <-errChan:
398 err = errors.Wrap(err, "%s", newErr.Error())
399 default:
400 break L
401 }
402 }
403 close(errChan)
404 if err != nil {
405 return err
406 }
407 return nil
408 }
409
410 // doListZones collects all zones using multiple list requests and aggregate all results on a single response.
411 // result is sorted by zone
412 func (c *Client) doListZones(req *ScalewayRequest, res any, zones []Zone) (err error) {
413 if isLister(res) {
414 // Prepare request with %zone% that can be replaced with actual zone
415 for _, zone := range AllZones {
416 if strings.Contains(req.Path, string(zone)) {
417 req.Path = strings.ReplaceAll(req.Path, string(zone), "%locality%")
418 break
419 }
420 }
421 if !strings.Contains(req.Path, "%locality%") {
422 return errors.New("request is not a valid zoned request")
423 }
424 localities := make([]string, 0, len(zones))
425 for _, zone := range zones {
426 localities = append(localities, string(zone))
427 }
428
429 err := c.doListLocalities(req, res, localities)
430 if err != nil {
431 return fmt.Errorf("failed to list localities: %w", err)
432 }
433
434 sortResponseByZones(res, zones)
435 return nil
436 }
437
438 return errors.New("%T does not support pagination", res)
439 }
440
441 // doListRegions collects all regions using multiple list requests and aggregate all results on a single response.
442 // result is sorted by region
443 func (c *Client) doListRegions(req *ScalewayRequest, res any, regions []Region) (err error) {
444 if isLister(res) {
445 // Prepare request with %locality% that can be replaced with actual region
446 for _, region := range AllRegions {
447 if strings.Contains(req.Path, string(region)) {
448 req.Path = strings.ReplaceAll(req.Path, string(region), "%locality%")
449 break
450 }
451 }
452 if !strings.Contains(req.Path, "%locality%") {
453 return errors.New("request is not a valid zoned request")
454 }
455 localities := make([]string, 0, len(regions))
456 for _, region := range regions {
457 localities = append(localities, string(region))
458 }
459
460 err := c.doListLocalities(req, res, localities)
461 if err != nil {
462 return fmt.Errorf("failed to list localities: %w", err)
463 }
464
465 sortResponseByRegions(res, regions)
466 return nil
467 }
468
469 return errors.New("%T does not support pagination", res)
470 }
471
472 // sortSliceByZones sorts a slice of struct using a Zone field that should exist
473 func sortSliceByZones(list any, zones []Zone) {
474 if !generic.HasField(list, "Zone") {
475 return
476 }
477
478 zoneMap := map[Zone]int{}
479 for i, zone := range zones {
480 zoneMap[zone] = i
481 }
482 generic.SortSliceByField(list, "Zone", func(i any, i2 any) bool {
483 return zoneMap[i.(Zone)] < zoneMap[i2.(Zone)]
484 })
485 }
486
487 // sortSliceByRegions sorts a slice of struct using a Region field that should exist
488 func sortSliceByRegions(list any, regions []Region) {
489 if !generic.HasField(list, "Region") {
490 return
491 }
492
493 regionMap := map[Region]int{}
494 for i, region := range regions {
495 regionMap[region] = i
496 }
497 generic.SortSliceByField(list, "Region", func(i any, i2 any) bool {
498 return regionMap[i.(Region)] < regionMap[i2.(Region)]
499 })
500 }
501
502 // sortResponseByZones find first field that is a slice in a struct and sort it by zone
503 func sortResponseByZones(res any, zones []Zone) {
504 // res may be ListServersResponse
505 //
506 // type ListServersResponse struct {
507 // TotalCount uint32 `json:"total_count"`
508 // Servers []*Server `json:"servers"`
509 // }
510 // We iterate over fields searching for the slice one to sort it
511 resType := reflect.TypeOf(res).Elem()
512 fields := reflect.VisibleFields(resType)
513 for _, field := range fields {
514 if field.Type.Kind() == reflect.Slice {
515 sortSliceByZones(reflect.ValueOf(res).Elem().FieldByName(field.Name).Interface(), zones)
516 return
517 }
518 }
519 }
520
521 // sortResponseByRegions find first field that is a slice in a struct and sort it by region
522 func sortResponseByRegions(res any, regions []Region) {
523 // res may be ListServersResponse
524 //
525 // type ListServersResponse struct {
526 // TotalCount uint32 `json:"total_count"`
527 // Servers []*Server `json:"servers"`
528 // }
529 // We iterate over fields searching for the slice one to sort it
530 resType := reflect.TypeOf(res).Elem()
531 fields := reflect.VisibleFields(resType)
532 for _, field := range fields {
533 if field.Type.Kind() == reflect.Slice {
534 sortSliceByRegions(reflect.ValueOf(res).Elem().FieldByName(field.Name).Interface(), regions)
535 return
536 }
537 }
538 }
539
540 // newVariableFromType returns a variable set to the zero value of the given type
541 func newVariableFromType(t any) any {
542 // reflect.New always create a pointer, that's why we use reflect.Indirect before
543 return reflect.New(reflect.Indirect(reflect.ValueOf(t)).Type()).Interface()
544 }
545
546 func newHTTPClient() *http.Client {
547 return &http.Client{
548 Timeout: 30 * time.Second,
549 Transport: http.DefaultTransport.(*http.Transport).Clone(),
550 }
551 }
552
553 func setInsecureMode(c httpClient) {
554 standardHTTPClient, ok := c.(*http.Client)
555 if !ok {
556 logger.Warningf("client: cannot use insecure mode with HTTP client of type %T", c)
557 return
558 }
559
560 altTransport, ok := standardHTTPClient.Transport.(interface {
561 SetInsecureTransport()
562 })
563 if ok {
564 altTransport.SetInsecureTransport()
565 return
566 }
567
568 transportClient, ok := standardHTTPClient.Transport.(*http.Transport)
569 if !ok {
570 logger.Warningf("client: cannot use insecure mode with Transport client of type %T", standardHTTPClient.Transport)
571 return
572 }
573 if transportClient.TLSClientConfig == nil {
574 transportClient.TLSClientConfig = &tls.Config{}
575 }
576 transportClient.TLSClientConfig.InsecureSkipVerify = true
577 }
578
579 func setRequestLogging(c httpClient) {
580 standardHTTPClient, ok := c.(*http.Client)
581 if !ok {
582 logger.Warningf("client: cannot use request logger with HTTP client of type %T", c)
583 return
584 }
585 // Do not wrap transport if it is already a logger
586 // As client is a pointer, changing transport will change given client
587 // If the same httpClient is used in multiple scwClient, it would add multiple logger transports
588 _, isLogger := standardHTTPClient.Transport.(*requestLoggerTransport)
589 if !isLogger {
590 standardHTTPClient.Transport = &requestLoggerTransport{rt: standardHTTPClient.Transport}
591 }
592 }
593