You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
3.9 KiB

  1. package pool
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "sync/atomic"
  7. )
  8. const (
  9. stateDefault = 0
  10. stateInited = 1
  11. stateClosed = 2
  12. )
  13. type BadConnError struct {
  14. wrapped error
  15. }
  16. var _ error = (*BadConnError)(nil)
  17. func (e BadConnError) Error() string {
  18. s := "redis: Conn is in a bad state"
  19. if e.wrapped != nil {
  20. s += ": " + e.wrapped.Error()
  21. }
  22. return s
  23. }
  24. func (e BadConnError) Unwrap() error {
  25. return e.wrapped
  26. }
  27. //------------------------------------------------------------------------------
  28. type StickyConnPool struct {
  29. pool Pooler
  30. shared int32 // atomic
  31. state uint32 // atomic
  32. ch chan *Conn
  33. _badConnError atomic.Value
  34. }
  35. var _ Pooler = (*StickyConnPool)(nil)
  36. func NewStickyConnPool(pool Pooler) *StickyConnPool {
  37. p, ok := pool.(*StickyConnPool)
  38. if !ok {
  39. p = &StickyConnPool{
  40. pool: pool,
  41. ch: make(chan *Conn, 1),
  42. }
  43. }
  44. atomic.AddInt32(&p.shared, 1)
  45. return p
  46. }
  47. func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) {
  48. return p.pool.NewConn(ctx)
  49. }
  50. func (p *StickyConnPool) CloseConn(cn *Conn) error {
  51. return p.pool.CloseConn(cn)
  52. }
  53. func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
  54. // In worst case this races with Close which is not a very common operation.
  55. for i := 0; i < 1000; i++ {
  56. switch atomic.LoadUint32(&p.state) {
  57. case stateDefault:
  58. cn, err := p.pool.Get(ctx)
  59. if err != nil {
  60. return nil, err
  61. }
  62. if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) {
  63. return cn, nil
  64. }
  65. p.pool.Remove(ctx, cn, ErrClosed)
  66. case stateInited:
  67. if err := p.badConnError(); err != nil {
  68. return nil, err
  69. }
  70. cn, ok := <-p.ch
  71. if !ok {
  72. return nil, ErrClosed
  73. }
  74. return cn, nil
  75. case stateClosed:
  76. return nil, ErrClosed
  77. default:
  78. panic("not reached")
  79. }
  80. }
  81. return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop")
  82. }
  83. func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) {
  84. defer func() {
  85. if recover() != nil {
  86. p.freeConn(ctx, cn)
  87. }
  88. }()
  89. p.ch <- cn
  90. }
  91. func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) {
  92. if err := p.badConnError(); err != nil {
  93. p.pool.Remove(ctx, cn, err)
  94. } else {
  95. p.pool.Put(ctx, cn)
  96. }
  97. }
  98. func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
  99. defer func() {
  100. if recover() != nil {
  101. p.pool.Remove(ctx, cn, ErrClosed)
  102. }
  103. }()
  104. p._badConnError.Store(BadConnError{wrapped: reason})
  105. p.ch <- cn
  106. }
  107. func (p *StickyConnPool) Close() error {
  108. if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
  109. return nil
  110. }
  111. for i := 0; i < 1000; i++ {
  112. state := atomic.LoadUint32(&p.state)
  113. if state == stateClosed {
  114. return ErrClosed
  115. }
  116. if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) {
  117. close(p.ch)
  118. cn, ok := <-p.ch
  119. if ok {
  120. p.freeConn(context.TODO(), cn)
  121. }
  122. return nil
  123. }
  124. }
  125. return errors.New("redis: StickyConnPool.Close: infinite loop")
  126. }
  127. func (p *StickyConnPool) Reset(ctx context.Context) error {
  128. if p.badConnError() == nil {
  129. return nil
  130. }
  131. select {
  132. case cn, ok := <-p.ch:
  133. if !ok {
  134. return ErrClosed
  135. }
  136. p.pool.Remove(ctx, cn, ErrClosed)
  137. p._badConnError.Store(BadConnError{wrapped: nil})
  138. default:
  139. return errors.New("redis: StickyConnPool does not have a Conn")
  140. }
  141. if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) {
  142. state := atomic.LoadUint32(&p.state)
  143. return fmt.Errorf("redis: invalid StickyConnPool state: %d", state)
  144. }
  145. return nil
  146. }
  147. func (p *StickyConnPool) badConnError() error {
  148. if v := p._badConnError.Load(); v != nil {
  149. if err := v.(BadConnError); err.wrapped != nil {
  150. return err
  151. }
  152. }
  153. return nil
  154. }
  155. func (p *StickyConnPool) Len() int {
  156. switch atomic.LoadUint32(&p.state) {
  157. case stateDefault:
  158. return 0
  159. case stateInited:
  160. return 1
  161. case stateClosed:
  162. return 0
  163. default:
  164. panic("not reached")
  165. }
  166. }
  167. func (p *StickyConnPool) IdleLen() int {
  168. return len(p.ch)
  169. }
  170. func (p *StickyConnPool) Stats() *Stats {
  171. return &Stats{}
  172. }