api.go raw

   1  package v3
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"crypto/hmac"
   7  	"crypto/sha256"
   8  	"encoding/base64"
   9  	"encoding/json"
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  	"math"
  14  	"net/http"
  15  	"net/http/httputil"
  16  	"os"
  17  	"sort"
  18  	"strings"
  19  	"time"
  20  
  21  	"github.com/go-playground/validator/v10"
  22  	"github.com/google/uuid"
  23  )
  24  
  25  type UUID string
  26  
  27  func (u UUID) String() string {
  28  	return string(u)
  29  }
  30  
  31  func ParseUUID(s string) (UUID, error) {
  32  	id, err := uuid.Parse(s)
  33  	if err != nil {
  34  		return "", err
  35  	}
  36  
  37  	return UUID(id.String()), nil
  38  }
  39  
  40  // Wait is a helper that waits for async operation to reach the final state.
  41  // Final states are one of: failure, success, timeout.
  42  // If states argument are given, returns an error if the final state not match on of those.
  43  func (c Client) Wait(ctx context.Context, op *Operation, states ...OperationState) (*Operation, error) {
  44  	const abortErrorsCount = 5
  45  
  46  	if op == nil {
  47  		return nil, fmt.Errorf("operation is nil")
  48  	}
  49  
  50  	startTime := time.Now()
  51  
  52  	ticker := time.NewTicker(pollInterval(0))
  53  	defer ticker.Stop()
  54  
  55  	if op.State != OperationStatePending {
  56  		return op, nil
  57  	}
  58  
  59  	var subsequentErrors int
  60  	var operation *Operation
  61  polling:
  62  	for {
  63  		select {
  64  		case <-ticker.C:
  65  			runTime := time.Since(startTime)
  66  
  67  			if c.waitTimeout != 0 && runTime > c.waitTimeout {
  68  				return nil, fmt.Errorf("operation: %q: max wait timeout reached", op.ID)
  69  			}
  70  
  71  			newInterval := pollInterval(runTime)
  72  			ticker.Reset(newInterval)
  73  
  74  			o, err := c.GetOperation(ctx, op.ID)
  75  			if err != nil {
  76  				subsequentErrors++
  77  				if subsequentErrors >= abortErrorsCount {
  78  					return nil, err
  79  				}
  80  				continue
  81  			}
  82  			subsequentErrors = 0
  83  
  84  			if o.State == OperationStatePending {
  85  				continue
  86  			}
  87  
  88  			operation = o
  89  			break polling
  90  		case <-ctx.Done():
  91  			return nil, ctx.Err()
  92  		}
  93  	}
  94  
  95  	if len(states) == 0 {
  96  		return operation, nil
  97  	}
  98  
  99  	for _, st := range states {
 100  		if operation.State == st {
 101  			return operation, nil
 102  		}
 103  	}
 104  
 105  	var ref OperationReference
 106  	if operation.Reference != nil {
 107  		ref = *operation.Reference
 108  	}
 109  
 110  	return nil,
 111  		fmt.Errorf("operation: %q %v, state: %s, reason: %q, message: %q",
 112  			operation.ID,
 113  			ref,
 114  			operation.State,
 115  			operation.Reason,
 116  			operation.Message,
 117  		)
 118  }
 119  
 120  func String(s string) *string {
 121  	return &s
 122  }
 123  
 124  func Int64(i int64) *int64 {
 125  	return &i
 126  }
 127  
 128  func Bool(b bool) *bool {
 129  	return &b
 130  }
 131  
 132  func Ptr[T any](v T) *T {
 133  	return &v
 134  }
 135  
 136  // Validate any struct from schema or request
 137  func (c Client) Validate(s any) error {
 138  	err := c.validate.Struct(s)
 139  	if err == nil {
 140  		return nil
 141  	}
 142  
 143  	// Print better error messages
 144  	validationErrors := err.(validator.ValidationErrors)
 145  
 146  	if len(validationErrors) > 0 {
 147  		e := validationErrors[0]
 148  		errorString := fmt.Sprintf(
 149  			"request validation error: '%s' = '%v' does not validate ",
 150  			e.StructNamespace(),
 151  			e.Value(),
 152  		)
 153  		if e.Param() == "" {
 154  			errorString += fmt.Sprintf("'%s'", e.ActualTag())
 155  		} else {
 156  			errorString += fmt.Sprintf("'%s=%v'", e.ActualTag(), e.Param())
 157  		}
 158  		return errors.New(errorString)
 159  	}
 160  
 161  	return err
 162  }
 163  
 164  // pollInterval returns the wait interval (as a time.Duration) before the next poll, based on the current runtime of a job.
 165  // The polling frequency is:
 166  //   - every 3 seconds for the first 30 seconds
 167  //   - then increases linearly to reach 1 minute at 15 minutes of runtime
 168  //   - after 15 minutes, it stays at 1 minute intervals
 169  func pollInterval(runTime time.Duration) time.Duration {
 170  	runTimeSeconds := runTime.Seconds()
 171  
 172  	// Coefficients for the linear equation y = a * x + b
 173  	a := 57.0 / 870.0
 174  	b := 3.0 - 30.0*a
 175  
 176  	minWait := 3.0
 177  	maxWait := 60.0
 178  
 179  	interval := a*runTimeSeconds + b
 180  	interval = math.Max(minWait, interval)
 181  	interval = math.Min(maxWait, interval)
 182  
 183  	return time.Duration(interval) * time.Second
 184  }
 185  
 186  func prepareJSONBody(body any) (*bytes.Reader, error) {
 187  	buf, err := json.Marshal(body)
 188  	if err != nil {
 189  		return nil, err
 190  	}
 191  
 192  	return bytes.NewReader(buf), nil
 193  }
 194  
 195  func prepareJSONResponse(resp *http.Response, v any) error {
 196  	defer resp.Body.Close()
 197  
 198  	buf, err := io.ReadAll(resp.Body)
 199  	if err != nil {
 200  		return err
 201  	}
 202  	if err := json.Unmarshal(buf, v); err != nil {
 203  		return err
 204  	}
 205  
 206  	return nil
 207  }
 208  
 209  func (c Client) signRequest(req *http.Request) error {
 210  	var (
 211  		sigParts    []string
 212  		headerParts []string
 213  	)
 214  
 215  	var expiration = time.Now().UTC().Add(time.Minute * 10)
 216  
 217  	// Request method/URL path
 218  	sigParts = append(sigParts, fmt.Sprintf("%s %s", req.Method, req.URL.EscapedPath()))
 219  	headerParts = append(headerParts, "EXO2-HMAC-SHA256 credential="+c.apiKey)
 220  
 221  	// Request body if present
 222  	body := ""
 223  	if req.Body != nil {
 224  		data, err := io.ReadAll(req.Body)
 225  		if err != nil {
 226  			return err
 227  		}
 228  		err = req.Body.Close()
 229  		if err != nil {
 230  			return err
 231  		}
 232  		body = string(data)
 233  		req.Body = io.NopCloser(bytes.NewReader(data))
 234  	}
 235  	sigParts = append(sigParts, body)
 236  
 237  	// Request query string parameters
 238  	// Important: this is order-sensitive, we have to have to sort parameters alphabetically to ensure signed
 239  	// values match the names listed in the "signed-query-args=" signature pragma.
 240  	signedParams, paramsValues := extractRequestParameters(req)
 241  	sigParts = append(sigParts, paramsValues)
 242  	if len(signedParams) > 0 {
 243  		headerParts = append(headerParts, "signed-query-args="+strings.Join(signedParams, ";"))
 244  	}
 245  
 246  	// Request headers -- none at the moment
 247  	// Note: the same order-sensitive caution for query string parameters applies to headers.
 248  	sigParts = append(sigParts, "")
 249  
 250  	// Request expiration date (UNIX timestamp, no line return)
 251  	sigParts = append(sigParts, fmt.Sprint(expiration.Unix()))
 252  	headerParts = append(headerParts, "expires="+fmt.Sprint(expiration.Unix()))
 253  
 254  	h := hmac.New(sha256.New, []byte(c.apiSecret))
 255  	if _, err := h.Write([]byte(strings.Join(sigParts, "\n"))); err != nil {
 256  		return err
 257  	}
 258  	headerParts = append(headerParts, "signature="+base64.StdEncoding.EncodeToString(h.Sum(nil)))
 259  
 260  	req.Header.Set("Authorization", strings.Join(headerParts, ","))
 261  
 262  	return nil
 263  }
 264  
 265  // extractRequestParameters returns the list of request URL parameters names
 266  // and a strings concatenating the values of the parameters.
 267  func extractRequestParameters(req *http.Request) ([]string, string) {
 268  	var (
 269  		names  []string
 270  		values string
 271  	)
 272  
 273  	for param, values := range req.URL.Query() {
 274  		// Keep only parameters that hold exactly 1 value (i.e. no empty or multi-valued parameters)
 275  		if len(values) == 1 {
 276  			names = append(names, param)
 277  		}
 278  	}
 279  	sort.Strings(names)
 280  
 281  	for _, param := range names {
 282  		values += req.URL.Query().Get(param)
 283  	}
 284  
 285  	return names, values
 286  }
 287  
 288  func dumpRequest(req *http.Request, operationID string) {
 289  	if req != nil {
 290  		if dump, err := httputil.DumpRequest(req, true); err == nil {
 291  			fmt.Fprintf(os.Stderr, ">>> Operation: %s\n%s\n", operationID, dump)
 292  		}
 293  	}
 294  }
 295  
 296  func dumpResponse(resp *http.Response) {
 297  	if resp != nil {
 298  		if dump, err := httputil.DumpResponse(resp, true); err == nil {
 299  			fmt.Fprintf(os.Stderr, "<<< %s\n", dump)
 300  			fmt.Fprintln(os.Stderr, "----------------------------------------------------------------------")
 301  		}
 302  	}
 303  }
 304  
 305  // FindInstanceType attempts to find an InstanceType by id, or by a string or the form FAMILY.SIZE or SIZE,
 306  // where family defaults to "standard"
 307  func (l ListInstanceTypesResponse) FindInstanceTypeByIdOrFamilyAndSize(familyAndSizeOrId string) (InstanceType, error) {
 308  	var result []InstanceType
 309  
 310  	var typeFamily, typeSize string
 311  	parts := strings.SplitN(familyAndSizeOrId, ".", 2)
 312  	if l := len(parts); l > 0 {
 313  		if l == 1 {
 314  			typeFamily, typeSize = "standard", strings.ToLower(parts[0])
 315  		} else {
 316  			typeFamily, typeSize = strings.ToLower(parts[0]), strings.ToLower(parts[1])
 317  		}
 318  	}
 319  
 320  	for i, elem := range l.InstanceTypes {
 321  		if string(elem.ID) == familyAndSizeOrId || (string(elem.Size) == typeSize && string(elem.Family) == typeFamily) {
 322  			result = append(result, l.InstanceTypes[i])
 323  		}
 324  	}
 325  	if len(result) == 1 {
 326  		return result[0], nil
 327  	}
 328  
 329  	if len(result) > 1 {
 330  		return InstanceType{}, fmt.Errorf("%q too many found in ListInstanceTypesResponse: %w", familyAndSizeOrId, ErrConflict)
 331  	}
 332  
 333  	return InstanceType{}, fmt.Errorf("%q not found in ListInstanceTypesResponse: %w", familyAndSizeOrId, ErrNotFound)
 334  }
 335