diff options
Diffstat (limited to 'src/dma/vendor/github.com/go-redis/redis/pubsub.go')
-rw-r--r-- | src/dma/vendor/github.com/go-redis/redis/pubsub.go | 127 |
1 files changed, 70 insertions, 57 deletions
diff --git a/src/dma/vendor/github.com/go-redis/redis/pubsub.go b/src/dma/vendor/github.com/go-redis/redis/pubsub.go index 2cfcd150..0afb47cd 100644 --- a/src/dma/vendor/github.com/go-redis/redis/pubsub.go +++ b/src/dma/vendor/github.com/go-redis/redis/pubsub.go @@ -1,15 +1,19 @@ package redis import ( + "errors" "fmt" "sync" "time" "github.com/go-redis/redis/internal" "github.com/go-redis/redis/internal/pool" + "github.com/go-redis/redis/internal/proto" ) -// PubSub implements Pub/Sub commands as described in +var errPingTimeout = errors.New("redis: ping timeout") + +// PubSub implements Pub/Sub commands bas described in // http://redis.io/topics/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. // @@ -46,15 +50,17 @@ func (c *PubSub) conn() (*pool.Conn, error) { return cn, err } -func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { +func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) { if c.closed { return nil, pool.ErrClosed } - if c.cn != nil { return c.cn, nil } + channels := mapKeys(c.channels) + channels = append(channels, newChannels...) + cn, err := c.newConn(channels) if err != nil { return nil, err @@ -69,20 +75,24 @@ func (c *PubSub) _conn(channels []string) (*pool.Conn, error) { return cn, nil } +func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error { + return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error { + return writeCmd(wr, cmd) + }) +} + func (c *PubSub) resubscribe(cn *pool.Conn) error { var firstErr error if len(c.channels) > 0 { - channels := mapKeys(c.channels) - err := c._subscribe(cn, "subscribe", channels...) + err := c._subscribe(cn, "subscribe", mapKeys(c.channels)) if err != nil && firstErr == nil { firstErr = err } } if len(c.patterns) > 0 { - patterns := mapKeys(c.patterns) - err := c._subscribe(cn, "psubscribe", patterns...) + err := c._subscribe(cn, "psubscribe", mapKeys(c.patterns)) if err != nil && firstErr == nil { firstErr = err } @@ -101,51 +111,48 @@ func mapKeys(m map[string]struct{}) []string { return s } -func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { - args := make([]interface{}, 1+len(channels)) - args[0] = redisCmd - for i, channel := range channels { - args[1+i] = channel +func (c *PubSub) _subscribe( + cn *pool.Conn, redisCmd string, channels []string, +) error { + args := make([]interface{}, 0, 1+len(channels)) + args = append(args, redisCmd) + for _, channel := range channels { + args = append(args, channel) } cmd := NewSliceCmd(args...) - - cn.SetWriteTimeout(c.opt.WriteTimeout) - return writeCmd(cn, cmd) + return c.writeCmd(cn, cmd) } -func (c *PubSub) releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) { c.mu.Lock() - c._releaseConn(cn, err) + c._releaseConn(cn, err, allowTimeout) c.mu.Unlock() } -func (c *PubSub) _releaseConn(cn *pool.Conn, err error) { +func (c *PubSub) _releaseConn(cn *pool.Conn, err error, allowTimeout bool) { if c.cn != cn { return } - if internal.IsBadConn(err, true) { - c._reconnect() + if internal.IsBadConn(err, allowTimeout) { + c._reconnect(err) } } -func (c *PubSub) _closeTheCn() error { - var err error - if c.cn != nil { - err = c.closeConn(c.cn) - c.cn = nil - } - return err -} - -func (c *PubSub) reconnect() { - c.mu.Lock() - c._reconnect() - c.mu.Unlock() +func (c *PubSub) _reconnect(reason error) { + _ = c._closeTheCn(reason) + _, _ = c._conn(nil) } -func (c *PubSub) _reconnect() { - _ = c._closeTheCn() - _, _ = c._conn(nil) +func (c *PubSub) _closeTheCn(reason error) error { + if c.cn == nil { + return nil + } + if !c.closed { + internal.Logf("redis: discarding bad PubSub connection: %s", reason) + } + err := c.closeConn(c.cn) + c.cn = nil + return err } func (c *PubSub) Close() error { @@ -158,7 +165,7 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) - err := c._closeTheCn() + err := c._closeTheCn(pool.ErrClosed) return err } @@ -172,8 +179,8 @@ func (c *PubSub) Subscribe(channels ...string) error { if c.channels == nil { c.channels = make(map[string]struct{}) } - for _, channel := range channels { - c.channels[channel] = struct{}{} + for _, s := range channels { + c.channels[s] = struct{}{} } return err } @@ -188,8 +195,8 @@ func (c *PubSub) PSubscribe(patterns ...string) error { if c.patterns == nil { c.patterns = make(map[string]struct{}) } - for _, pattern := range patterns { - c.patterns[pattern] = struct{}{} + for _, s := range patterns { + c.patterns[s] = struct{}{} } return err } @@ -200,10 +207,10 @@ func (c *PubSub) Unsubscribe(channels ...string) error { c.mu.Lock() defer c.mu.Unlock() - err := c.subscribe("unsubscribe", channels...) for _, channel := range channels { delete(c.channels, channel) } + err := c.subscribe("unsubscribe", channels...) return err } @@ -213,10 +220,10 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { c.mu.Lock() defer c.mu.Unlock() - err := c.subscribe("punsubscribe", patterns...) for _, pattern := range patterns { delete(c.patterns, pattern) } + err := c.subscribe("punsubscribe", patterns...) return err } @@ -226,8 +233,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { return err } - err = c._subscribe(cn, redisCmd, channels...) - c._releaseConn(cn, err) + err = c._subscribe(cn, redisCmd, channels) + c._releaseConn(cn, err, false) return err } @@ -243,9 +250,8 @@ func (c *PubSub) Ping(payload ...string) error { return err } - cn.SetWriteTimeout(c.opt.WriteTimeout) - err = writeCmd(cn, cmd) - c.releaseConn(cn, err) + err = c.writeCmd(cn, cmd) + c.releaseConn(cn, err, false) return err } @@ -336,9 +342,11 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { return nil, err } - cn.SetReadTimeout(timeout) - err = c.cmd.readReply(cn) - c.releaseConn(cn, err) + err = cn.WithReader(timeout, func(rd *proto.Reader) error { + return c.cmd.readReply(rd) + }) + + c.releaseConn(cn, err, timeout > 0) if err != nil { return nil, err } @@ -432,21 +440,26 @@ func (c *PubSub) initChannel() { timer := time.NewTimer(timeout) timer.Stop() - var hasPing bool + healthy := true for { timer.Reset(timeout) select { case <-c.ping: - hasPing = true + healthy = true if !timer.Stop() { <-timer.C } case <-timer.C: - if hasPing { - hasPing = false - _ = c.Ping() + pingErr := c.Ping() + if healthy { + healthy = false } else { - c.reconnect() + if pingErr == nil { + pingErr = errPingTimeout + } + c.mu.Lock() + c._reconnect(pingErr) + c.mu.Unlock() } case <-c.exit: return |