dns_resolver.go raw

   1  /*
   2   *
   3   * Copyright 2018 gRPC authors.
   4   *
   5   * Licensed under the Apache License, Version 2.0 (the "License");
   6   * you may not use this file except in compliance with the License.
   7   * You may obtain a copy of the License at
   8   *
   9   *     http://www.apache.org/licenses/LICENSE-2.0
  10   *
  11   * Unless required by applicable law or agreed to in writing, software
  12   * distributed under the License is distributed on an "AS IS" BASIS,
  13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14   * See the License for the specific language governing permissions and
  15   * limitations under the License.
  16   *
  17   */
  18  
  19  // Package dns implements a dns resolver to be installed as the default resolver
  20  // in grpc.
  21  package dns
  22  
  23  import (
  24  	"context"
  25  	"encoding/json"
  26  	"fmt"
  27  	rand "math/rand/v2"
  28  	"net"
  29  	"net/netip"
  30  	"os"
  31  	"strconv"
  32  	"strings"
  33  	"sync"
  34  	"time"
  35  
  36  	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
  37  	"google.golang.org/grpc/grpclog"
  38  	"google.golang.org/grpc/internal/backoff"
  39  	"google.golang.org/grpc/internal/envconfig"
  40  	"google.golang.org/grpc/internal/resolver/dns/internal"
  41  	"google.golang.org/grpc/resolver"
  42  	"google.golang.org/grpc/serviceconfig"
  43  )
  44  
  45  var (
  46  	// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB
  47  	// addresses from SRV records.  Must not be changed after init time.
  48  	EnableSRVLookups = false
  49  
  50  	// MinResolutionInterval is the minimum interval at which re-resolutions are
  51  	// allowed. This helps to prevent excessive re-resolution.
  52  	MinResolutionInterval = 30 * time.Second
  53  
  54  	// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
  55  	// If the timeout expires before a response is received, the request will be canceled.
  56  	//
  57  	// It is recommended to set this value at application startup. Avoid modifying this variable
  58  	// after initialization as it's not thread-safe for concurrent modification.
  59  	ResolvingTimeout = 30 * time.Second
  60  
  61  	logger = grpclog.Component("dns")
  62  )
  63  
  64  func init() {
  65  	resolver.Register(NewBuilder())
  66  	internal.TimeAfterFunc = time.After
  67  	internal.TimeNowFunc = time.Now
  68  	internal.TimeUntilFunc = time.Until
  69  	internal.NewNetResolver = newNetResolver
  70  	internal.AddressDialer = addressDialer
  71  }
  72  
  73  const (
  74  	defaultPort       = "443"
  75  	defaultDNSSvrPort = "53"
  76  	golang            = "GO"
  77  	// txtPrefix is the prefix string to be prepended to the host name for txt
  78  	// record lookup.
  79  	txtPrefix = "_grpc_config."
  80  	// In DNS, service config is encoded in a TXT record via the mechanism
  81  	// described in RFC-1464 using the attribute name grpc_config.
  82  	txtAttribute = "grpc_config="
  83  )
  84  
  85  var addressDialer = func(address string) func(context.Context, string, string) (net.Conn, error) {
  86  	return func(ctx context.Context, network, _ string) (net.Conn, error) {
  87  		var dialer net.Dialer
  88  		return dialer.DialContext(ctx, network, address)
  89  	}
  90  }
  91  
  92  var newNetResolver = func(authority string) (internal.NetResolver, error) {
  93  	if authority == "" {
  94  		return net.DefaultResolver, nil
  95  	}
  96  
  97  	host, port, err := parseTarget(authority, defaultDNSSvrPort)
  98  	if err != nil {
  99  		return nil, err
 100  	}
 101  
 102  	authorityWithPort := net.JoinHostPort(host, port)
 103  
 104  	return &net.Resolver{
 105  		PreferGo: true,
 106  		Dial:     internal.AddressDialer(authorityWithPort),
 107  	}, nil
 108  }
 109  
 110  // NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
 111  func NewBuilder() resolver.Builder {
 112  	return &dnsBuilder{}
 113  }
 114  
 115  type dnsBuilder struct{}
 116  
 117  // Build creates and starts a DNS resolver that watches the name resolution of
 118  // the target.
 119  func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
 120  	host, port, err := parseTarget(target.Endpoint(), defaultPort)
 121  	if err != nil {
 122  		return nil, err
 123  	}
 124  
 125  	// IP address.
 126  	if ipAddr, err := formatIP(host); err == nil {
 127  		addr := []resolver.Address{{Addr: ipAddr + ":" + port}}
 128  		cc.UpdateState(resolver.State{Addresses: addr})
 129  		return deadResolver{}, nil
 130  	}
 131  
 132  	// DNS address (non-IP).
 133  	ctx, cancel := context.WithCancel(context.Background())
 134  	d := &dnsResolver{
 135  		host:                host,
 136  		port:                port,
 137  		ctx:                 ctx,
 138  		cancel:              cancel,
 139  		cc:                  cc,
 140  		rn:                  make(chan struct{}, 1),
 141  		enableServiceConfig: envconfig.EnableTXTServiceConfig && !opts.DisableServiceConfig,
 142  	}
 143  
 144  	d.resolver, err = internal.NewNetResolver(target.URL.Host)
 145  	if err != nil {
 146  		return nil, err
 147  	}
 148  
 149  	d.wg.Add(1)
 150  	go d.watcher()
 151  	return d, nil
 152  }
 153  
 154  // Scheme returns the naming scheme of this resolver builder, which is "dns".
 155  func (b *dnsBuilder) Scheme() string {
 156  	return "dns"
 157  }
 158  
 159  // deadResolver is a resolver that does nothing.
 160  type deadResolver struct{}
 161  
 162  func (deadResolver) ResolveNow(resolver.ResolveNowOptions) {}
 163  
 164  func (deadResolver) Close() {}
 165  
 166  // dnsResolver watches for the name resolution update for a non-IP target.
 167  type dnsResolver struct {
 168  	host     string
 169  	port     string
 170  	resolver internal.NetResolver
 171  	ctx      context.Context
 172  	cancel   context.CancelFunc
 173  	cc       resolver.ClientConn
 174  	// rn channel is used by ResolveNow() to force an immediate resolution of the
 175  	// target.
 176  	rn chan struct{}
 177  	// wg is used to enforce Close() to return after the watcher() goroutine has
 178  	// finished. Otherwise, data race will be possible. [Race Example] in
 179  	// dns_resolver_test we replace the real lookup functions with mocked ones to
 180  	// facilitate testing. If Close() doesn't wait for watcher() goroutine
 181  	// finishes, race detector sometimes will warn lookup (READ the lookup
 182  	// function pointers) inside watcher() goroutine has data race with
 183  	// replaceNetFunc (WRITE the lookup function pointers).
 184  	wg                  sync.WaitGroup
 185  	enableServiceConfig bool
 186  }
 187  
 188  // ResolveNow invoke an immediate resolution of the target that this
 189  // dnsResolver watches.
 190  func (d *dnsResolver) ResolveNow(resolver.ResolveNowOptions) {
 191  	select {
 192  	case d.rn <- struct{}{}:
 193  	default:
 194  	}
 195  }
 196  
 197  // Close closes the dnsResolver.
 198  func (d *dnsResolver) Close() {
 199  	d.cancel()
 200  	d.wg.Wait()
 201  }
 202  
 203  func (d *dnsResolver) watcher() {
 204  	defer d.wg.Done()
 205  	backoffIndex := 1
 206  	for {
 207  		state, err := d.lookup()
 208  		if err != nil {
 209  			// Report error to the underlying grpc.ClientConn.
 210  			d.cc.ReportError(err)
 211  		} else {
 212  			err = d.cc.UpdateState(*state)
 213  		}
 214  
 215  		var nextResolutionTime time.Time
 216  		if err == nil {
 217  			// Success resolving, wait for the next ResolveNow. However, also wait 30
 218  			// seconds at the very least to prevent constantly re-resolving.
 219  			backoffIndex = 1
 220  			nextResolutionTime = internal.TimeNowFunc().Add(MinResolutionInterval)
 221  			select {
 222  			case <-d.ctx.Done():
 223  				return
 224  			case <-d.rn:
 225  			}
 226  		} else {
 227  			// Poll on an error found in DNS Resolver or an error received from
 228  			// ClientConn.
 229  			nextResolutionTime = internal.TimeNowFunc().Add(backoff.DefaultExponential.Backoff(backoffIndex))
 230  			backoffIndex++
 231  		}
 232  		select {
 233  		case <-d.ctx.Done():
 234  			return
 235  		case <-internal.TimeAfterFunc(internal.TimeUntilFunc(nextResolutionTime)):
 236  		}
 237  	}
 238  }
 239  
 240  func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
 241  	// Skip this particular host to avoid timeouts with some versions of
 242  	// systemd-resolved.
 243  	if !EnableSRVLookups || d.host == "metadata.google.internal." {
 244  		return nil, nil
 245  	}
 246  	var newAddrs []resolver.Address
 247  	_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
 248  	if err != nil {
 249  		err = handleDNSError(err, "SRV") // may become nil
 250  		return nil, err
 251  	}
 252  	for _, s := range srvs {
 253  		lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
 254  		if err != nil {
 255  			err = handleDNSError(err, "A") // may become nil
 256  			if err == nil {
 257  				// If there are other SRV records, look them up and ignore this
 258  				// one that does not exist.
 259  				continue
 260  			}
 261  			return nil, err
 262  		}
 263  		for _, a := range lbAddrs {
 264  			ip, err := formatIP(a)
 265  			if err != nil {
 266  				return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
 267  			}
 268  			addr := ip + ":" + strconv.Itoa(int(s.Port))
 269  			newAddrs = append(newAddrs, resolver.Address{Addr: addr, ServerName: s.Target})
 270  		}
 271  	}
 272  	return newAddrs, nil
 273  }
 274  
 275  func handleDNSError(err error, lookupType string) error {
 276  	dnsErr, ok := err.(*net.DNSError)
 277  	if ok && !dnsErr.IsTimeout && !dnsErr.IsTemporary {
 278  		// Timeouts and temporary errors should be communicated to gRPC to
 279  		// attempt another DNS query (with backoff).  Other errors should be
 280  		// suppressed (they may represent the absence of a TXT record).
 281  		return nil
 282  	}
 283  	if err != nil {
 284  		err = fmt.Errorf("dns: %v record lookup error: %v", lookupType, err)
 285  		logger.Info(err)
 286  	}
 287  	return err
 288  }
 289  
 290  func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
 291  	ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
 292  	if err != nil {
 293  		if envconfig.TXTErrIgnore {
 294  			return nil
 295  		}
 296  		if err = handleDNSError(err, "TXT"); err != nil {
 297  			return &serviceconfig.ParseResult{Err: err}
 298  		}
 299  		return nil
 300  	}
 301  	var res string
 302  	for _, s := range ss {
 303  		res += s
 304  	}
 305  
 306  	// TXT record must have "grpc_config=" attribute in order to be used as
 307  	// service config.
 308  	if !strings.HasPrefix(res, txtAttribute) {
 309  		logger.Warningf("dns: TXT record %v missing %v attribute", res, txtAttribute)
 310  		// This is not an error; it is the equivalent of not having a service
 311  		// config.
 312  		return nil
 313  	}
 314  	sc := canaryingSC(strings.TrimPrefix(res, txtAttribute))
 315  	return d.cc.ParseServiceConfig(sc)
 316  }
 317  
 318  func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
 319  	addrs, err := d.resolver.LookupHost(ctx, d.host)
 320  	if err != nil {
 321  		err = handleDNSError(err, "A")
 322  		return nil, err
 323  	}
 324  	newAddrs := make([]resolver.Address, 0, len(addrs))
 325  	for _, a := range addrs {
 326  		ip, err := formatIP(a)
 327  		if err != nil {
 328  			return nil, fmt.Errorf("dns: error parsing A record IP address %v: %v", a, err)
 329  		}
 330  		addr := ip + ":" + d.port
 331  		newAddrs = append(newAddrs, resolver.Address{Addr: addr})
 332  	}
 333  	return newAddrs, nil
 334  }
 335  
 336  func (d *dnsResolver) lookup() (*resolver.State, error) {
 337  	ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
 338  	defer cancel()
 339  	srv, srvErr := d.lookupSRV(ctx)
 340  	addrs, hostErr := d.lookupHost(ctx)
 341  	if hostErr != nil && (srvErr != nil || len(srv) == 0) {
 342  		return nil, hostErr
 343  	}
 344  
 345  	state := resolver.State{Addresses: addrs}
 346  	if len(srv) > 0 {
 347  		state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
 348  	}
 349  	if d.enableServiceConfig {
 350  		state.ServiceConfig = d.lookupTXT(ctx)
 351  	}
 352  	return &state, nil
 353  }
 354  
 355  // formatIP returns an error if addr is not a valid textual representation of
 356  // an IP address. If addr is an IPv4 address, return the addr and error = nil.
 357  // If addr is an IPv6 address, return the addr enclosed in square brackets and
 358  // error = nil.
 359  func formatIP(addr string) (string, error) {
 360  	ip, err := netip.ParseAddr(addr)
 361  	if err != nil {
 362  		return "", err
 363  	}
 364  	if ip.Is4() {
 365  		return addr, nil
 366  	}
 367  	return "[" + addr + "]", nil
 368  }
 369  
 370  // parseTarget takes the user input target string and default port, returns
 371  // formatted host and port info. If target doesn't specify a port, set the port
 372  // to be the defaultPort. If target is in IPv6 format and host-name is enclosed
 373  // in square brackets, brackets are stripped when setting the host.
 374  // examples:
 375  // target: "www.google.com" defaultPort: "443" returns host: "www.google.com", port: "443"
 376  // target: "ipv4-host:80" defaultPort: "443" returns host: "ipv4-host", port: "80"
 377  // target: "[ipv6-host]" defaultPort: "443" returns host: "ipv6-host", port: "443"
 378  // target: ":80" defaultPort: "443" returns host: "localhost", port: "80"
 379  func parseTarget(target, defaultPort string) (host, port string, err error) {
 380  	if target == "" {
 381  		return "", "", internal.ErrMissingAddr
 382  	}
 383  	if _, err := netip.ParseAddr(target); err == nil {
 384  		// target is an IPv4 or IPv6(without brackets) address
 385  		return target, defaultPort, nil
 386  	}
 387  	if host, port, err = net.SplitHostPort(target); err == nil {
 388  		if port == "" {
 389  			// If the port field is empty (target ends with colon), e.g. "[::1]:",
 390  			// this is an error.
 391  			return "", "", internal.ErrEndsWithColon
 392  		}
 393  		// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
 394  		if host == "" {
 395  			// Keep consistent with net.Dial(): If the host is empty, as in ":80",
 396  			// the local system is assumed.
 397  			host = "localhost"
 398  		}
 399  		return host, port, nil
 400  	}
 401  	if host, port, err = net.SplitHostPort(target + ":" + defaultPort); err == nil {
 402  		// target doesn't have port
 403  		return host, port, nil
 404  	}
 405  	return "", "", fmt.Errorf("invalid target address %v, error info: %v", target, err)
 406  }
 407  
 408  type rawChoice struct {
 409  	ClientLanguage *[]string        `json:"clientLanguage,omitempty"`
 410  	Percentage     *int             `json:"percentage,omitempty"`
 411  	ClientHostName *[]string        `json:"clientHostName,omitempty"`
 412  	ServiceConfig  *json.RawMessage `json:"serviceConfig,omitempty"`
 413  }
 414  
 415  func containsString(a *[]string, b string) bool {
 416  	if a == nil {
 417  		return true
 418  	}
 419  	for _, c := range *a {
 420  		if c == b {
 421  			return true
 422  		}
 423  	}
 424  	return false
 425  }
 426  
 427  func chosenByPercentage(a *int) bool {
 428  	if a == nil {
 429  		return true
 430  	}
 431  	return rand.IntN(100)+1 <= *a
 432  }
 433  
 434  func canaryingSC(js string) string {
 435  	if js == "" {
 436  		return ""
 437  	}
 438  	var rcs []rawChoice
 439  	err := json.Unmarshal([]byte(js), &rcs)
 440  	if err != nil {
 441  		logger.Warningf("dns: error parsing service config json: %v", err)
 442  		return ""
 443  	}
 444  	cliHostname, err := os.Hostname()
 445  	if err != nil {
 446  		logger.Warningf("dns: error getting client hostname: %v", err)
 447  		return ""
 448  	}
 449  	var sc string
 450  	for _, c := range rcs {
 451  		if !containsString(c.ClientLanguage, golang) ||
 452  			!chosenByPercentage(c.Percentage) ||
 453  			!containsString(c.ClientHostName, cliHostname) ||
 454  			c.ServiceConfig == nil {
 455  			continue
 456  		}
 457  		sc = string(*c.ServiceConfig)
 458  		break
 459  	}
 460  	return sc
 461  }
 462