client.go raw

   1  package scw
   2  
   3  import (
   4  	"context"
   5  	"crypto/tls"
   6  	"encoding/json"
   7  	"fmt"
   8  	"io"
   9  	"math"
  10  	"net/http"
  11  	"reflect"
  12  	"strconv"
  13  	"strings"
  14  	"sync"
  15  	"time"
  16  
  17  	"github.com/scaleway/scaleway-sdk-go/errors"
  18  	"github.com/scaleway/scaleway-sdk-go/internal/auth"
  19  	"github.com/scaleway/scaleway-sdk-go/internal/generic"
  20  	"github.com/scaleway/scaleway-sdk-go/logger"
  21  )
  22  
  23  // Client is the Scaleway client which performs API requests.
  24  //
  25  // This client should be passed in the `NewApi` functions whenever an API instance is created.
  26  // Creating a Client is done with the `NewClient` function.
  27  type Client struct {
  28  	httpClient            httpClient
  29  	auth                  auth.Auth
  30  	apiURL                string
  31  	userAgent             string
  32  	defaultOrganizationID *string
  33  	defaultProjectID      *string
  34  	defaultRegion         *Region
  35  	defaultZone           *Zone
  36  	defaultPageSize       *uint32
  37  }
  38  
  39  func defaultOptions() []ClientOption {
  40  	return []ClientOption{
  41  		WithoutAuth(),
  42  		WithAPIURL("https://api.scaleway.com"),
  43  		withDefaultUserAgent(userAgent),
  44  	}
  45  }
  46  
  47  // NewClient instantiate a new Client object.
  48  //
  49  // Zero or more ClientOption object can be passed as a parameter.
  50  // These options will then be applied to the client.
  51  func NewClient(opts ...ClientOption) (*Client, error) {
  52  	s := newSettings()
  53  
  54  	// apply options
  55  	s.apply(append(defaultOptions(), opts...))
  56  
  57  	// validate settings
  58  	err := s.validate()
  59  	if err != nil {
  60  		return nil, err
  61  	}
  62  
  63  	// dial the API
  64  	if s.httpClient == nil {
  65  		s.httpClient = newHTTPClient()
  66  	}
  67  
  68  	// insecure mode
  69  	if s.insecure {
  70  		logger.Debugf("client: using insecure mode\n")
  71  		setInsecureMode(s.httpClient)
  72  	}
  73  
  74  	if logger.ShouldLog(logger.LogLevelDebug) {
  75  		logger.Debugf("client: using request logger\n")
  76  		setRequestLogging(s.httpClient)
  77  	}
  78  
  79  	logger.Debugf("client: using sdk version " + getVersion() + "\n")
  80  
  81  	return &Client{
  82  		auth:                  s.token,
  83  		httpClient:            s.httpClient,
  84  		apiURL:                s.apiURL,
  85  		userAgent:             s.userAgent,
  86  		defaultOrganizationID: s.defaultOrganizationID,
  87  		defaultProjectID:      s.defaultProjectID,
  88  		defaultRegion:         s.defaultRegion,
  89  		defaultZone:           s.defaultZone,
  90  		defaultPageSize:       s.defaultPageSize,
  91  	}, nil
  92  }
  93  
  94  // GetDefaultOrganizationID returns the default organization ID
  95  // of the client. This value can be set in the client option
  96  // WithDefaultOrganizationID(). Be aware this value can be empty.
  97  func (c *Client) GetDefaultOrganizationID() (organizationID string, exists bool) {
  98  	if c.defaultOrganizationID != nil {
  99  		return *c.defaultOrganizationID, true
 100  	}
 101  	return "", false
 102  }
 103  
 104  // GetDefaultProjectID returns the default project ID
 105  // of the client. This value can be set in the client option
 106  // WithDefaultProjectID(). Be aware this value can be empty.
 107  func (c *Client) GetDefaultProjectID() (projectID string, exists bool) {
 108  	if c.defaultProjectID != nil {
 109  		return *c.defaultProjectID, true
 110  	}
 111  	return "", false
 112  }
 113  
 114  // GetDefaultRegion returns the default region of the client.
 115  // This value can be set in the client option
 116  // WithDefaultRegion(). Be aware this value can be empty.
 117  func (c *Client) GetDefaultRegion() (region Region, exists bool) {
 118  	if c.defaultRegion != nil {
 119  		return *c.defaultRegion, true
 120  	}
 121  	return Region(""), false
 122  }
 123  
 124  // GetDefaultZone returns the default zone of the client.
 125  // This value can be set in the client option
 126  // WithDefaultZone(). Be aware this value can be empty.
 127  func (c *Client) GetDefaultZone() (zone Zone, exists bool) {
 128  	if c.defaultZone != nil {
 129  		return *c.defaultZone, true
 130  	}
 131  	return Zone(""), false
 132  }
 133  
 134  func (c *Client) GetSecretKey() (secretKey string, exists bool) {
 135  	if token, isToken := c.auth.(*auth.Token); isToken {
 136  		return token.SecretKey, isToken
 137  	}
 138  	return "", false
 139  }
 140  
 141  func (c *Client) GetAccessKey() (accessKey string, exists bool) {
 142  	if token, isToken := c.auth.(*auth.Token); isToken {
 143  		return token.AccessKey, isToken
 144  	} else if token, isAccessKey := c.auth.(*auth.AccessKeyOnly); isAccessKey {
 145  		return token.AccessKey, isAccessKey
 146  	}
 147  
 148  	return "", false
 149  }
 150  
 151  // GetDefaultPageSize returns the default page size of the client.
 152  // This value can be set in the client option
 153  // WithDefaultPageSize(). Be aware this value can be empty.
 154  func (c *Client) GetDefaultPageSize() (pageSize uint32, exists bool) {
 155  	if c.defaultPageSize != nil {
 156  		return *c.defaultPageSize, true
 157  	}
 158  	return 0, false
 159  }
 160  
 161  // Do performs HTTP request(s) based on the ScalewayRequest object.
 162  // RequestOptions are applied prior to doing the request.
 163  func (c *Client) Do(req *ScalewayRequest, res any, opts ...RequestOption) (err error) {
 164  	// apply request options
 165  	req.apply(opts)
 166  
 167  	// validate request options
 168  	err = req.validate()
 169  	if err != nil {
 170  		return err
 171  	}
 172  
 173  	if req.auth == nil {
 174  		req.auth = c.auth
 175  	}
 176  
 177  	if req.zones != nil {
 178  		return c.doListZones(req, res, req.zones)
 179  	}
 180  	if req.regions != nil {
 181  		return c.doListRegions(req, res, req.regions)
 182  	}
 183  
 184  	if req.allPages {
 185  		return c.doListAll(req, res)
 186  	}
 187  
 188  	return c.do(req, res)
 189  }
 190  
 191  // do performs a single HTTP request based on the ScalewayRequest object.
 192  func (c *Client) do(req *ScalewayRequest, res any) (sdkErr error) {
 193  	if req == nil {
 194  		return errors.New("request must be non-nil")
 195  	}
 196  
 197  	// build url
 198  	url, sdkErr := req.getURL(c.apiURL)
 199  	if sdkErr != nil {
 200  		return sdkErr
 201  	}
 202  	logger.Debugf("creating %s request on %s\n", req.Method, url.String())
 203  
 204  	// build request
 205  	ctx := req.ctx
 206  	if ctx == nil {
 207  		ctx = context.Background()
 208  	}
 209  	httpRequest, err := http.NewRequestWithContext(ctx, req.Method, url.String(), req.Body)
 210  	if err != nil {
 211  		return errors.Wrap(err, "could not create request")
 212  	}
 213  
 214  	httpRequest.Header = req.getAllHeaders(req.auth, c.userAgent, false)
 215  
 216  	// execute request
 217  	httpResponse, err := c.httpClient.Do(httpRequest)
 218  	if err != nil {
 219  		return errors.Wrap(err, "error executing request")
 220  	}
 221  
 222  	defer func() {
 223  		closeErr := httpResponse.Body.Close()
 224  		if sdkErr == nil && closeErr != nil {
 225  			sdkErr = errors.Wrap(closeErr, "could not close http response")
 226  		}
 227  	}()
 228  
 229  	sdkErr = hasResponseError(httpResponse)
 230  	if sdkErr != nil {
 231  		return sdkErr
 232  	}
 233  
 234  	if res != nil && httpResponse.ContentLength != 0 {
 235  		contentType := httpResponse.Header.Get("Content-Type")
 236  
 237  		if strings.HasPrefix(contentType, "application/json") {
 238  			err = json.NewDecoder(httpResponse.Body).Decode(&res)
 239  			if err != nil {
 240  				return errors.Wrap(err, "could not parse %s response body", contentType)
 241  			}
 242  		} else {
 243  			buffer, isBuffer := res.(io.Writer)
 244  			if !isBuffer {
 245  				return errors.Wrap(err, "could not handle %s response body with %T result type", contentType, buffer)
 246  			}
 247  
 248  			_, err := io.Copy(buffer, httpResponse.Body)
 249  			if err != nil {
 250  				return errors.Wrap(err, "could not copy %s response body", contentType)
 251  			}
 252  		}
 253  
 254  		// Handle instance API X-Total-Count header
 255  		xTotalCountStr := httpResponse.Header.Get("X-Total-Count")
 256  		if legacyLister, isLegacyLister := res.(legacyLister); isLegacyLister && xTotalCountStr != "" {
 257  			xTotalCount, err := strconv.ParseInt(xTotalCountStr, 10, 32)
 258  			if err != nil {
 259  				return errors.Wrap(err, "could not parse X-Total-Count header")
 260  			}
 261  			legacyLister.UnsafeSetTotalCount(int(xTotalCount))
 262  		}
 263  	}
 264  
 265  	return nil
 266  }
 267  
 268  type lister interface {
 269  	UnsafeGetTotalCount() uint64
 270  	UnsafeAppend(any) (uint64, error)
 271  }
 272  
 273  // Old lister for uint32
 274  // Used for retro-compatibility with response that use uint32
 275  type lister32 interface {
 276  	UnsafeGetTotalCount() uint32
 277  	UnsafeAppend(any) (uint32, error)
 278  }
 279  
 280  type legacyLister interface {
 281  	UnsafeSetTotalCount(totalCount int)
 282  }
 283  
 284  func listerGetTotalCount(i any) uint64 {
 285  	if l, isLister := i.(lister); isLister {
 286  		return l.UnsafeGetTotalCount()
 287  	}
 288  	if l32, isLister32 := i.(lister32); isLister32 {
 289  		return uint64(l32.UnsafeGetTotalCount())
 290  	}
 291  	panic(fmt.Errorf("%T does not support pagination but checks failed, should not happen", i))
 292  }
 293  
 294  func listerAppend(recv any, elems any) (uint64, error) {
 295  	if l, isLister := recv.(lister); isLister {
 296  		return l.UnsafeAppend(elems)
 297  	} else if l32, isLister32 := recv.(lister32); isLister32 {
 298  		total, err := l32.UnsafeAppend(elems)
 299  		return uint64(total), err
 300  	}
 301  
 302  	panic(fmt.Errorf("%T does not support pagination but checks failed, should not happen", recv))
 303  }
 304  
 305  func isLister(i any) bool {
 306  	switch i.(type) {
 307  	case lister:
 308  		return true
 309  	case lister32:
 310  		return true
 311  	default:
 312  		return false
 313  	}
 314  }
 315  
 316  const maxPageCount uint64 = math.MaxUint32
 317  
 318  // doListAll collects all pages of a List request and aggregate all results on a single response.
 319  func (c *Client) doListAll(req *ScalewayRequest, res any) (err error) {
 320  	// check for lister interface
 321  	if isLister(res) {
 322  		pageCount := maxPageCount
 323  		for page := uint64(1); page <= pageCount; page++ {
 324  			// set current page
 325  			req.Query.Set("page", strconv.FormatUint(page, 10))
 326  
 327  			// request the next page
 328  			nextPage := newVariableFromType(res)
 329  			err := c.do(req, nextPage)
 330  			if err != nil {
 331  				return err
 332  			}
 333  
 334  			// append results
 335  			pageSize, err := listerAppend(res, nextPage)
 336  			if err != nil {
 337  				return err
 338  			}
 339  
 340  			if pageSize == 0 {
 341  				return nil
 342  			}
 343  
 344  			// set total count on first request
 345  			if pageCount == maxPageCount {
 346  				totalCount := listerGetTotalCount(nextPage)
 347  				pageCount = (totalCount + pageSize - 1) / pageSize
 348  			}
 349  		}
 350  		return nil
 351  	}
 352  
 353  	return errors.New("%T does not support pagination", res)
 354  }
 355  
 356  // doListLocalities collects all localities using multiple list requests and aggregate all results on a lister response
 357  // results is sorted by locality
 358  func (c *Client) doListLocalities(req *ScalewayRequest, res any, localities []string) (err error) {
 359  	path := req.Path
 360  	if !strings.Contains(path, "%locality%") {
 361  		return errors.New("request is not a valid locality request")
 362  	}
 363  	// Requests are parallelized
 364  	responseMutex := sync.Mutex{}
 365  	requestGroup := sync.WaitGroup{}
 366  	errChan := make(chan error, len(localities))
 367  
 368  	requestGroup.Add(len(localities))
 369  	for _, locality := range localities {
 370  		go func(locality string) {
 371  			defer requestGroup.Done()
 372  			// Request is cloned as doListAll will change header
 373  			// We remove zones as it would recurse in the same function
 374  			req := req.clone()
 375  			req.zones = []Zone(nil)
 376  			req.Path = strings.ReplaceAll(path, "%locality%", locality)
 377  
 378  			// We create a new response that we append to main response
 379  			zoneResponse := newVariableFromType(res)
 380  			err := c.Do(req, zoneResponse)
 381  			if err != nil {
 382  				errChan <- err
 383  			}
 384  			responseMutex.Lock()
 385  			_, err = listerAppend(res, zoneResponse)
 386  			responseMutex.Unlock()
 387  			if err != nil {
 388  				errChan <- err
 389  			}
 390  		}(locality)
 391  	}
 392  	requestGroup.Wait()
 393  
 394  L: // We gather potential errors and return them all together
 395  	for {
 396  		select {
 397  		case newErr := <-errChan:
 398  			err = errors.Wrap(err, "%s", newErr.Error())
 399  		default:
 400  			break L
 401  		}
 402  	}
 403  	close(errChan)
 404  	if err != nil {
 405  		return err
 406  	}
 407  	return nil
 408  }
 409  
 410  // doListZones collects all zones using multiple list requests and aggregate all results on a single response.
 411  // result is sorted by zone
 412  func (c *Client) doListZones(req *ScalewayRequest, res any, zones []Zone) (err error) {
 413  	if isLister(res) {
 414  		// Prepare request with %zone% that can be replaced with actual zone
 415  		for _, zone := range AllZones {
 416  			if strings.Contains(req.Path, string(zone)) {
 417  				req.Path = strings.ReplaceAll(req.Path, string(zone), "%locality%")
 418  				break
 419  			}
 420  		}
 421  		if !strings.Contains(req.Path, "%locality%") {
 422  			return errors.New("request is not a valid zoned request")
 423  		}
 424  		localities := make([]string, 0, len(zones))
 425  		for _, zone := range zones {
 426  			localities = append(localities, string(zone))
 427  		}
 428  
 429  		err := c.doListLocalities(req, res, localities)
 430  		if err != nil {
 431  			return fmt.Errorf("failed to list localities: %w", err)
 432  		}
 433  
 434  		sortResponseByZones(res, zones)
 435  		return nil
 436  	}
 437  
 438  	return errors.New("%T does not support pagination", res)
 439  }
 440  
 441  // doListRegions collects all regions using multiple list requests and aggregate all results on a single response.
 442  // result is sorted by region
 443  func (c *Client) doListRegions(req *ScalewayRequest, res any, regions []Region) (err error) {
 444  	if isLister(res) {
 445  		// Prepare request with %locality% that can be replaced with actual region
 446  		for _, region := range AllRegions {
 447  			if strings.Contains(req.Path, string(region)) {
 448  				req.Path = strings.ReplaceAll(req.Path, string(region), "%locality%")
 449  				break
 450  			}
 451  		}
 452  		if !strings.Contains(req.Path, "%locality%") {
 453  			return errors.New("request is not a valid zoned request")
 454  		}
 455  		localities := make([]string, 0, len(regions))
 456  		for _, region := range regions {
 457  			localities = append(localities, string(region))
 458  		}
 459  
 460  		err := c.doListLocalities(req, res, localities)
 461  		if err != nil {
 462  			return fmt.Errorf("failed to list localities: %w", err)
 463  		}
 464  
 465  		sortResponseByRegions(res, regions)
 466  		return nil
 467  	}
 468  
 469  	return errors.New("%T does not support pagination", res)
 470  }
 471  
 472  // sortSliceByZones sorts a slice of struct using a Zone field that should exist
 473  func sortSliceByZones(list any, zones []Zone) {
 474  	if !generic.HasField(list, "Zone") {
 475  		return
 476  	}
 477  
 478  	zoneMap := map[Zone]int{}
 479  	for i, zone := range zones {
 480  		zoneMap[zone] = i
 481  	}
 482  	generic.SortSliceByField(list, "Zone", func(i any, i2 any) bool {
 483  		return zoneMap[i.(Zone)] < zoneMap[i2.(Zone)]
 484  	})
 485  }
 486  
 487  // sortSliceByRegions sorts a slice of struct using a Region field that should exist
 488  func sortSliceByRegions(list any, regions []Region) {
 489  	if !generic.HasField(list, "Region") {
 490  		return
 491  	}
 492  
 493  	regionMap := map[Region]int{}
 494  	for i, region := range regions {
 495  		regionMap[region] = i
 496  	}
 497  	generic.SortSliceByField(list, "Region", func(i any, i2 any) bool {
 498  		return regionMap[i.(Region)] < regionMap[i2.(Region)]
 499  	})
 500  }
 501  
 502  // sortResponseByZones find first field that is a slice in a struct and sort it by zone
 503  func sortResponseByZones(res any, zones []Zone) {
 504  	// res may be ListServersResponse
 505  	//
 506  	// type ListServersResponse struct {
 507  	//	TotalCount uint32 `json:"total_count"`
 508  	//	Servers []*Server `json:"servers"`
 509  	// }
 510  	// We iterate over fields searching for the slice one to sort it
 511  	resType := reflect.TypeOf(res).Elem()
 512  	fields := reflect.VisibleFields(resType)
 513  	for _, field := range fields {
 514  		if field.Type.Kind() == reflect.Slice {
 515  			sortSliceByZones(reflect.ValueOf(res).Elem().FieldByName(field.Name).Interface(), zones)
 516  			return
 517  		}
 518  	}
 519  }
 520  
 521  // sortResponseByRegions find first field that is a slice in a struct and sort it by region
 522  func sortResponseByRegions(res any, regions []Region) {
 523  	// res may be ListServersResponse
 524  	//
 525  	// type ListServersResponse struct {
 526  	//	TotalCount uint32 `json:"total_count"`
 527  	//	Servers []*Server `json:"servers"`
 528  	// }
 529  	// We iterate over fields searching for the slice one to sort it
 530  	resType := reflect.TypeOf(res).Elem()
 531  	fields := reflect.VisibleFields(resType)
 532  	for _, field := range fields {
 533  		if field.Type.Kind() == reflect.Slice {
 534  			sortSliceByRegions(reflect.ValueOf(res).Elem().FieldByName(field.Name).Interface(), regions)
 535  			return
 536  		}
 537  	}
 538  }
 539  
 540  // newVariableFromType returns a variable set to the zero value of the given type
 541  func newVariableFromType(t any) any {
 542  	// reflect.New always create a pointer, that's why we use reflect.Indirect before
 543  	return reflect.New(reflect.Indirect(reflect.ValueOf(t)).Type()).Interface()
 544  }
 545  
 546  func newHTTPClient() *http.Client {
 547  	return &http.Client{
 548  		Timeout:   30 * time.Second,
 549  		Transport: http.DefaultTransport.(*http.Transport).Clone(),
 550  	}
 551  }
 552  
 553  func setInsecureMode(c httpClient) {
 554  	standardHTTPClient, ok := c.(*http.Client)
 555  	if !ok {
 556  		logger.Warningf("client: cannot use insecure mode with HTTP client of type %T", c)
 557  		return
 558  	}
 559  
 560  	altTransport, ok := standardHTTPClient.Transport.(interface {
 561  		SetInsecureTransport()
 562  	})
 563  	if ok {
 564  		altTransport.SetInsecureTransport()
 565  		return
 566  	}
 567  
 568  	transportClient, ok := standardHTTPClient.Transport.(*http.Transport)
 569  	if !ok {
 570  		logger.Warningf("client: cannot use insecure mode with Transport client of type %T", standardHTTPClient.Transport)
 571  		return
 572  	}
 573  	if transportClient.TLSClientConfig == nil {
 574  		transportClient.TLSClientConfig = &tls.Config{}
 575  	}
 576  	transportClient.TLSClientConfig.InsecureSkipVerify = true
 577  }
 578  
 579  func setRequestLogging(c httpClient) {
 580  	standardHTTPClient, ok := c.(*http.Client)
 581  	if !ok {
 582  		logger.Warningf("client: cannot use request logger with HTTP client of type %T", c)
 583  		return
 584  	}
 585  	// Do not wrap transport if it is already a logger
 586  	// As client is a pointer, changing transport will change given client
 587  	// If the same httpClient is used in multiple scwClient, it would add multiple logger transports
 588  	_, isLogger := standardHTTPClient.Transport.(*requestLoggerTransport)
 589  	if !isLogger {
 590  		standardHTTPClient.Transport = &requestLoggerTransport{rt: standardHTTPClient.Transport}
 591  	}
 592  }
 593