pool.go raw

   1  package grpc
   2  
   3  import (
   4  	"context"
   5  	"errors"
   6  	"sync"
   7  
   8  	"google.golang.org/grpc"
   9  	"google.golang.org/grpc/status"
  10  
  11  	"github.com/yandex-cloud/go-sdk/v2/pkg/endpoints"
  12  	pkgerrors "github.com/yandex-cloud/go-sdk/v2/pkg/errors"
  13  )
  14  
  15  // ConnPool manages a pool of gRPC client connections with support for connection reuse and context-based lifecycle.
  16  type ConnPool struct {
  17  	dialOptions []grpc.DialOption
  18  
  19  	ctx    context.Context
  20  	cancel context.CancelFunc
  21  
  22  	mu     sync.RWMutex
  23  	conns  sync.Map
  24  	closed bool
  25  }
  26  
  27  // NewConnPool creates a new connection pool with the provided gRPC dial options. It initializes context and internal state.
  28  func NewConnPool(opts []grpc.DialOption) *ConnPool {
  29  	ctx, cancel := context.WithCancel(context.Background())
  30  
  31  	return &ConnPool{
  32  		dialOptions: opts,
  33  		ctx:         ctx,
  34  		cancel:      cancel,
  35  		conns:       sync.Map{},
  36  	}
  37  }
  38  
  39  // GetConn retrieves or establishes a gRPC connection for the specified endpoint, respecting the provided context constraints.
  40  func (cc *ConnPool) GetConn(ctx context.Context, ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
  41  	if err := ctx.Err(); err != nil {
  42  		// fail fast: don't even try
  43  		return nil, status.FromContextError(err).Err()
  44  	}
  45  
  46  	if conn, err := cc.pooledConn(ep); conn != nil || err != nil {
  47  		return conn, err
  48  	}
  49  
  50  	type dialResult struct {
  51  		conn *grpc.ClientConn
  52  		err  error
  53  	}
  54  
  55  	resultCh := make(chan dialResult, 1)
  56  	go func() {
  57  		defer func() {
  58  			if r := recover(); r != nil {
  59  				resultCh <- dialResult{nil, r.(error)}
  60  			}
  61  		}()
  62  
  63  		v, err := cc.makeDial(ep)
  64  		resultCh <- dialResult{v, err}
  65  	}()
  66  
  67  	select {
  68  	case res := <-resultCh:
  69  		return res.conn, res.err
  70  	case <-ctx.Done():
  71  		return nil, status.FromContextError(ctx.Err()).Err()
  72  	}
  73  }
  74  
  75  // makeDial establishes a new gRPC client connection to the specified endpoint and stores it in the connection pool.
  76  func (cc *ConnPool) makeDial(ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
  77  	if conn, err := cc.pooledConn(ep); conn != nil || err != nil {
  78  		return conn, err
  79  	}
  80  
  81  	opts := make([]grpc.DialOption, 0, len(cc.dialOptions)+len(ep.DialOptions))
  82  	opts = append(opts, cc.dialOptions...)
  83  	opts = append(opts, ep.DialOptions...)
  84  
  85  	conn, err := grpc.NewClient(ep.Addr, opts...)
  86  	if err != nil {
  87  		if cc.ctx.Err() != nil {
  88  			return nil, pkgerrors.ErrConnContextClosed
  89  		}
  90  
  91  		return nil, &pkgerrors.DialError{Err: err, Addr: ep.Addr}
  92  	}
  93  
  94  	cc.mu.Lock()
  95  	defer cc.mu.Unlock()
  96  
  97  	if cc.closed {
  98  		_ = conn.Close()
  99  		return nil, nil
 100  	}
 101  
 102  	cc.conns.Store(ep, conn)
 103  
 104  	return conn, nil
 105  }
 106  
 107  // pooledConn retrieves a gRPC client connection from the pool for the given endpoint if available, else returns nil.
 108  // It ensures thread-safe access and returns an error if the connection context is closed.
 109  func (cc *ConnPool) pooledConn(ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
 110  	cc.mu.RLock()
 111  	defer cc.mu.RUnlock()
 112  
 113  	if cc.closed {
 114  		return nil, pkgerrors.ErrConnContextClosed
 115  	}
 116  
 117  	value, ok := cc.conns.Load(ep)
 118  	if !ok {
 119  		return nil, nil
 120  	}
 121  
 122  	conn := value.(*grpc.ClientConn)
 123  
 124  	return conn, nil
 125  }
 126  
 127  // Shutdown closes all active connections in the pool, cancels the context, and marks the pool as closed. Returns errors if any occur during the connection closures.
 128  func (cc *ConnPool) Shutdown(ctx context.Context) error {
 129  	cc.mu.Lock()
 130  	defer cc.mu.Unlock()
 131  
 132  	if cc.closed {
 133  		return nil
 134  	}
 135  
 136  	defer func() {
 137  		cc.closed = true
 138  		cc.cancel()
 139  	}()
 140  
 141  	var errs error
 142  
 143  	cc.conns.Range(
 144  		func(_, value any) bool {
 145  			conn := value.(*grpc.ClientConn)
 146  			if err := conn.Close(); err != nil {
 147  				errs = errors.Join(errs, err)
 148  			}
 149  
 150  			return true
 151  		},
 152  	)
 153  
 154  	return errs
 155  }
 156