浏览代码

resiliency features

track consecutive fails on a client, and check errors on Reset() when we
do choose to reuse one.

these changes should make it impossible for the pool to eventually
become populated with clients in a bad state that doesn't get caught by
shouldReuse()
Travis J Parker 10 年之前
父节点
当前提交
6e7b67953c
共有 1 个文件被更改,包括 95 次插入47 次删除
  1. 95 47
      pool.go

+ 95 - 47
pool.go

@@ -17,27 +17,34 @@ type Pool struct {
 	auth    smtp.Auth
 	max     int
 	created int
-	ch      chan *smtp.Client
-	decs    chan struct{}
+	clients chan *client
+	rebuild chan struct{}
 	mut     *sync.Mutex
 }
 
+type client struct {
+	*smtp.Client
+	failCount int
+}
+
+const maxFails = 4
+
 var ErrTimeout = errors.New("timed out")
 
-func NewPool(address string, auth smtp.Auth, count int) *Pool {
+func NewPool(address string, count int, auth smtp.Auth) *Pool {
 	return &Pool{
-		addr: address,
-		auth: auth,
-		max:  count,
-		ch:   make(chan *smtp.Client, count),
-		decs: make(chan struct{}),
-		mut:  &sync.Mutex{},
+		addr:    address,
+		auth:    auth,
+		max:     count,
+		clients: make(chan *client, count),
+		rebuild: make(chan struct{}),
+		mut:     &sync.Mutex{},
 	}
 }
 
-func (p *Pool) get(timeout time.Duration) *smtp.Client {
+func (p *Pool) get(timeout time.Duration) *client {
 	select {
-	case c := <-p.ch:
+	case c := <-p.clients:
 		return c
 	default:
 	}
@@ -49,9 +56,9 @@ func (p *Pool) get(timeout time.Duration) *smtp.Client {
 	deadline := time.After(timeout)
 	for {
 		select {
-		case c := <-p.ch:
+		case c := <-p.clients:
 			return c
-		case <-p.decs:
+		case <-p.rebuild:
 			p.makeOne()
 		case <-deadline:
 			return nil
@@ -60,7 +67,7 @@ func (p *Pool) get(timeout time.Duration) *smtp.Client {
 }
 
 func shouldReuse(err error) bool {
-	// probably needs tweaking, but might be close:
+	// certainly not perfect, but might be close:
 	//  - textproto.Errors were valid SMTP over a valid connection,
 	//    but resulted from an SMTP error response
 	//  - textproto.ProtocolErrors result from connections going down,
@@ -69,6 +76,10 @@ func shouldReuse(err error) bool {
 	//    passed straight through by textproto instead of becoming a
 	//    ProtocolError
 	//  - if we don't recognize the error, don't reuse the connection
+	// A false positive will probably fail on the Reset(), and even if
+	// not will eventually hit maxFails.
+	// A false negative will knock over (and trigger replacement of) a
+	// conn that might have still worked.
 	switch err.(type) {
 	case *textproto.Error:
 		return true
@@ -81,8 +92,8 @@ func shouldReuse(err error) bool {
 	}
 }
 
-func (p *Pool) replace(c *smtp.Client) {
-	p.ch <- c
+func (p *Pool) replace(c *client) {
+	p.clients <- c
 }
 
 func (p *Pool) inc() bool {
@@ -106,7 +117,7 @@ func (p *Pool) dec() {
 	p.mut.Unlock()
 
 	select {
-	case p.decs <- struct{}{}:
+	case p.rebuild <- struct{}{}:
 	default:
 	}
 }
@@ -115,7 +126,7 @@ func (p *Pool) makeOne() {
 	go func() {
 		if p.inc() {
 			if c, err := p.build(); err == nil {
-				p.ch <- c
+				p.clients <- c
 			} else {
 				p.dec()
 			}
@@ -123,37 +134,85 @@ func (p *Pool) makeOne() {
 	}()
 }
 
-func (p *Pool) build() (*smtp.Client, error) {
+func startTLS(c *smtp.Client, addr string) (bool, error) {
+	if ok, _ := c.Extension("STARTTLS"); !ok {
+		return false, nil
+	}
+
+	host, _, err := net.SplitHostPort(addr)
+	if err != nil {
+		return false, err
+	}
+
+	if err := c.StartTLS(&tls.Config{ServerName: host}); err != nil {
+		return false, err
+	}
+
+	return true, nil
+}
+
+func addAuth(c *smtp.Client, auth smtp.Auth) (bool, error) {
+	if ok, _ := c.Extension("AUTH"); !ok {
+		return false, nil
+	}
+
+	if err := c.Auth(auth); err != nil {
+		return false, err
+	}
+
+	return true, nil
+}
+
+func (p *Pool) build() (*client, error) {
 	c, err := smtp.Dial(p.addr)
 	if err != nil {
 		return nil, err
 	}
 
-	onErr := func(err error) error {
+	if _, err := startTLS(c, p.addr); err != nil {
 		c.Quit()
 		c.Close()
-		return err
+		return nil, err
 	}
 
-	if ok, _ := c.Extension("STARTTLS"); ok {
-		host, _, err := net.SplitHostPort(p.addr)
-		if err != nil {
-			return nil, onErr(err)
-		}
-		if err = c.StartTLS(&tls.Config{ServerName: host}); err != nil {
-			return nil, onErr(err)
+	if p.auth != nil {
+		if _, err := addAuth(c, p.auth); err != nil {
+			c.Quit()
+			c.Close()
+			return nil, err
 		}
 	}
 
-	if p.auth != nil {
-		if ok, _ := c.Extension("AUTH"); ok {
-			if err := c.Auth(p.auth); err != nil {
-				return nil, onErr(err)
-			}
-		}
+	return &client{c, 0}, nil
+}
+
+func (p *Pool) maybeReplace(err error, c *client) {
+	if err == nil {
+		c.failCount = 0
+		p.replace(c)
+		return
 	}
 
-	return c, nil
+	c.failCount++
+	if c.failCount >= maxFails {
+		goto shutdown
+	}
+
+	if !shouldReuse(err) {
+		goto shutdown
+	}
+
+	if err := c.Reset(); err != nil {
+		goto shutdown
+	}
+
+	p.replace(c)
+	return
+
+	shutdown:
+	p.dec()
+	c.Quit()
+	c.Close()
 }
 
 func (p *Pool) Send(e *Email, timeout time.Duration) (err error) {
@@ -163,18 +222,7 @@ func (p *Pool) Send(e *Email, timeout time.Duration) (err error) {
 	}
 
 	defer func() {
-		if err != nil {
-			if shouldReuse(err) {
-				c.Reset()
-				p.replace(c)
-			} else {
-				p.dec()
-				c.Quit()
-				c.Close()
-			}
-		} else {
-			p.replace(c)
-		}
+		p.maybeReplace(err, c)
 	}()
 
 	recipients, err := addressLists(e.To, e.Cc, e.Bcc)