|
- package pool
-
- import (
- "context"
- "errors"
- "fmt"
- "sync/atomic"
- )
-
- const (
- stateDefault = 0
- stateInited = 1
- stateClosed = 2
- )
-
- type BadConnError struct {
- wrapped error
- }
-
- var _ error = (*BadConnError)(nil)
-
- func (e BadConnError) Error() string {
- s := "redis: Conn is in a bad state"
- if e.wrapped != nil {
- s += ": " + e.wrapped.Error()
- }
- return s
- }
-
- func (e BadConnError) Unwrap() error {
- return e.wrapped
- }
-
- //------------------------------------------------------------------------------
-
- type StickyConnPool struct {
- pool Pooler
- shared int32 // atomic
-
- state uint32 // atomic
- ch chan *Conn
-
- _badConnError atomic.Value
- }
-
- var _ Pooler = (*StickyConnPool)(nil)
-
- func NewStickyConnPool(pool Pooler) *StickyConnPool {
- p, ok := pool.(*StickyConnPool)
- if !ok {
- p = &StickyConnPool{
- pool: pool,
- ch: make(chan *Conn, 1),
- }
- }
- atomic.AddInt32(&p.shared, 1)
- return p
- }
-
- func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
- return p.pool.NewConn(ctx)
- }
-
- func (p *StickyConnPool) CloseConn(cn *Conn) error {
- return p.pool.CloseConn(cn)
- }
-
- func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
- // In worst case this races with Close which is not a very common operation.
- for i := 0; i < 1000; i++ {
- switch atomic.LoadUint32(&p.state) {
- case stateDefault:
- cn, err := p.pool.Get(ctx)
- if err != nil {
- return nil, err
- }
- if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
- return cn, nil
- }
- p.pool.Remove(ctx, cn, ErrClosed)
- case stateInited:
- if err := p.badConnError(); err != nil {
- return nil, err
- }
- cn, ok := <-p.ch
- if !ok {
- return nil, ErrClosed
- }
- return cn, nil
- case stateClosed:
- return nil, ErrClosed
- default:
- panic("not reached")
- }
- }
- return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
- }
-
- func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
- defer func() {
- if recover() != nil {
- p.freeConn(ctx, cn)
- }
- }()
- p.ch <- cn
- }
-
- func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
- if err := p.badConnError(); err != nil {
- p.pool.Remove(ctx, cn, err)
- } else {
- p.pool.Put(ctx, cn)
- }
- }
-
- func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
- defer func() {
- if recover() != nil {
- p.pool.Remove(ctx, cn, ErrClosed)
- }
- }()
- p._badConnError.Store(BadConnError{wrapped: reason})
- p.ch <- cn
- }
-
- func (p *StickyConnPool) Close() error {
- if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
- return nil
- }
-
- for i := 0; i < 1000; i++ {
- state := atomic.LoadUint32(&p.state)
- if state == stateClosed {
- return ErrClosed
- }
- if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
- close(p.ch)
- cn, ok := <-p.ch
- if ok {
- p.freeConn(context.TODO(), cn)
- }
- return nil
- }
- }
-
- return errors.New("redis: StickyConnPool.Close: infinite loop")
- }
-
- func (p *StickyConnPool) Reset(ctx context.Context) error {
- if p.badConnError() == nil {
- return nil
- }
-
- select {
- case cn, ok := <-p.ch:
- if !ok {
- return ErrClosed
- }
- p.pool.Remove(ctx, cn, ErrClosed)
- p._badConnError.Store(BadConnError{wrapped: nil})
- default:
- return errors.New("redis: StickyConnPool does not have a Conn")
- }
-
- if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
- state := atomic.LoadUint32(&p.state)
- return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
- }
-
- return nil
- }
-
- func (p *StickyConnPool) badConnError() error {
- if v := p._badConnError.Load(); v != nil {
- if err := v.(BadConnError); err.wrapped != nil {
- return err
- }
- }
- return nil
- }
-
- func (p *StickyConnPool) Len() int {
- switch atomic.LoadUint32(&p.state) {
- case stateDefault:
- return 0
- case stateInited:
- return 1
- case stateClosed:
- return 0
- default:
- panic("not reached")
- }
- }
-
- func (p *StickyConnPool) IdleLen() int {
- return len(p.ch)
- }
-
- func (p *StickyConnPool) Stats() *Stats {
- return &Stats{}
- }
|