proto.go raw

   1  package operation
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  	"strconv"
   7  	"time"
   8  
   9  	"google.golang.org/grpc"
  10  	"google.golang.org/grpc/codes"
  11  	"google.golang.org/grpc/metadata"
  12  	"google.golang.org/grpc/status"
  13  	"google.golang.org/protobuf/proto"
  14  )
  15  
  16  // pollIntervalMetadataKey is a metadata key used to retrieve the polling interval value from gRPC response headers.
  17  const pollIntervalMetadataKey = "x-operation-poll-interval"
  18  
  19  // Error extracts and returns the error from the provided YCOperation if it exists, otherwise returns nil.
  20  func Error(op YCOperation) error {
  21  	pb := op.GetError()
  22  	if pb == nil {
  23  		return nil
  24  	}
  25  
  26  	st := status.FromProto(pb)
  27  	if st == nil {
  28  		return nil
  29  	}
  30  
  31  	return st.Err()
  32  }
  33  
  34  // Response processes the response from a YCOperation and unmarshals it into a proto.Message.
  35  func Response(op YCOperation) (proto.Message, error) {
  36  	raw := op.GetResponse()
  37  	if raw == nil {
  38  		return nil, nil
  39  	}
  40  
  41  	return raw.UnmarshalNew()
  42  }
  43  
  44  // PollFunc defines a function type used to poll the status of an operation using its ID and return a YCOperation or an error.
  45  type PollFunc func(ctx context.Context, operationId string, opts ...grpc.CallOption) (YCOperation, error)
  46  
  47  // waitInterval polls the operation until its completion or context timeout, using a custom polling interval strategy.
  48  func waitInterval(
  49  	ctx context.Context,
  50  	operationId string,
  51  	poll PollFunc,
  52  	pollInterval PollIntervalFunc,
  53  	opts ...grpc.CallOption,
  54  ) (YCOperation, error) {
  55  	op, err := PollUntilDone(ctx, operationId, poll, pollInterval, opts...)
  56  	if err != nil {
  57  		return nil, err
  58  	}
  59  
  60  	err = Error(op)
  61  	if err != nil {
  62  		return nil, fmt.Errorf("operation (id=%s) failed: %w", op.GetId(), err)
  63  	}
  64  
  65  	return op, nil
  66  }
  67  
  68  // PollUntilDone repeatedly polls the status of a long-running operation until it is marked as done or the context is canceled.
  69  // ctx is the context for managing the polling lifecycle and cancellation.
  70  // operationId identifies the specific operation to be polled.
  71  // poll is a function used to fetch the current state of the operation.
  72  // pollInterval determines the duration to wait between polling attempts based on the attempt count.
  73  // opts allows for additional gRPC call options to be passed during polling.
  74  // Returns the completed YCOperation or an error if polling fails or the context is canceled.
  75  func PollUntilDone(
  76  	ctx context.Context,
  77  	operationId string,
  78  	poll PollFunc,
  79  	pollInterval PollIntervalFunc,
  80  	opts ...grpc.CallOption,
  81  ) (YCOperation, error) {
  82  	var headers metadata.MD
  83  	opts = append(opts, grpc.Header(&headers))
  84  
  85  	// Sometimes, the returned operation is not on all replicas yet,
  86  	// so we need to ignore first couple of NotFound errors.
  87  	const maxNotFoundRetry = 3
  88  
  89  	notFoundCount := 0
  90  	attempt := 0
  91  
  92  	for {
  93  		headers = metadata.MD{}
  94  
  95  		polledOperation, err := poll(ctx, operationId, opts...)
  96  		if err != nil {
  97  			st, ok := status.FromError(err)
  98  			if notFoundCount < maxNotFoundRetry && ok && st.Code() == codes.NotFound {
  99  				notFoundCount++
 100  			} else {
 101  				// Message needed to distinguish poll fail and operation error, which are both gRPC status.
 102  				return nil, fmt.Errorf("operation (id=%s) poll fail: %w", operationId, err)
 103  			}
 104  		} else {
 105  			if polledOperation.GetDone() {
 106  				return polledOperation, nil
 107  			}
 108  		}
 109  
 110  		interval := pollInterval(attempt)
 111  		attempt++
 112  
 113  		if values := headers.Get(pollIntervalMetadataKey); len(values) > 0 {
 114  			i, err := strconv.Atoi(values[0])
 115  			if err == nil {
 116  				interval = time.Duration(i) * time.Second
 117  			}
 118  		}
 119  
 120  		if interval <= 0 {
 121  			continue
 122  		}
 123  
 124  		waitTimer := time.NewTimer(interval)
 125  		select {
 126  		case <-waitTimer.C:
 127  		case <-ctx.Done():
 128  			waitTimer.Stop()
 129  			return nil, fmt.Errorf("operation (id=%s) wait context done: %w", operationId, ctx.Err())
 130  		}
 131  	}
 132  }
 133