server.go raw

   1  package smtp
   2  
   3  import (
   4  	"context"
   5  	"crypto/tls"
   6  	"errors"
   7  	"io"
   8  	"log"
   9  	"net"
  10  	"os"
  11  	"sync"
  12  	"time"
  13  )
  14  
  15  var ErrServerClosed = errors.New("smtp: server already closed")
  16  
  17  // Logger interface is used by Server to report unexpected internal errors.
  18  type Logger interface {
  19  	Printf(format string, v ...interface{})
  20  	Println(v ...interface{})
  21  }
  22  
  23  // A SMTP server.
  24  type Server struct {
  25  	// The type of network, "tcp" or "unix".
  26  	Network string
  27  	// TCP or Unix address to listen on.
  28  	Addr string
  29  	// The server TLS configuration.
  30  	TLSConfig *tls.Config
  31  	// Enable LMTP mode, as defined in RFC 2033.
  32  	LMTP bool
  33  
  34  	Domain            string
  35  	MaxRecipients     int
  36  	MaxMessageBytes   int64
  37  	MaxLineLength     int
  38  	AllowInsecureAuth bool
  39  	Debug             io.Writer
  40  	ErrorLog          Logger
  41  	ReadTimeout       time.Duration
  42  	WriteTimeout      time.Duration
  43  
  44  	// Advertise SMTPUTF8 (RFC 6531) capability.
  45  	// Should be used only if backend supports it.
  46  	EnableSMTPUTF8 bool
  47  
  48  	// Advertise REQUIRETLS (RFC 8689) capability.
  49  	// Should be used only if backend supports it.
  50  	EnableREQUIRETLS bool
  51  
  52  	// Advertise BINARYMIME (RFC 3030) capability.
  53  	// Should be used only if backend supports it.
  54  	EnableBINARYMIME bool
  55  
  56  	// Advertise DSN (RFC 3461) capability.
  57  	// Should be used only if backend supports it.
  58  	EnableDSN bool
  59  
  60  	// Advertise RRVS (RFC 7293) capability.
  61  	// Should be used only if backend supports it.
  62  	EnableRRVS bool
  63  
  64  	// Advertise DELIVERBY (RFC 2852) capability.
  65  	// Should be used only if backend supports it.
  66  	EnableDELIVERBY bool
  67  	// The minimum time, with seconds precision, that a client
  68  	// may specify in the BY argument with return mode.
  69  	// A zero value indicates no set minimum.
  70  	// Only use if DELIVERBY is enabled.
  71  	MinimumDeliverByTime time.Duration
  72  
  73  	// Advertise MT-PRIORITY (RFC 6710) capability.
  74  	// Should only be used if backend supports it.
  75  	EnableMTPRIORITY bool
  76  	// The priority profile mapping as defined
  77  	// in RFC 6710 section 10.2.
  78  	//
  79  	// Default value of NONE to advertise no specific profile.
  80  	MtPriorityProfile PriorityProfile
  81  
  82  	// The server backend.
  83  	Backend Backend
  84  
  85  	wg   sync.WaitGroup
  86  	done chan struct{}
  87  
  88  	locker    sync.Mutex
  89  	listeners []net.Listener
  90  	conns     map[*Conn]struct{}
  91  }
  92  
  93  // New creates a new SMTP server.
  94  func NewServer(be Backend) *Server {
  95  	return &Server{
  96  		// Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6)
  97  		MaxLineLength: 2000,
  98  
  99  		Backend:  be,
 100  		done:     make(chan struct{}, 1),
 101  		ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
 102  		conns:    make(map[*Conn]struct{}),
 103  	}
 104  }
 105  
 106  // Serve accepts incoming connections on the Listener l.
 107  func (s *Server) Serve(l net.Listener) error {
 108  	s.locker.Lock()
 109  	s.listeners = append(s.listeners, l)
 110  	s.locker.Unlock()
 111  
 112  	var tempDelay time.Duration // how long to sleep on accept failure
 113  
 114  	for {
 115  		c, err := l.Accept()
 116  		if err != nil {
 117  			select {
 118  			case <-s.done:
 119  				// we called Close()
 120  				return nil
 121  			default:
 122  			}
 123  			if ne, ok := err.(net.Error); ok && ne.Temporary() {
 124  				if tempDelay == 0 {
 125  					tempDelay = 5 * time.Millisecond
 126  				} else {
 127  					tempDelay *= 2
 128  				}
 129  				if max := 1 * time.Second; tempDelay > max {
 130  					tempDelay = max
 131  				}
 132  				s.ErrorLog.Printf("accept error: %s; retrying in %s", err, tempDelay)
 133  				time.Sleep(tempDelay)
 134  				continue
 135  			}
 136  			return err
 137  		}
 138  
 139  		s.wg.Add(1)
 140  		go func() {
 141  			defer s.wg.Done()
 142  
 143  			err := s.handleConn(newConn(c, s))
 144  			if err != nil {
 145  				s.ErrorLog.Printf("error handling %v: %s", c.RemoteAddr(), err)
 146  			}
 147  		}()
 148  	}
 149  }
 150  
 151  func (s *Server) handleConn(c *Conn) error {
 152  	s.locker.Lock()
 153  	s.conns[c] = struct{}{}
 154  	s.locker.Unlock()
 155  
 156  	defer func() {
 157  		c.Close()
 158  
 159  		s.locker.Lock()
 160  		delete(s.conns, c)
 161  		s.locker.Unlock()
 162  	}()
 163  
 164  	if tlsConn, ok := c.conn.(*tls.Conn); ok {
 165  		if d := s.ReadTimeout; d != 0 {
 166  			c.conn.SetReadDeadline(time.Now().Add(d))
 167  		}
 168  		if d := s.WriteTimeout; d != 0 {
 169  			c.conn.SetWriteDeadline(time.Now().Add(d))
 170  		}
 171  		if err := tlsConn.Handshake(); err != nil {
 172  			return err
 173  		}
 174  	}
 175  
 176  	c.greet()
 177  
 178  	for {
 179  		line, err := c.readLine()
 180  		if err == nil {
 181  			cmd, arg, err := parseCmd(line)
 182  			if err != nil {
 183  				c.protocolError(501, EnhancedCode{5, 5, 2}, "Bad command")
 184  				continue
 185  			}
 186  
 187  			c.handle(cmd, arg)
 188  		} else {
 189  			if err == io.EOF || errors.Is(err, net.ErrClosed) {
 190  				return nil
 191  			}
 192  			if err == ErrTooLongLine {
 193  				c.writeResponse(500, EnhancedCode{5, 4, 0}, "Too long line, closing connection")
 194  				return nil
 195  			}
 196  
 197  			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
 198  				c.writeResponse(421, EnhancedCode{4, 4, 2}, "Idle timeout, bye bye")
 199  				return nil
 200  			}
 201  
 202  			c.writeResponse(421, EnhancedCode{4, 4, 0}, "Connection error, sorry")
 203  			return err
 204  		}
 205  	}
 206  }
 207  
 208  func (s *Server) network() string {
 209  	if s.Network != "" {
 210  		return s.Network
 211  	}
 212  	if s.LMTP {
 213  		return "unix"
 214  	}
 215  	return "tcp"
 216  }
 217  
 218  // ListenAndServe listens on the network address s.Addr and then calls Serve
 219  // to handle requests on incoming connections.
 220  //
 221  // If s.Addr is blank and LMTP is disabled, ":smtp" is used.
 222  func (s *Server) ListenAndServe() error {
 223  	network := s.network()
 224  
 225  	addr := s.Addr
 226  	if !s.LMTP && addr == "" {
 227  		addr = ":smtp"
 228  	}
 229  
 230  	l, err := net.Listen(network, addr)
 231  	if err != nil {
 232  		return err
 233  	}
 234  
 235  	return s.Serve(l)
 236  }
 237  
 238  // ListenAndServeTLS listens on the TCP network address s.Addr and then calls
 239  // Serve to handle requests on incoming TLS connections.
 240  //
 241  // If s.Addr is blank and LMTP is disabled, ":smtps" is used.
 242  func (s *Server) ListenAndServeTLS() error {
 243  	network := s.network()
 244  
 245  	addr := s.Addr
 246  	if !s.LMTP && addr == "" {
 247  		addr = ":smtps"
 248  	}
 249  
 250  	l, err := tls.Listen(network, addr, s.TLSConfig)
 251  	if err != nil {
 252  		return err
 253  	}
 254  
 255  	return s.Serve(l)
 256  }
 257  
 258  // Close immediately closes all active listeners and connections.
 259  //
 260  // Close returns any error returned from closing the server's underlying
 261  // listener(s).
 262  func (s *Server) Close() error {
 263  	select {
 264  	case <-s.done:
 265  		return ErrServerClosed
 266  	default:
 267  		close(s.done)
 268  	}
 269  
 270  	var err error
 271  	s.locker.Lock()
 272  	for _, l := range s.listeners {
 273  		if lerr := l.Close(); lerr != nil && err == nil {
 274  			err = lerr
 275  		}
 276  	}
 277  
 278  	for conn := range s.conns {
 279  		conn.Close()
 280  	}
 281  	s.locker.Unlock()
 282  
 283  	return err
 284  }
 285  
 286  // Shutdown gracefully shuts down the server without interrupting any
 287  // active connections. Shutdown works by first closing all open
 288  // listeners and then waiting indefinitely for connections to return to
 289  // idle and then shut down.
 290  // If the provided context expires before the shutdown is complete,
 291  // Shutdown returns the context's error, otherwise it returns any
 292  // error returned from closing the Server's underlying Listener(s).
 293  func (s *Server) Shutdown(ctx context.Context) error {
 294  	select {
 295  	case <-s.done:
 296  		return ErrServerClosed
 297  	default:
 298  		close(s.done)
 299  	}
 300  
 301  	var err error
 302  	s.locker.Lock()
 303  	for _, l := range s.listeners {
 304  		if lerr := l.Close(); lerr != nil && err == nil {
 305  			err = lerr
 306  		}
 307  	}
 308  	s.locker.Unlock()
 309  
 310  	connDone := make(chan struct{})
 311  	go func() {
 312  		defer close(connDone)
 313  		s.wg.Wait()
 314  	}()
 315  
 316  	select {
 317  	case <-ctx.Done():
 318  		return ctx.Err()
 319  	case <-connDone:
 320  		return err
 321  	}
 322  }
 323