polling.go raw

   1  // Copyright 2022-2025 The sacloud/packages-go Authors
   2  //
   3  // Licensed under the Apache License, Version 2.0 (the "License");
   4  // you may not use this file except in compliance with the License.
   5  // You may obtain a copy of the License at
   6  //
   7  //      http://www.apache.org/licenses/LICENSE-2.0
   8  //
   9  // Unless required by applicable law or agreed to in writing, software
  10  // distributed under the License is distributed on an "AS IS" BASIS,
  11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12  // See the License for the specific language governing permissions and
  13  // limitations under the License.
  14  
  15  package wait
  16  
  17  import (
  18  	"context"
  19  	"time"
  20  )
  21  
  22  var (
  23  	defaultPollingTimeout  = 20 * time.Minute
  24  	defaultPollingInterval = 5 * time.Second
  25  )
  26  
  27  // StateReadFunc PollingWaiterにより利用される、対象リソースの状態を取得するためのfunc
  28  type StateReadFunc func() (state interface{}, err error)
  29  
  30  // StateCheckFunc StateReadFuncで得たリソースの情報を元に待ちを継続するか判定するためのfunc
  31  //
  32  // completeがtrueの場合待ち処理を終了する
  33  type StateCheckFunc func(target interface{}) (complete bool, err error)
  34  
  35  // ComposeStateCheckFunc 指定のStateCheckFuncを順に適用するするStateCheckFuncを生成する
  36  //
  37  // completeがtrueまたはerrが非nilになったら即時リターンする。
  38  func ComposeStateCheckFunc(funcs ...StateCheckFunc) StateCheckFunc {
  39  	return func(target interface{}) (bool, error) {
  40  		for _, f := range funcs {
  41  			complete, err := f(target)
  42  			if err != nil {
  43  				return false, err
  44  			}
  45  			if complete {
  46  				return complete, nil
  47  			}
  48  		}
  49  		return false, nil
  50  	}
  51  }
  52  
  53  // PollingWaiter ポーリングでステート変更を検知するWaiter
  54  type PollingWaiter struct {
  55  	// ReadFunc 対象リソースの状態を取得するためのfunc
  56  	//
  57  	// ReadFuncがerrorを返した場合は待ち処理が即時リターンする。非エラーかつ非nilを返した場合のみStateCheckFuncでの判定処理を行う。
  58  	ReadFunc StateReadFunc
  59  
  60  	// StateCheckFunc ReadFuncで得たリソースの情報を元に待ちを継続するかの判定を行うためのfunc
  61  	StateCheckFunc StateCheckFunc
  62  
  63  	// Timeout タイムアウト
  64  	Timeout time.Duration // タイムアウト
  65  
  66  	// Interval ポーリング間隔
  67  	Interval time.Duration
  68  }
  69  
  70  // WaitForState リソースが指定の状態になるまで待つ
  71  func (w *PollingWaiter) WaitForState(ctx context.Context) (interface{}, error) {
  72  	c, p, e := w.WaitForStateAsync(ctx)
  73  	for {
  74  		select {
  75  		case <-ctx.Done():
  76  			return nil, ctx.Err()
  77  		case lastState := <-c:
  78  			return lastState, nil
  79  		case <-p:
  80  			// noop
  81  		case err := <-e:
  82  			return nil, err
  83  		}
  84  	}
  85  }
  86  
  87  // WaitForStateAsync リソースが指定の状態になるまで待つ
  88  func (w *PollingWaiter) WaitForStateAsync(ctx context.Context) (<-chan interface{}, <-chan interface{}, <-chan error) {
  89  	w.validateFields()
  90  	w.defaults()
  91  
  92  	compCh := make(chan interface{})
  93  	progressCh := make(chan interface{})
  94  	errCh := make(chan error)
  95  
  96  	ticker := time.NewTicker(w.Interval)
  97  
  98  	go func() {
  99  		ctx, cancel := context.WithTimeout(ctx, w.Timeout)
 100  		defer cancel()
 101  
 102  		defer ticker.Stop()
 103  
 104  		defer close(compCh)
 105  		defer close(progressCh)
 106  		defer close(errCh)
 107  
 108  		for {
 109  			select {
 110  			case <-ctx.Done():
 111  				errCh <- ctx.Err()
 112  				return
 113  			case <-ticker.C:
 114  				state, err := w.ReadFunc()
 115  				if err != nil {
 116  					errCh <- err
 117  					return
 118  				}
 119  				if state != nil {
 120  					complete, err := w.StateCheckFunc(state)
 121  					if complete {
 122  						compCh <- state
 123  						return
 124  					}
 125  
 126  					if err != nil {
 127  						errCh <- err
 128  						return
 129  					}
 130  				}
 131  				// note: nilの場合もあり得る
 132  				progressCh <- state
 133  			}
 134  		}
 135  	}()
 136  
 137  	return compCh, progressCh, errCh
 138  }
 139  
 140  func (w *PollingWaiter) validateFields() {
 141  	if w.ReadFunc == nil {
 142  		panic("required: PollingWaiter.ReadFunc")
 143  	}
 144  
 145  	if w.StateCheckFunc == nil {
 146  		panic("required: PollingWaiter.StateCheckFunc")
 147  	}
 148  }
 149  
 150  func (w *PollingWaiter) defaults() {
 151  	if w.Timeout == time.Duration(0) {
 152  		w.Timeout = defaultPollingTimeout
 153  	}
 154  	if w.Interval == time.Duration(0) {
 155  		w.Interval = defaultPollingInterval
 156  	}
 157  }
 158