瀏覽代碼

Merge pull request #18 from teepark/longlived-conns

Pool implementation with connection reuse
Jordan Wright 9 年之前
父節點
當前提交
fd703108da
共有 2 個文件被更改,包括 367 次插入0 次删除
  1. 17 0
      README.md
  2. 350 0
      pool.go

+ 17 - 0
README.md

@@ -60,6 +60,23 @@ e := NewEmail()
 e.AttachFile("test.txt")
 ```
 
+#### A Pool of Reusable Connections
+```
+(var ch <-chan *email.Email)
+p := email.NewPool(
+	"smtp.gmail.com:587",
+	4,
+	smtp.PlainAuth("", "test@gmail.com", "password123", "smtp.gmail.com"),
+)
+for i := 0; i < 4; i++ {
+	go func() {
+		for e := range ch {
+			p.Send(e, 10 * time.Second)
+		}
+	}
+}
+```
+
 ### Documentation
 [http://godoc.org/github.com/jordan-wright/email](http://godoc.org/github.com/jordan-wright/email)
 

+ 350 - 0
pool.go

@@ -0,0 +1,350 @@
+package email
+
+import (
+	"crypto/tls"
+	"errors"
+	"io"
+	"net"
+	"net/mail"
+	"net/smtp"
+	"net/textproto"
+	"sync"
+	"syscall"
+	"time"
+)
+
+type Pool struct {
+	addr         string
+	auth         smtp.Auth
+	max          int
+	created      int
+	clients      chan *client
+	rebuild      chan struct{}
+	mut          *sync.Mutex
+	lastBuildErr *timestampedErr
+	closing      chan struct{}
+}
+
+type client struct {
+	*smtp.Client
+	failCount int
+}
+
+type timestampedErr struct {
+	err error
+	ts  time.Time
+}
+
+const maxFails = 4
+
+var (
+	ErrClosed  = errors.New("pool closed")
+	ErrTimeout = errors.New("timed out")
+)
+
+func NewPool(address string, count int, auth smtp.Auth) *Pool {
+	return &Pool{
+		addr:    address,
+		auth:    auth,
+		max:     count,
+		clients: make(chan *client, count),
+		rebuild: make(chan struct{}),
+		closing: make(chan struct{}),
+		mut:     &sync.Mutex{},
+	}
+}
+
+// go1.1 didn't have this method
+func (c *client) Close() error {
+	return c.Text.Close()
+}
+
+func (p *Pool) get(timeout time.Duration) *client {
+	select {
+	case c := <-p.clients:
+		return c
+	default:
+	}
+
+	if p.created < p.max {
+		p.makeOne()
+	}
+
+	var deadline <-chan time.Time
+	if timeout < 0 {
+		deadline = nil
+	} else {
+		deadline = time.After(timeout)
+	}
+
+	for {
+		select {
+		case c := <-p.clients:
+			return c
+		case <-p.rebuild:
+			p.makeOne()
+		case <-deadline:
+			return nil
+		case <-p.closing:
+			return nil
+		}
+	}
+}
+
+func shouldReuse(err error) bool {
+	// certainly not perfect, but might be close:
+	//  - EOF: clearly, the connection went down
+	//  - textproto.Errors were valid SMTP over a valid connection,
+	//    but resulted from an SMTP error response
+	//  - textproto.ProtocolErrors result from connections going down,
+	//    invalid SMTP, that sort of thing
+	//  - syscall.Errno is probably down connection/bad pipe, but
+	//    passed straight through by textproto instead of becoming a
+	//    ProtocolError
+	//  - if we don't recognize the error, don't reuse the connection
+	// A false positive will probably fail on the Reset(), and even if
+	// not will eventually hit maxFails.
+	// A false negative will knock over (and trigger replacement of) a
+	// conn that might have still worked.
+	if err == io.EOF {
+		return false
+	}
+	switch err.(type) {
+	case *textproto.Error:
+		return true
+	case *textproto.ProtocolError, textproto.ProtocolError:
+		return false
+	case syscall.Errno:
+		return false
+	default:
+		return false
+	}
+}
+
+func (p *Pool) replace(c *client) {
+	p.clients <- c
+}
+
+func (p *Pool) inc() bool {
+	if p.created >= p.max {
+		return false
+	}
+
+	p.mut.Lock()
+	defer p.mut.Unlock()
+
+	if p.created >= p.max {
+		return false
+	}
+	p.created++
+	return true
+}
+
+func (p *Pool) dec() {
+	p.mut.Lock()
+	p.created--
+	p.mut.Unlock()
+
+	select {
+	case p.rebuild <- struct{}{}:
+	default:
+	}
+}
+
+func (p *Pool) makeOne() {
+	go func() {
+		if p.inc() {
+			if c, err := p.build(); err == nil {
+				p.clients <- c
+			} else {
+				p.lastBuildErr = &timestampedErr{err, time.Now()}
+				p.dec()
+			}
+		}
+	}()
+}
+
+func startTLS(c *client, addr string) (bool, error) {
+	if ok, _ := c.Extension("STARTTLS"); !ok {
+		return false, nil
+	}
+
+	host, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return false, err
+	}
+
+	if err := c.StartTLS(&tls.Config{ServerName: host}); err != nil {
+		return false, err
+	}
+
+	return true, nil
+}
+
+func addAuth(c *client, auth smtp.Auth) (bool, error) {
+	if ok, _ := c.Extension("AUTH"); !ok {
+		return false, nil
+	}
+
+	if err := c.Auth(auth); err != nil {
+		return false, err
+	}
+
+	return true, nil
+}
+
+func (p *Pool) build() (*client, error) {
+	cl, err := smtp.Dial(p.addr)
+	if err != nil {
+		return nil, err
+	}
+	c := &client{cl, 0}
+
+	if _, err := startTLS(c, p.addr); err != nil {
+		c.Close()
+		return nil, err
+	}
+
+	if p.auth != nil {
+		if _, err := addAuth(c, p.auth); err != nil {
+			c.Close()
+			return nil, err
+		}
+	}
+
+	return c, nil
+}
+
+func (p *Pool) maybeReplace(err error, c *client) {
+	if err == nil {
+		c.failCount = 0
+		p.replace(c)
+		return
+	}
+
+	c.failCount++
+	if c.failCount >= maxFails {
+		goto shutdown
+	}
+
+	if !shouldReuse(err) {
+		goto shutdown
+	}
+
+	if err := c.Reset(); err != nil {
+		goto shutdown
+	}
+
+	p.replace(c)
+	return
+
+shutdown:
+	p.dec()
+	c.Close()
+}
+
+func (p *Pool) failedToGet(startTime time.Time) error {
+	select {
+	case <-p.closing:
+		return ErrClosed
+	default:
+	}
+
+	if p.lastBuildErr != nil && startTime.Before(p.lastBuildErr.ts) {
+		return p.lastBuildErr.err
+	}
+
+	return ErrTimeout
+}
+
+// Send sends an email via a connection pulled from the Pool. The timeout may
+// be <0 to indicate no timeout. Otherwise reaching the timeout will produce
+// and error building a connection that occurred while we were waiting, or
+// otherwise ErrTimeout.
+func (p *Pool) Send(e *Email, timeout time.Duration) (err error) {
+	start := time.Now()
+	c := p.get(timeout)
+	if c == nil {
+		return p.failedToGet(start)
+	}
+
+	defer func() {
+		p.maybeReplace(err, c)
+	}()
+
+	recipients, err := addressLists(e.To, e.Cc, e.Bcc)
+	if err != nil {
+		return
+	}
+
+	msg, err := e.Bytes()
+	if err != nil {
+		return
+	}
+
+	from, err := emailOnly(e.From)
+	if err != nil {
+		return
+	}
+	if err = c.Mail(from); err != nil {
+		return
+	}
+
+	for _, recip := range recipients {
+		if err = c.Rcpt(recip); err != nil {
+			return
+		}
+	}
+
+	w, err := c.Data()
+	if err != nil {
+		return
+	}
+	if _, err = w.Write(msg); err != nil {
+		return
+	}
+
+	err = w.Close()
+
+	return
+}
+
+func emailOnly(full string) (string, error) {
+	addr, err := mail.ParseAddress(full)
+	if err != nil {
+		return "", err
+	}
+	return addr.Address, nil
+}
+
+func addressLists(lists ...[]string) ([]string, error) {
+	length := 0
+	for _, lst := range lists {
+		length += len(lst)
+	}
+	combined := make([]string, 0, length)
+
+	for _, lst := range lists {
+		for _, full := range lst {
+			addr, err := emailOnly(full)
+			if err != nil {
+				return nil, err
+			}
+			combined = append(combined, addr)
+		}
+	}
+
+	return combined, nil
+}
+
+// Close immediately changes the pool's state so no new connections will be
+// created, then gets and closes the existing ones as they become available.
+func (p *Pool) Close() {
+	close(p.closing)
+
+	for p.created > 0 {
+		c := <-p.clients
+		c.Quit()
+		p.dec()
+	}
+}