lookup_windows.mx raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package net
   6  
   7  import (
   8  	"context"
   9  	"internal/syscall/windows"
  10  	"os"
  11  	"runtime"
  12  	"syscall"
  13  	"time"
  14  	"unsafe"
  15  )
  16  
  17  // cgoAvailable set to true to indicate that the cgo resolver
  18  // is available on Windows. Note that on Windows the cgo resolver
  19  // does not actually use cgo.
  20  const cgoAvailable = true
  21  
  22  const (
  23  	_DNS_ERROR_RCODE_NAME_ERROR = syscall.Errno(9003)
  24  	_DNS_INFO_NO_RECORDS        = syscall.Errno(9501)
  25  
  26  	_WSAHOST_NOT_FOUND = syscall.Errno(11001)
  27  	_WSATRY_AGAIN      = syscall.Errno(11002)
  28  	_WSATYPE_NOT_FOUND = syscall.Errno(10109)
  29  )
  30  
  31  func winError(call string, err error) error {
  32  	switch err {
  33  	case _WSAHOST_NOT_FOUND, _DNS_ERROR_RCODE_NAME_ERROR, _DNS_INFO_NO_RECORDS:
  34  		return errNoSuchHost
  35  	}
  36  	return os.NewSyscallError(call, err)
  37  }
  38  
  39  func getprotobyname(name string) (proto int, err error) {
  40  	p, err := syscall.GetProtoByName(name)
  41  	if err != nil {
  42  		return 0, winError("getprotobyname", err)
  43  	}
  44  	return int(p.Proto), nil
  45  }
  46  
  47  // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
  48  func lookupProtocol(ctx context.Context, name string) (int, error) {
  49  	// GetProtoByName return value is stored in thread local storage.
  50  	// Start new os thread before the call to prevent races.
  51  	type result struct {
  52  		proto int
  53  		err   error
  54  	}
  55  	ch := chan result{1} // buffer so that next goroutine never blocks
  56  	go func() {
  57  		if err := acquireThread(ctx); err != nil {
  58  			ch <- result{err: mapErr(err)}
  59  			return
  60  		}
  61  		defer releaseThread()
  62  		runtime.LockOSThread()
  63  		defer runtime.UnlockOSThread()
  64  		proto, err := getprotobyname(name)
  65  		ch <- result{proto: proto, err: err}
  66  	}()
  67  	select {
  68  	case r := <-ch:
  69  		if r.err != nil {
  70  			if proto, err := lookupProtocolMap(name); err == nil {
  71  				return proto, nil
  72  			}
  73  			r.err = newDNSError(r.err, name, "")
  74  		}
  75  		return r.proto, r.err
  76  	case <-ctx.Done():
  77  		return 0, newDNSError(mapErr(ctx.Err()), name, "")
  78  	}
  79  }
  80  
  81  func (r *Resolver) lookupHost(ctx context.Context, name string) ([][]byte, error) {
  82  	ips, err := r.lookupIP(ctx, "ip", name)
  83  	if err != nil {
  84  		return nil, err
  85  	}
  86  	addrs := [][]byte{:0:len(ips)}
  87  	for _, ip := range ips {
  88  		addrs = append(addrs, ip.String())
  89  	}
  90  	return addrs, nil
  91  }
  92  
  93  func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
  94  	if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo {
  95  		return r.goLookupIP(ctx, network, name, order, conf)
  96  	}
  97  
  98  	// TODO(bradfitz,brainman): use ctx more. See TODO below.
  99  
 100  	var family int32 = syscall.AF_UNSPEC
 101  	switch ipVersion(network) {
 102  	case '4':
 103  		family = syscall.AF_INET
 104  	case '6':
 105  		family = syscall.AF_INET6
 106  	}
 107  
 108  	getaddr := func() ([]IPAddr, error) {
 109  		if err := acquireThread(ctx); err != nil {
 110  			return nil, newDNSError(mapErr(err), name, "")
 111  		}
 112  		defer releaseThread()
 113  		hints := syscall.AddrinfoW{
 114  			Family:   family,
 115  			Socktype: syscall.SOCK_STREAM,
 116  			Protocol: syscall.IPPROTO_IP,
 117  		}
 118  		var result *syscall.AddrinfoW
 119  		name16p, err := syscall.UTF16PtrFromString(name)
 120  		if err != nil {
 121  			return nil, newDNSError(err, name, "")
 122  		}
 123  
 124  		dnsConf := getSystemDNSConfig()
 125  		start := time.Now()
 126  
 127  		var e error
 128  		for i := 0; i < dnsConf.attempts; i++ {
 129  			e = syscall.GetAddrInfoW(name16p, nil, &hints, &result)
 130  			if e == nil || e != _WSATRY_AGAIN || time.Since(start) > dnsConf.timeout {
 131  				break
 132  			}
 133  		}
 134  		if e != nil {
 135  			return nil, newDNSError(winError("getaddrinfow", e), name, "")
 136  		}
 137  		defer syscall.FreeAddrInfoW(result)
 138  		addrs := []IPAddr{:0:5}
 139  		for ; result != nil; result = result.Next {
 140  			addr := unsafe.Pointer(result.Addr)
 141  			switch result.Family {
 142  			case syscall.AF_INET:
 143  				a := (*syscall.RawSockaddrInet4)(addr).Addr
 144  				addrs = append(addrs, IPAddr{IP: copyIP(a[:])})
 145  			case syscall.AF_INET6:
 146  				a := (*syscall.RawSockaddrInet6)(addr).Addr
 147  				zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
 148  				addrs = append(addrs, IPAddr{IP: copyIP(a[:]), Zone: zone})
 149  			default:
 150  				return nil, newDNSError(syscall.EWINDOWS, name, "")
 151  			}
 152  		}
 153  		return addrs, nil
 154  	}
 155  
 156  	type ret struct {
 157  		addrs []IPAddr
 158  		err   error
 159  	}
 160  
 161  	var ch chan ret
 162  	if ctx.Err() == nil {
 163  		ch = chan ret{1}
 164  		go func() {
 165  			addr, err := getaddr()
 166  			ch <- ret{addrs: addr, err: err}
 167  		}()
 168  	}
 169  
 170  	select {
 171  	case r := <-ch:
 172  		return r.addrs, r.err
 173  	case <-ctx.Done():
 174  		// TODO(bradfitz,brainman): cancel the ongoing
 175  		// GetAddrInfoW? It would require conditionally using
 176  		// GetAddrInfoEx with lpOverlapped, which requires
 177  		// Windows 8 or newer. I guess we'll need oldLookupIP,
 178  		// newLookupIP, and newerLookUP.
 179  		//
 180  		// For now we just let it finish and write to the
 181  		// buffered channel.
 182  		return nil, newDNSError(mapErr(ctx.Err()), name, "")
 183  	}
 184  }
 185  
 186  func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
 187  	if systemConf().mustUseGoResolver(r) {
 188  		return lookupPortMap(network, service)
 189  	}
 190  
 191  	// TODO(bradfitz): finish ctx plumbing
 192  	if err := acquireThread(ctx); err != nil {
 193  		return 0, newDNSError(mapErr(err), network+"/"+service, "")
 194  	}
 195  	defer releaseThread()
 196  
 197  	var hints syscall.AddrinfoW
 198  
 199  	switch network {
 200  	case "ip": // no hints
 201  	case "tcp", "tcp4", "tcp6":
 202  		hints.Socktype = syscall.SOCK_STREAM
 203  		hints.Protocol = syscall.IPPROTO_TCP
 204  	case "udp", "udp4", "udp6":
 205  		hints.Socktype = syscall.SOCK_DGRAM
 206  		hints.Protocol = syscall.IPPROTO_UDP
 207  	default:
 208  		return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}
 209  	}
 210  
 211  	switch ipVersion(network) {
 212  	case '4':
 213  		hints.Family = syscall.AF_INET
 214  	case '6':
 215  		hints.Family = syscall.AF_INET6
 216  	}
 217  
 218  	var result *syscall.AddrinfoW
 219  	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
 220  	if e != nil {
 221  		if port, err := lookupPortMap(network, service); err == nil {
 222  			return port, nil
 223  		}
 224  
 225  		// The _WSATYPE_NOT_FOUND error is returned by GetAddrInfoW
 226  		// when the service name is unknown. We are also checking
 227  		// for _WSAHOST_NOT_FOUND here to match the cgo (unix) version
 228  		// cgo_unix.go (cgoLookupServicePort).
 229  		if e == _WSATYPE_NOT_FOUND || e == _WSAHOST_NOT_FOUND {
 230  			return 0, newDNSError(errUnknownPort, network+"/"+service, "")
 231  		}
 232  		return 0, newDNSError(winError("getaddrinfow", e), network+"/"+service, "")
 233  	}
 234  	defer syscall.FreeAddrInfoW(result)
 235  	if result == nil {
 236  		return 0, newDNSError(syscall.EINVAL, network+"/"+service, "")
 237  	}
 238  	addr := unsafe.Pointer(result.Addr)
 239  	switch result.Family {
 240  	case syscall.AF_INET:
 241  		a := (*syscall.RawSockaddrInet4)(addr)
 242  		return int(syscall.Ntohs(a.Port)), nil
 243  	case syscall.AF_INET6:
 244  		a := (*syscall.RawSockaddrInet6)(addr)
 245  		return int(syscall.Ntohs(a.Port)), nil
 246  	}
 247  	return 0, newDNSError(syscall.EINVAL, network+"/"+service, "")
 248  }
 249  
 250  func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
 251  	if order, conf := systemConf().hostLookupOrder(r, name); order != hostLookupCgo {
 252  		return r.goLookupCNAME(ctx, name, order, conf)
 253  	}
 254  
 255  	// TODO(bradfitz): finish ctx plumbing
 256  	if err := acquireThread(ctx); err != nil {
 257  		return "", newDNSError(mapErr(err), name, "")
 258  	}
 259  	defer releaseThread()
 260  	var rec *syscall.DNSRecord
 261  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil)
 262  	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
 263  	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
 264  		// if there are no aliases, the canonical name is the input name
 265  		return absDomainName(name), nil
 266  	}
 267  	if e != nil {
 268  		return "", newDNSError(winError("dnsquery", e), name, "")
 269  	}
 270  	defer syscall.DnsRecordListFree(rec, 1)
 271  
 272  	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), rec)
 273  	cname := windows.UTF16PtrToString(resolved)
 274  	return absDomainName(cname), nil
 275  }
 276  
 277  func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
 278  	if systemConf().mustUseGoResolver(r) {
 279  		return r.goLookupSRV(ctx, service, proto, name)
 280  	}
 281  	// TODO(bradfitz): finish ctx plumbing
 282  	if err := acquireThread(ctx); err != nil {
 283  		return "", nil, newDNSError(mapErr(err), name, "")
 284  	}
 285  	defer releaseThread()
 286  	var target string
 287  	if service == "" && proto == "" {
 288  		target = name
 289  	} else {
 290  		target = "_" + service + "._" + proto + "." + name
 291  	}
 292  	var rec *syscall.DNSRecord
 293  	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
 294  	if e != nil {
 295  		return "", nil, newDNSError(winError("dnsquery", e), name, "")
 296  	}
 297  	defer syscall.DnsRecordListFree(rec, 1)
 298  
 299  	srvs := []*SRV{:0:10}
 300  	for _, p := range validRecs(rec, syscall.DNS_TYPE_SRV, target) {
 301  		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
 302  		srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
 303  	}
 304  	byPriorityWeight(srvs).sort()
 305  	return absDomainName(target), srvs, nil
 306  }
 307  
 308  func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
 309  	if systemConf().mustUseGoResolver(r) {
 310  		return r.goLookupMX(ctx, name)
 311  	}
 312  	// TODO(bradfitz): finish ctx plumbing.
 313  	if err := acquireThread(ctx); err != nil {
 314  		return nil, newDNSError(mapErr(err), name, "")
 315  	}
 316  	defer releaseThread()
 317  	var rec *syscall.DNSRecord
 318  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
 319  	if e != nil {
 320  		return nil, newDNSError(winError("dnsquery", e), name, "")
 321  	}
 322  	defer syscall.DnsRecordListFree(rec, 1)
 323  
 324  	mxs := []*MX{:0:10}
 325  	for _, p := range validRecs(rec, syscall.DNS_TYPE_MX, name) {
 326  		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
 327  		mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
 328  	}
 329  	byPref(mxs).sort()
 330  	return mxs, nil
 331  }
 332  
 333  func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
 334  	if systemConf().mustUseGoResolver(r) {
 335  		return r.goLookupNS(ctx, name)
 336  	}
 337  	// TODO(bradfitz): finish ctx plumbing.
 338  	if err := acquireThread(ctx); err != nil {
 339  		return nil, newDNSError(mapErr(err), name, "")
 340  	}
 341  	defer releaseThread()
 342  	var rec *syscall.DNSRecord
 343  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
 344  	if e != nil {
 345  		return nil, newDNSError(winError("dnsquery", e), name, "")
 346  	}
 347  	defer syscall.DnsRecordListFree(rec, 1)
 348  
 349  	nss := []*NS{:0:10}
 350  	for _, p := range validRecs(rec, syscall.DNS_TYPE_NS, name) {
 351  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
 352  		nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
 353  	}
 354  	return nss, nil
 355  }
 356  
 357  func (r *Resolver) lookupTXT(ctx context.Context, name string) ([][]byte, error) {
 358  	if systemConf().mustUseGoResolver(r) {
 359  		return r.goLookupTXT(ctx, name)
 360  	}
 361  	// TODO(bradfitz): finish ctx plumbing.
 362  	if err := acquireThread(ctx); err != nil {
 363  		return nil, newDNSError(mapErr(err), name, "")
 364  	}
 365  	defer releaseThread()
 366  	var rec *syscall.DNSRecord
 367  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
 368  	if e != nil {
 369  		return nil, newDNSError(winError("dnsquery", e), name, "")
 370  	}
 371  	defer syscall.DnsRecordListFree(rec, 1)
 372  
 373  	txts := [][]byte{:0:10}
 374  	for _, p := range validRecs(rec, syscall.DNS_TYPE_TEXT, name) {
 375  		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
 376  		s := ""
 377  		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] {
 378  			s += windows.UTF16PtrToString(v)
 379  		}
 380  		txts = append(txts, s)
 381  	}
 382  	return txts, nil
 383  }
 384  
 385  func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([][]byte, error) {
 386  	if order, conf := systemConf().addrLookupOrder(r, addr); order != hostLookupCgo {
 387  		return r.goLookupPTR(ctx, addr, order, conf)
 388  	}
 389  
 390  	// TODO(bradfitz): finish ctx plumbing.
 391  	if err := acquireThread(ctx); err != nil {
 392  		return nil, newDNSError(mapErr(err), addr, "")
 393  	}
 394  	defer releaseThread()
 395  	arpa, err := reverseaddr(addr)
 396  	if err != nil {
 397  		return nil, err
 398  	}
 399  	var rec *syscall.DNSRecord
 400  	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
 401  	if e != nil {
 402  		return nil, newDNSError(winError("dnsquery", e), addr, "")
 403  	}
 404  	defer syscall.DnsRecordListFree(rec, 1)
 405  
 406  	ptrs := [][]byte{:0:10}
 407  	for _, p := range validRecs(rec, syscall.DNS_TYPE_PTR, arpa) {
 408  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
 409  		ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
 410  	}
 411  	return ptrs, nil
 412  }
 413  
 414  const dnsSectionMask = 0x0003
 415  
 416  // returns only results applicable to name and resolves CNAME entries.
 417  func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
 418  	cname := syscall.StringToUTF16Ptr(name)
 419  	if dnstype != syscall.DNS_TYPE_CNAME {
 420  		cname = resolveCNAME(cname, r)
 421  	}
 422  	rec := []*syscall.DNSRecord{:0:10}
 423  	for p := r; p != nil; p = p.Next {
 424  		// in case of a local machine, DNS records are returned with DNSREC_QUESTION flag instead of DNS_ANSWER
 425  		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer && p.Dw&dnsSectionMask != syscall.DnsSectionQuestion {
 426  			continue
 427  		}
 428  		if p.Type != dnstype {
 429  			continue
 430  		}
 431  		if !syscall.DnsNameCompare(cname, p.Name) {
 432  			continue
 433  		}
 434  		rec = append(rec, p)
 435  	}
 436  	return rec
 437  }
 438  
 439  // returns the last CNAME in chain.
 440  func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
 441  	// limit cname resolving to 10 in case of an infinite CNAME loop
 442  Cname:
 443  	for cnameloop := 0; cnameloop < 10; cnameloop++ {
 444  		for p := r; p != nil; p = p.Next {
 445  			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
 446  				continue
 447  			}
 448  			if p.Type != syscall.DNS_TYPE_CNAME {
 449  				continue
 450  			}
 451  			if !syscall.DnsNameCompare(name, p.Name) {
 452  				continue
 453  			}
 454  			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
 455  			continue Cname
 456  		}
 457  		break
 458  	}
 459  	return name
 460  }
 461  
 462  // concurrentThreadsLimit returns the number of threads we permit to
 463  // run concurrently doing DNS lookups.
 464  func concurrentThreadsLimit() int {
 465  	return 500
 466  }
 467