pool.go raw
1 package grpc
2
3 import (
4 "context"
5 "errors"
6 "sync"
7
8 "google.golang.org/grpc"
9 "google.golang.org/grpc/status"
10
11 "github.com/yandex-cloud/go-sdk/v2/pkg/endpoints"
12 pkgerrors "github.com/yandex-cloud/go-sdk/v2/pkg/errors"
13 )
14
15 // ConnPool manages a pool of gRPC client connections with support for connection reuse and context-based lifecycle.
16 type ConnPool struct {
17 dialOptions []grpc.DialOption
18
19 ctx context.Context
20 cancel context.CancelFunc
21
22 mu sync.RWMutex
23 conns sync.Map
24 closed bool
25 }
26
27 // NewConnPool creates a new connection pool with the provided gRPC dial options. It initializes context and internal state.
28 func NewConnPool(opts []grpc.DialOption) *ConnPool {
29 ctx, cancel := context.WithCancel(context.Background())
30
31 return &ConnPool{
32 dialOptions: opts,
33 ctx: ctx,
34 cancel: cancel,
35 conns: sync.Map{},
36 }
37 }
38
39 // GetConn retrieves or establishes a gRPC connection for the specified endpoint, respecting the provided context constraints.
40 func (cc *ConnPool) GetConn(ctx context.Context, ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
41 if err := ctx.Err(); err != nil {
42 // fail fast: don't even try
43 return nil, status.FromContextError(err).Err()
44 }
45
46 if conn, err := cc.pooledConn(ep); conn != nil || err != nil {
47 return conn, err
48 }
49
50 type dialResult struct {
51 conn *grpc.ClientConn
52 err error
53 }
54
55 resultCh := make(chan dialResult, 1)
56 go func() {
57 defer func() {
58 if r := recover(); r != nil {
59 resultCh <- dialResult{nil, r.(error)}
60 }
61 }()
62
63 v, err := cc.makeDial(ep)
64 resultCh <- dialResult{v, err}
65 }()
66
67 select {
68 case res := <-resultCh:
69 return res.conn, res.err
70 case <-ctx.Done():
71 return nil, status.FromContextError(ctx.Err()).Err()
72 }
73 }
74
75 // makeDial establishes a new gRPC client connection to the specified endpoint and stores it in the connection pool.
76 func (cc *ConnPool) makeDial(ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
77 if conn, err := cc.pooledConn(ep); conn != nil || err != nil {
78 return conn, err
79 }
80
81 opts := make([]grpc.DialOption, 0, len(cc.dialOptions)+len(ep.DialOptions))
82 opts = append(opts, cc.dialOptions...)
83 opts = append(opts, ep.DialOptions...)
84
85 conn, err := grpc.NewClient(ep.Addr, opts...)
86 if err != nil {
87 if cc.ctx.Err() != nil {
88 return nil, pkgerrors.ErrConnContextClosed
89 }
90
91 return nil, &pkgerrors.DialError{Err: err, Addr: ep.Addr}
92 }
93
94 cc.mu.Lock()
95 defer cc.mu.Unlock()
96
97 if cc.closed {
98 _ = conn.Close()
99 return nil, nil
100 }
101
102 cc.conns.Store(ep, conn)
103
104 return conn, nil
105 }
106
107 // pooledConn retrieves a gRPC client connection from the pool for the given endpoint if available, else returns nil.
108 // It ensures thread-safe access and returns an error if the connection context is closed.
109 func (cc *ConnPool) pooledConn(ep *endpoints.Endpoint) (*grpc.ClientConn, error) {
110 cc.mu.RLock()
111 defer cc.mu.RUnlock()
112
113 if cc.closed {
114 return nil, pkgerrors.ErrConnContextClosed
115 }
116
117 value, ok := cc.conns.Load(ep)
118 if !ok {
119 return nil, nil
120 }
121
122 conn := value.(*grpc.ClientConn)
123
124 return conn, nil
125 }
126
127 // Shutdown closes all active connections in the pool, cancels the context, and marks the pool as closed. Returns errors if any occur during the connection closures.
128 func (cc *ConnPool) Shutdown(ctx context.Context) error {
129 cc.mu.Lock()
130 defer cc.mu.Unlock()
131
132 if cc.closed {
133 return nil
134 }
135
136 defer func() {
137 cc.closed = true
138 cc.cancel()
139 }()
140
141 var errs error
142
143 cc.conns.Range(
144 func(_, value any) bool {
145 conn := value.(*grpc.ClientConn)
146 if err := conn.Close(); err != nil {
147 errs = errors.Join(errs, err)
148 }
149
150 return true
151 },
152 )
153
154 return errs
155 }
156