mmsghdr_unix.go raw

   1  // Copyright 2017 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  //go:build aix || linux || netbsd
   6  
   7  package socket
   8  
   9  import (
  10  	"net"
  11  	"os"
  12  	"sync"
  13  	"syscall"
  14  )
  15  
  16  type mmsghdrs []mmsghdr
  17  
  18  func (hs mmsghdrs) unpack(ms []Message, parseFn func([]byte, string) (net.Addr, error), hint string) error {
  19  	for i := range hs {
  20  		ms[i].N = int(hs[i].Len)
  21  		ms[i].NN = hs[i].Hdr.controllen()
  22  		ms[i].Flags = hs[i].Hdr.flags()
  23  		if parseFn != nil {
  24  			var err error
  25  			ms[i].Addr, err = parseFn(hs[i].Hdr.name(), hint)
  26  			if err != nil {
  27  				return err
  28  			}
  29  		}
  30  	}
  31  	return nil
  32  }
  33  
  34  // mmsghdrsPacker packs Message-slices into mmsghdrs (re-)using pre-allocated buffers.
  35  type mmsghdrsPacker struct {
  36  	// hs are the pre-allocated mmsghdrs.
  37  	hs mmsghdrs
  38  	// sockaddrs is the pre-allocated buffer for the Hdr.Name buffers.
  39  	// We use one large buffer for all messages and slice it up.
  40  	sockaddrs []byte
  41  	// vs are the pre-allocated iovecs.
  42  	// We allocate one large buffer for all messages and slice it up. This allows to reuse the buffer
  43  	// if the number of buffers per message is distributed differently between calls.
  44  	vs []iovec
  45  }
  46  
  47  func (p *mmsghdrsPacker) prepare(ms []Message) {
  48  	n := len(ms)
  49  	if n <= cap(p.hs) {
  50  		p.hs = p.hs[:n]
  51  	} else {
  52  		p.hs = make(mmsghdrs, n)
  53  	}
  54  	if n*sizeofSockaddrInet6 <= cap(p.sockaddrs) {
  55  		p.sockaddrs = p.sockaddrs[:n*sizeofSockaddrInet6]
  56  	} else {
  57  		p.sockaddrs = make([]byte, n*sizeofSockaddrInet6)
  58  	}
  59  
  60  	nb := 0
  61  	for _, m := range ms {
  62  		nb += len(m.Buffers)
  63  	}
  64  	if nb <= cap(p.vs) {
  65  		p.vs = p.vs[:nb]
  66  	} else {
  67  		p.vs = make([]iovec, nb)
  68  	}
  69  }
  70  
  71  func (p *mmsghdrsPacker) pack(ms []Message, parseFn func([]byte, string) (net.Addr, error), marshalFn func(net.Addr, []byte) int) mmsghdrs {
  72  	p.prepare(ms)
  73  	hs := p.hs
  74  	vsRest := p.vs
  75  	saRest := p.sockaddrs
  76  	for i := range hs {
  77  		nvs := len(ms[i].Buffers)
  78  		vs := vsRest[:nvs]
  79  		vsRest = vsRest[nvs:]
  80  
  81  		var sa []byte
  82  		if parseFn != nil {
  83  			sa = saRest[:sizeofSockaddrInet6]
  84  			saRest = saRest[sizeofSockaddrInet6:]
  85  		} else if marshalFn != nil {
  86  			n := marshalFn(ms[i].Addr, saRest)
  87  			if n > 0 {
  88  				sa = saRest[:n]
  89  				saRest = saRest[n:]
  90  			}
  91  		}
  92  		hs[i].Hdr.pack(vs, ms[i].Buffers, ms[i].OOB, sa)
  93  	}
  94  	return hs
  95  }
  96  
  97  // syscaller is a helper to invoke recvmmsg and sendmmsg via the RawConn.Read/Write interface.
  98  // It is reusable, to amortize the overhead of allocating a closure for the function passed to
  99  // RawConn.Read/Write.
 100  type syscaller struct {
 101  	n     int
 102  	operr error
 103  	hs    mmsghdrs
 104  	flags int
 105  
 106  	boundRecvmmsgF func(uintptr) bool
 107  	boundSendmmsgF func(uintptr) bool
 108  }
 109  
 110  func (r *syscaller) init() {
 111  	r.boundRecvmmsgF = r.recvmmsgF
 112  	r.boundSendmmsgF = r.sendmmsgF
 113  }
 114  
 115  func (r *syscaller) recvmmsg(c syscall.RawConn, hs mmsghdrs, flags int) (int, error) {
 116  	r.n = 0
 117  	r.operr = nil
 118  	r.hs = hs
 119  	r.flags = flags
 120  	if err := c.Read(r.boundRecvmmsgF); err != nil {
 121  		return r.n, err
 122  	}
 123  	if r.operr != nil {
 124  		return r.n, os.NewSyscallError("recvmmsg", r.operr)
 125  	}
 126  	return r.n, nil
 127  }
 128  
 129  func (r *syscaller) recvmmsgF(s uintptr) bool {
 130  	r.n, r.operr = recvmmsg(s, r.hs, r.flags)
 131  	return ioComplete(r.flags, r.operr)
 132  }
 133  
 134  func (r *syscaller) sendmmsg(c syscall.RawConn, hs mmsghdrs, flags int) (int, error) {
 135  	r.n = 0
 136  	r.operr = nil
 137  	r.hs = hs
 138  	r.flags = flags
 139  	if err := c.Write(r.boundSendmmsgF); err != nil {
 140  		return r.n, err
 141  	}
 142  	if r.operr != nil {
 143  		return r.n, os.NewSyscallError("sendmmsg", r.operr)
 144  	}
 145  	return r.n, nil
 146  }
 147  
 148  func (r *syscaller) sendmmsgF(s uintptr) bool {
 149  	r.n, r.operr = sendmmsg(s, r.hs, r.flags)
 150  	return ioComplete(r.flags, r.operr)
 151  }
 152  
 153  // mmsgTmps holds reusable temporary helpers for recvmmsg and sendmmsg.
 154  type mmsgTmps struct {
 155  	packer    mmsghdrsPacker
 156  	syscaller syscaller
 157  }
 158  
 159  var defaultMmsgTmpsPool = mmsgTmpsPool{
 160  	p: sync.Pool{
 161  		New: func() interface{} {
 162  			tmps := new(mmsgTmps)
 163  			tmps.syscaller.init()
 164  			return tmps
 165  		},
 166  	},
 167  }
 168  
 169  type mmsgTmpsPool struct {
 170  	p sync.Pool
 171  }
 172  
 173  func (p *mmsgTmpsPool) Get() *mmsgTmps {
 174  	m := p.p.Get().(*mmsgTmps)
 175  	// Clear fields up to the len (not the cap) of the slice,
 176  	// assuming that the previous caller only used that many elements.
 177  	for i := range m.packer.sockaddrs {
 178  		m.packer.sockaddrs[i] = 0
 179  	}
 180  	m.packer.sockaddrs = m.packer.sockaddrs[:0]
 181  	for i := range m.packer.vs {
 182  		m.packer.vs[i] = iovec{}
 183  	}
 184  	m.packer.vs = m.packer.vs[:0]
 185  	for i := range m.packer.hs {
 186  		m.packer.hs[i].Len = 0
 187  		m.packer.hs[i].Hdr = msghdr{}
 188  	}
 189  	m.packer.hs = m.packer.hs[:0]
 190  	return m
 191  }
 192  
 193  func (p *mmsgTmpsPool) Put(tmps *mmsgTmps) {
 194  	p.p.Put(tmps)
 195  }
 196