wintun.go raw

   1  //go:build windows
   2  
   3  /* SPDX-License-Identifier: MIT
   4   *
   5   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
   6   */
   7  
   8  package wintun
   9  
  10  import (
  11  	"log"
  12  	"runtime"
  13  	"syscall"
  14  	"unsafe"
  15  
  16  	"golang.org/x/sys/windows"
  17  )
  18  
  19  type loggerLevel int
  20  
  21  const (
  22  	logInfo loggerLevel = iota
  23  	logWarn
  24  	logErr
  25  )
  26  
  27  const AdapterNameMax = 128
  28  
  29  type Adapter struct {
  30  	handle uintptr
  31  }
  32  
  33  var (
  34  	modwintun                         = newLazyDLL("wintun.dll", setupLogger)
  35  	procWintunCreateAdapter           = modwintun.NewProc("WintunCreateAdapter")
  36  	procWintunOpenAdapter             = modwintun.NewProc("WintunOpenAdapter")
  37  	procWintunCloseAdapter            = modwintun.NewProc("WintunCloseAdapter")
  38  	procWintunDeleteDriver            = modwintun.NewProc("WintunDeleteDriver")
  39  	procWintunGetAdapterLUID          = modwintun.NewProc("WintunGetAdapterLUID")
  40  	procWintunGetRunningDriverVersion = modwintun.NewProc("WintunGetRunningDriverVersion")
  41  )
  42  
  43  type TimestampedWriter interface {
  44  	WriteWithTimestamp(p []byte, ts int64) (n int, err error)
  45  }
  46  
  47  func logMessage(level loggerLevel, timestamp uint64, msg *uint16) int {
  48  	if tw, ok := log.Default().Writer().(TimestampedWriter); ok {
  49  		tw.WriteWithTimestamp([]byte(log.Default().Prefix()+windows.UTF16PtrToString(msg)), (int64(timestamp)-116444736000000000)*100)
  50  	} else {
  51  		log.Println(windows.UTF16PtrToString(msg))
  52  	}
  53  	return 0
  54  }
  55  
  56  func setupLogger(dll *lazyDLL) {
  57  	var callback uintptr
  58  	if runtime.GOARCH == "386" {
  59  		callback = windows.NewCallback(func(level loggerLevel, timestampLow, timestampHigh uint32, msg *uint16) int {
  60  			return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
  61  		})
  62  	} else if runtime.GOARCH == "arm" {
  63  		callback = windows.NewCallback(func(level loggerLevel, _, timestampLow, timestampHigh uint32, msg *uint16) int {
  64  			return logMessage(level, uint64(timestampHigh)<<32|uint64(timestampLow), msg)
  65  		})
  66  	} else if runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" {
  67  		callback = windows.NewCallback(logMessage)
  68  	}
  69  	syscall.Syscall(dll.NewProc("WintunSetLogger").Addr(), 1, callback, 0, 0)
  70  }
  71  
  72  func closeAdapter(wintun *Adapter) {
  73  	syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
  74  }
  75  
  76  // CreateAdapter creates a Wintun adapter. name is the cosmetic name of the adapter.
  77  // tunnelType represents the type of adapter and should be "Wintun". requestedGUID is
  78  // the GUID of the created network adapter, which then influences NLA generation
  79  // deterministically. If it is set to nil, the GUID is chosen by the system at random,
  80  // and hence a new NLA entry is created for each new adapter.
  81  func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
  82  	var name16 *uint16
  83  	name16, err = windows.UTF16PtrFromString(name)
  84  	if err != nil {
  85  		return
  86  	}
  87  	var tunnelType16 *uint16
  88  	tunnelType16, err = windows.UTF16PtrFromString(tunnelType)
  89  	if err != nil {
  90  		return
  91  	}
  92  	if err := procWintunCreateAdapter.Find(); err != nil {
  93  		return nil, err
  94  	}
  95  	r0, _, e1 := syscall.Syscall(procWintunCreateAdapter.Addr(), 3, uintptr(unsafe.Pointer(name16)), uintptr(unsafe.Pointer(tunnelType16)), uintptr(unsafe.Pointer(requestedGUID)))
  96  	if r0 == 0 {
  97  		err = e1
  98  		return
  99  	}
 100  	wintun = &Adapter{handle: r0}
 101  	runtime.SetFinalizer(wintun, closeAdapter)
 102  	return
 103  }
 104  
 105  // OpenAdapter opens an existing Wintun adapter by name.
 106  func OpenAdapter(name string) (wintun *Adapter, err error) {
 107  	var name16 *uint16
 108  	name16, err = windows.UTF16PtrFromString(name)
 109  	if err != nil {
 110  		return
 111  	}
 112  	if err := procWintunOpenAdapter.Find(); err != nil {
 113  		return nil, err
 114  	}
 115  	r0, _, e1 := syscall.Syscall(procWintunOpenAdapter.Addr(), 1, uintptr(unsafe.Pointer(name16)), 0, 0)
 116  	if r0 == 0 {
 117  		err = e1
 118  		return
 119  	}
 120  	wintun = &Adapter{handle: r0}
 121  	runtime.SetFinalizer(wintun, closeAdapter)
 122  	return
 123  }
 124  
 125  // Close closes a Wintun adapter.
 126  func (wintun *Adapter) Close() (err error) {
 127  	if err := procWintunCloseAdapter.Find(); err != nil {
 128  		return err
 129  	}
 130  	runtime.SetFinalizer(wintun, nil)
 131  	r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
 132  	if r1 == 0 {
 133  		err = e1
 134  	}
 135  	return
 136  }
 137  
 138  // Uninstall removes the driver from the system if no drivers are currently in use.
 139  func Uninstall() (err error) {
 140  	if err := procWintunDeleteDriver.Find(); err != nil {
 141  		return err
 142  	}
 143  	r1, _, e1 := syscall.Syscall(procWintunDeleteDriver.Addr(), 0, 0, 0, 0)
 144  	if r1 == 0 {
 145  		err = e1
 146  	}
 147  	return
 148  }
 149  
 150  // RunningVersion returns the version of the loaded driver.
 151  func RunningVersion() (version uint32, err error) {
 152  	if err := procWintunGetRunningDriverVersion.Find(); err != nil {
 153  		return 0, err
 154  	}
 155  	r0, _, e1 := syscall.Syscall(procWintunGetRunningDriverVersion.Addr(), 0, 0, 0, 0)
 156  	version = uint32(r0)
 157  	if version == 0 {
 158  		err = e1
 159  	}
 160  	return
 161  }
 162  
 163  // LUID returns the LUID of the adapter.
 164  func (wintun *Adapter) LUID() (luid uint64) {
 165  	syscall.Syscall(procWintunGetAdapterLUID.Addr(), 2, uintptr(wintun.handle), uintptr(unsafe.Pointer(&luid)), 0)
 166  	return
 167  }
 168