1 package operation
2 3 import (
4 "context"
5 "fmt"
6 "strconv"
7 "time"
8 9 "google.golang.org/grpc"
10 "google.golang.org/grpc/codes"
11 "google.golang.org/grpc/metadata"
12 "google.golang.org/grpc/status"
13 "google.golang.org/protobuf/proto"
14 )
15 16 // pollIntervalMetadataKey is a metadata key used to retrieve the polling interval value from gRPC response headers.
17 const pollIntervalMetadataKey = "x-operation-poll-interval"
18 19 // Error extracts and returns the error from the provided YCOperation if it exists, otherwise returns nil.
20 func Error(op YCOperation) error {
21 pb := op.GetError()
22 if pb == nil {
23 return nil
24 }
25 26 st := status.FromProto(pb)
27 if st == nil {
28 return nil
29 }
30 31 return st.Err()
32 }
33 34 // Response processes the response from a YCOperation and unmarshals it into a proto.Message.
35 func Response(op YCOperation) (proto.Message, error) {
36 raw := op.GetResponse()
37 if raw == nil {
38 return nil, nil
39 }
40 41 return raw.UnmarshalNew()
42 }
43 44 // PollFunc defines a function type used to poll the status of an operation using its ID and return a YCOperation or an error.
45 type PollFunc func(ctx context.Context, operationId string, opts ...grpc.CallOption) (YCOperation, error)
46 47 // waitInterval polls the operation until its completion or context timeout, using a custom polling interval strategy.
48 func waitInterval(
49 ctx context.Context,
50 operationId string,
51 poll PollFunc,
52 pollInterval PollIntervalFunc,
53 opts ...grpc.CallOption,
54 ) (YCOperation, error) {
55 op, err := PollUntilDone(ctx, operationId, poll, pollInterval, opts...)
56 if err != nil {
57 return nil, err
58 }
59 60 err = Error(op)
61 if err != nil {
62 return nil, fmt.Errorf("operation (id=%s) failed: %w", op.GetId(), err)
63 }
64 65 return op, nil
66 }
67 68 // PollUntilDone repeatedly polls the status of a long-running operation until it is marked as done or the context is canceled.
69 // ctx is the context for managing the polling lifecycle and cancellation.
70 // operationId identifies the specific operation to be polled.
71 // poll is a function used to fetch the current state of the operation.
72 // pollInterval determines the duration to wait between polling attempts based on the attempt count.
73 // opts allows for additional gRPC call options to be passed during polling.
74 // Returns the completed YCOperation or an error if polling fails or the context is canceled.
75 func PollUntilDone(
76 ctx context.Context,
77 operationId string,
78 poll PollFunc,
79 pollInterval PollIntervalFunc,
80 opts ...grpc.CallOption,
81 ) (YCOperation, error) {
82 var headers metadata.MD
83 opts = append(opts, grpc.Header(&headers))
84 85 // Sometimes, the returned operation is not on all replicas yet,
86 // so we need to ignore first couple of NotFound errors.
87 const maxNotFoundRetry = 3
88 89 notFoundCount := 0
90 attempt := 0
91 92 for {
93 headers = metadata.MD{}
94 95 polledOperation, err := poll(ctx, operationId, opts...)
96 if err != nil {
97 st, ok := status.FromError(err)
98 if notFoundCount < maxNotFoundRetry && ok && st.Code() == codes.NotFound {
99 notFoundCount++
100 } else {
101 // Message needed to distinguish poll fail and operation error, which are both gRPC status.
102 return nil, fmt.Errorf("operation (id=%s) poll fail: %w", operationId, err)
103 }
104 } else {
105 if polledOperation.GetDone() {
106 return polledOperation, nil
107 }
108 }
109 110 interval := pollInterval(attempt)
111 attempt++
112 113 if values := headers.Get(pollIntervalMetadataKey); len(values) > 0 {
114 i, err := strconv.Atoi(values[0])
115 if err == nil {
116 interval = time.Duration(i) * time.Second
117 }
118 }
119 120 if interval <= 0 {
121 continue
122 }
123 124 waitTimer := time.NewTimer(interval)
125 select {
126 case <-waitTimer.C:
127 case <-ctx.Done():
128 waitTimer.Stop()
129 return nil, fmt.Errorf("operation (id=%s) wait context done: %w", operationId, ctx.Err())
130 }
131 }
132 }
133