You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

737 lines
16 KiB

  1. package redis
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "strconv"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/cespare/xxhash/v2"
  13. rendezvous "github.com/dgryski/go-rendezvous" //nolint
  14. "github.com/go-redis/redis/v8/internal"
  15. "github.com/go-redis/redis/v8/internal/hashtag"
  16. "github.com/go-redis/redis/v8/internal/pool"
  17. "github.com/go-redis/redis/v8/internal/rand"
  18. )
  19. var errRingShardsDown = errors.New("redis: all ring shards are down")
  20. //------------------------------------------------------------------------------
  21. type ConsistentHash interface {
  22. Get(string) string
  23. }
  24. type rendezvousWrapper struct {
  25. *rendezvous.Rendezvous
  26. }
  27. func (w rendezvousWrapper) Get(key string) string {
  28. return w.Lookup(key)
  29. }
  30. func newRendezvous(shards []string) ConsistentHash {
  31. return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)}
  32. }
  33. //------------------------------------------------------------------------------
  34. // RingOptions are used to configure a ring client and should be
  35. // passed to NewRing.
  36. type RingOptions struct {
  37. // Map of name => host:port addresses of ring shards.
  38. Addrs map[string]string
  39. // NewClient creates a shard client with provided name and options.
  40. NewClient func(name string, opt *Options) *Client
  41. // Frequency of PING commands sent to check shards availability.
  42. // Shard is considered down after 3 subsequent failed checks.
  43. HeartbeatFrequency time.Duration
  44. // NewConsistentHash returns a consistent hash that is used
  45. // to distribute keys across the shards.
  46. //
  47. // See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8
  48. // for consistent hashing algorithmic tradeoffs.
  49. NewConsistentHash func(shards []string) ConsistentHash
  50. // Following options are copied from Options struct.
  51. Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
  52. OnConnect func(ctx context.Context, cn *Conn) error
  53. Username string
  54. Password string
  55. DB int
  56. MaxRetries int
  57. MinRetryBackoff time.Duration
  58. MaxRetryBackoff time.Duration
  59. DialTimeout time.Duration
  60. ReadTimeout time.Duration
  61. WriteTimeout time.Duration
  62. // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO).
  63. PoolFIFO bool
  64. PoolSize int
  65. MinIdleConns int
  66. MaxConnAge time.Duration
  67. PoolTimeout time.Duration
  68. IdleTimeout time.Duration
  69. IdleCheckFrequency time.Duration
  70. TLSConfig *tls.Config
  71. Limiter Limiter
  72. }
  73. func (opt *RingOptions) init() {
  74. if opt.NewClient == nil {
  75. opt.NewClient = func(name string, opt *Options) *Client {
  76. return NewClient(opt)
  77. }
  78. }
  79. if opt.HeartbeatFrequency == 0 {
  80. opt.HeartbeatFrequency = 500 * time.Millisecond
  81. }
  82. if opt.NewConsistentHash == nil {
  83. opt.NewConsistentHash = newRendezvous
  84. }
  85. if opt.MaxRetries == -1 {
  86. opt.MaxRetries = 0
  87. } else if opt.MaxRetries == 0 {
  88. opt.MaxRetries = 3
  89. }
  90. switch opt.MinRetryBackoff {
  91. case -1:
  92. opt.MinRetryBackoff = 0
  93. case 0:
  94. opt.MinRetryBackoff = 8 * time.Millisecond
  95. }
  96. switch opt.MaxRetryBackoff {
  97. case -1:
  98. opt.MaxRetryBackoff = 0
  99. case 0:
  100. opt.MaxRetryBackoff = 512 * time.Millisecond
  101. }
  102. }
  103. func (opt *RingOptions) clientOptions() *Options {
  104. return &Options{
  105. Dialer: opt.Dialer,
  106. OnConnect: opt.OnConnect,
  107. Username: opt.Username,
  108. Password: opt.Password,
  109. DB: opt.DB,
  110. MaxRetries: -1,
  111. DialTimeout: opt.DialTimeout,
  112. ReadTimeout: opt.ReadTimeout,
  113. WriteTimeout: opt.WriteTimeout,
  114. PoolFIFO: opt.PoolFIFO,
  115. PoolSize: opt.PoolSize,
  116. MinIdleConns: opt.MinIdleConns,
  117. MaxConnAge: opt.MaxConnAge,
  118. PoolTimeout: opt.PoolTimeout,
  119. IdleTimeout: opt.IdleTimeout,
  120. IdleCheckFrequency: opt.IdleCheckFrequency,
  121. TLSConfig: opt.TLSConfig,
  122. Limiter: opt.Limiter,
  123. }
  124. }
  125. //------------------------------------------------------------------------------
  126. type ringShard struct {
  127. Client *Client
  128. down int32
  129. }
  130. func newRingShard(opt *RingOptions, name, addr string) *ringShard {
  131. clopt := opt.clientOptions()
  132. clopt.Addr = addr
  133. return &ringShard{
  134. Client: opt.NewClient(name, clopt),
  135. }
  136. }
  137. func (shard *ringShard) String() string {
  138. var state string
  139. if shard.IsUp() {
  140. state = "up"
  141. } else {
  142. state = "down"
  143. }
  144. return fmt.Sprintf("%s is %s", shard.Client, state)
  145. }
  146. func (shard *ringShard) IsDown() bool {
  147. const threshold = 3
  148. return atomic.LoadInt32(&shard.down) >= threshold
  149. }
  150. func (shard *ringShard) IsUp() bool {
  151. return !shard.IsDown()
  152. }
  153. // Vote votes to set shard state and returns true if state was changed.
  154. func (shard *ringShard) Vote(up bool) bool {
  155. if up {
  156. changed := shard.IsDown()
  157. atomic.StoreInt32(&shard.down, 0)
  158. return changed
  159. }
  160. if shard.IsDown() {
  161. return false
  162. }
  163. atomic.AddInt32(&shard.down, 1)
  164. return shard.IsDown()
  165. }
  166. //------------------------------------------------------------------------------
  167. type ringShards struct {
  168. opt *RingOptions
  169. mu sync.RWMutex
  170. hash ConsistentHash
  171. shards map[string]*ringShard // read only
  172. list []*ringShard // read only
  173. numShard int
  174. closed bool
  175. }
  176. func newRingShards(opt *RingOptions) *ringShards {
  177. shards := make(map[string]*ringShard, len(opt.Addrs))
  178. list := make([]*ringShard, 0, len(shards))
  179. for name, addr := range opt.Addrs {
  180. shard := newRingShard(opt, name, addr)
  181. shards[name] = shard
  182. list = append(list, shard)
  183. }
  184. c := &ringShards{
  185. opt: opt,
  186. shards: shards,
  187. list: list,
  188. }
  189. c.rebalance()
  190. return c
  191. }
  192. func (c *ringShards) List() []*ringShard {
  193. var list []*ringShard
  194. c.mu.RLock()
  195. if !c.closed {
  196. list = c.list
  197. }
  198. c.mu.RUnlock()
  199. return list
  200. }
  201. func (c *ringShards) Hash(key string) string {
  202. key = hashtag.Key(key)
  203. var hash string
  204. c.mu.RLock()
  205. if c.numShard > 0 {
  206. hash = c.hash.Get(key)
  207. }
  208. c.mu.RUnlock()
  209. return hash
  210. }
  211. func (c *ringShards) GetByKey(key string) (*ringShard, error) {
  212. key = hashtag.Key(key)
  213. c.mu.RLock()
  214. if c.closed {
  215. c.mu.RUnlock()
  216. return nil, pool.ErrClosed
  217. }
  218. if c.numShard == 0 {
  219. c.mu.RUnlock()
  220. return nil, errRingShardsDown
  221. }
  222. hash := c.hash.Get(key)
  223. if hash == "" {
  224. c.mu.RUnlock()
  225. return nil, errRingShardsDown
  226. }
  227. shard := c.shards[hash]
  228. c.mu.RUnlock()
  229. return shard, nil
  230. }
  231. func (c *ringShards) GetByName(shardName string) (*ringShard, error) {
  232. if shardName == "" {
  233. return c.Random()
  234. }
  235. c.mu.RLock()
  236. shard := c.shards[shardName]
  237. c.mu.RUnlock()
  238. return shard, nil
  239. }
  240. func (c *ringShards) Random() (*ringShard, error) {
  241. return c.GetByKey(strconv.Itoa(rand.Int()))
  242. }
  243. // heartbeat monitors state of each shard in the ring.
  244. func (c *ringShards) Heartbeat(frequency time.Duration) {
  245. ticker := time.NewTicker(frequency)
  246. defer ticker.Stop()
  247. ctx := context.Background()
  248. for range ticker.C {
  249. var rebalance bool
  250. for _, shard := range c.List() {
  251. err := shard.Client.Ping(ctx).Err()
  252. isUp := err == nil || err == pool.ErrPoolTimeout
  253. if shard.Vote(isUp) {
  254. internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard)
  255. rebalance = true
  256. }
  257. }
  258. if rebalance {
  259. c.rebalance()
  260. }
  261. }
  262. }
  263. // rebalance removes dead shards from the Ring.
  264. func (c *ringShards) rebalance() {
  265. c.mu.RLock()
  266. shards := c.shards
  267. c.mu.RUnlock()
  268. liveShards := make([]string, 0, len(shards))
  269. for name, shard := range shards {
  270. if shard.IsUp() {
  271. liveShards = append(liveShards, name)
  272. }
  273. }
  274. hash := c.opt.NewConsistentHash(liveShards)
  275. c.mu.Lock()
  276. c.hash = hash
  277. c.numShard = len(liveShards)
  278. c.mu.Unlock()
  279. }
  280. func (c *ringShards) Len() int {
  281. c.mu.RLock()
  282. l := c.numShard
  283. c.mu.RUnlock()
  284. return l
  285. }
  286. func (c *ringShards) Close() error {
  287. c.mu.Lock()
  288. defer c.mu.Unlock()
  289. if c.closed {
  290. return nil
  291. }
  292. c.closed = true
  293. var firstErr error
  294. for _, shard := range c.shards {
  295. if err := shard.Client.Close(); err != nil && firstErr == nil {
  296. firstErr = err
  297. }
  298. }
  299. c.hash = nil
  300. c.shards = nil
  301. c.list = nil
  302. return firstErr
  303. }
  304. //------------------------------------------------------------------------------
  305. type ring struct {
  306. opt *RingOptions
  307. shards *ringShards
  308. cmdsInfoCache *cmdsInfoCache //nolint:structcheck
  309. }
  310. // Ring is a Redis client that uses consistent hashing to distribute
  311. // keys across multiple Redis servers (shards). It's safe for
  312. // concurrent use by multiple goroutines.
  313. //
  314. // Ring monitors the state of each shard and removes dead shards from
  315. // the ring. When a shard comes online it is added back to the ring. This
  316. // gives you maximum availability and partition tolerance, but no
  317. // consistency between different shards or even clients. Each client
  318. // uses shards that are available to the client and does not do any
  319. // coordination when shard state is changed.
  320. //
  321. // Ring should be used when you need multiple Redis servers for caching
  322. // and can tolerate losing data when one of the servers dies.
  323. // Otherwise you should use Redis Cluster.
  324. type Ring struct {
  325. *ring
  326. cmdable
  327. hooks
  328. ctx context.Context
  329. }
  330. func NewRing(opt *RingOptions) *Ring {
  331. opt.init()
  332. ring := Ring{
  333. ring: &ring{
  334. opt: opt,
  335. shards: newRingShards(opt),
  336. },
  337. ctx: context.Background(),
  338. }
  339. ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo)
  340. ring.cmdable = ring.Process
  341. go ring.shards.Heartbeat(opt.HeartbeatFrequency)
  342. return &ring
  343. }
  344. func (c *Ring) Context() context.Context {
  345. return c.ctx
  346. }
  347. func (c *Ring) WithContext(ctx context.Context) *Ring {
  348. if ctx == nil {
  349. panic("nil context")
  350. }
  351. clone := *c
  352. clone.cmdable = clone.Process
  353. clone.hooks.lock()
  354. clone.ctx = ctx
  355. return &clone
  356. }
  357. // Do creates a Cmd from the args and processes the cmd.
  358. func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd {
  359. cmd := NewCmd(ctx, args...)
  360. _ = c.Process(ctx, cmd)
  361. return cmd
  362. }
  363. func (c *Ring) Process(ctx context.Context, cmd Cmder) error {
  364. return c.hooks.process(ctx, cmd, c.process)
  365. }
  366. // Options returns read-only Options that were used to create the client.
  367. func (c *Ring) Options() *RingOptions {
  368. return c.opt
  369. }
  370. func (c *Ring) retryBackoff(attempt int) time.Duration {
  371. return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
  372. }
  373. // PoolStats returns accumulated connection pool stats.
  374. func (c *Ring) PoolStats() *PoolStats {
  375. shards := c.shards.List()
  376. var acc PoolStats
  377. for _, shard := range shards {
  378. s := shard.Client.connPool.Stats()
  379. acc.Hits += s.Hits
  380. acc.Misses += s.Misses
  381. acc.Timeouts += s.Timeouts
  382. acc.TotalConns += s.TotalConns
  383. acc.IdleConns += s.IdleConns
  384. }
  385. return &acc
  386. }
  387. // Len returns the current number of shards in the ring.
  388. func (c *Ring) Len() int {
  389. return c.shards.Len()
  390. }
  391. // Subscribe subscribes the client to the specified channels.
  392. func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub {
  393. if len(channels) == 0 {
  394. panic("at least one channel is required")
  395. }
  396. shard, err := c.shards.GetByKey(channels[0])
  397. if err != nil {
  398. // TODO: return PubSub with sticky error
  399. panic(err)
  400. }
  401. return shard.Client.Subscribe(ctx, channels...)
  402. }
  403. // PSubscribe subscribes the client to the given patterns.
  404. func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub {
  405. if len(channels) == 0 {
  406. panic("at least one channel is required")
  407. }
  408. shard, err := c.shards.GetByKey(channels[0])
  409. if err != nil {
  410. // TODO: return PubSub with sticky error
  411. panic(err)
  412. }
  413. return shard.Client.PSubscribe(ctx, channels...)
  414. }
  415. // ForEachShard concurrently calls the fn on each live shard in the ring.
  416. // It returns the first error if any.
  417. func (c *Ring) ForEachShard(
  418. ctx context.Context,
  419. fn func(ctx context.Context, client *Client) error,
  420. ) error {
  421. shards := c.shards.List()
  422. var wg sync.WaitGroup
  423. errCh := make(chan error, 1)
  424. for _, shard := range shards {
  425. if shard.IsDown() {
  426. continue
  427. }
  428. wg.Add(1)
  429. go func(shard *ringShard) {
  430. defer wg.Done()
  431. err := fn(ctx, shard.Client)
  432. if err != nil {
  433. select {
  434. case errCh <- err:
  435. default:
  436. }
  437. }
  438. }(shard)
  439. }
  440. wg.Wait()
  441. select {
  442. case err := <-errCh:
  443. return err
  444. default:
  445. return nil
  446. }
  447. }
  448. func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) {
  449. shards := c.shards.List()
  450. var firstErr error
  451. for _, shard := range shards {
  452. cmdsInfo, err := shard.Client.Command(ctx).Result()
  453. if err == nil {
  454. return cmdsInfo, nil
  455. }
  456. if firstErr == nil {
  457. firstErr = err
  458. }
  459. }
  460. if firstErr == nil {
  461. return nil, errRingShardsDown
  462. }
  463. return nil, firstErr
  464. }
  465. func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo {
  466. cmdsInfo, err := c.cmdsInfoCache.Get(ctx)
  467. if err != nil {
  468. return nil
  469. }
  470. info := cmdsInfo[name]
  471. if info == nil {
  472. internal.Logger.Printf(c.Context(), "info for cmd=%s not found", name)
  473. }
  474. return info
  475. }
  476. func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) {
  477. cmdInfo := c.cmdInfo(ctx, cmd.Name())
  478. pos := cmdFirstKeyPos(cmd, cmdInfo)
  479. if pos == 0 {
  480. return c.shards.Random()
  481. }
  482. firstKey := cmd.stringArg(pos)
  483. return c.shards.GetByKey(firstKey)
  484. }
  485. func (c *Ring) process(ctx context.Context, cmd Cmder) error {
  486. var lastErr error
  487. for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
  488. if attempt > 0 {
  489. if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
  490. return err
  491. }
  492. }
  493. shard, err := c.cmdShard(ctx, cmd)
  494. if err != nil {
  495. return err
  496. }
  497. lastErr = shard.Client.Process(ctx, cmd)
  498. if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) {
  499. return lastErr
  500. }
  501. }
  502. return lastErr
  503. }
  504. func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  505. return c.Pipeline().Pipelined(ctx, fn)
  506. }
  507. func (c *Ring) Pipeline() Pipeliner {
  508. pipe := Pipeline{
  509. ctx: c.ctx,
  510. exec: c.processPipeline,
  511. }
  512. pipe.init()
  513. return &pipe
  514. }
  515. func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
  516. return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
  517. return c.generalProcessPipeline(ctx, cmds, false)
  518. })
  519. }
  520. func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  521. return c.TxPipeline().Pipelined(ctx, fn)
  522. }
  523. func (c *Ring) TxPipeline() Pipeliner {
  524. pipe := Pipeline{
  525. ctx: c.ctx,
  526. exec: c.processTxPipeline,
  527. }
  528. pipe.init()
  529. return &pipe
  530. }
  531. func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error {
  532. return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
  533. return c.generalProcessPipeline(ctx, cmds, true)
  534. })
  535. }
  536. func (c *Ring) generalProcessPipeline(
  537. ctx context.Context, cmds []Cmder, tx bool,
  538. ) error {
  539. cmdsMap := make(map[string][]Cmder)
  540. for _, cmd := range cmds {
  541. cmdInfo := c.cmdInfo(ctx, cmd.Name())
  542. hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo))
  543. if hash != "" {
  544. hash = c.shards.Hash(hash)
  545. }
  546. cmdsMap[hash] = append(cmdsMap[hash], cmd)
  547. }
  548. var wg sync.WaitGroup
  549. for hash, cmds := range cmdsMap {
  550. wg.Add(1)
  551. go func(hash string, cmds []Cmder) {
  552. defer wg.Done()
  553. _ = c.processShardPipeline(ctx, hash, cmds, tx)
  554. }(hash, cmds)
  555. }
  556. wg.Wait()
  557. return cmdsFirstErr(cmds)
  558. }
  559. func (c *Ring) processShardPipeline(
  560. ctx context.Context, hash string, cmds []Cmder, tx bool,
  561. ) error {
  562. // TODO: retry?
  563. shard, err := c.shards.GetByName(hash)
  564. if err != nil {
  565. setCmdsErr(cmds, err)
  566. return err
  567. }
  568. if tx {
  569. return shard.Client.processTxPipeline(ctx, cmds)
  570. }
  571. return shard.Client.processPipeline(ctx, cmds)
  572. }
  573. func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error {
  574. if len(keys) == 0 {
  575. return fmt.Errorf("redis: Watch requires at least one key")
  576. }
  577. var shards []*ringShard
  578. for _, key := range keys {
  579. if key != "" {
  580. shard, err := c.shards.GetByKey(hashtag.Key(key))
  581. if err != nil {
  582. return err
  583. }
  584. shards = append(shards, shard)
  585. }
  586. }
  587. if len(shards) == 0 {
  588. return fmt.Errorf("redis: Watch requires at least one shard")
  589. }
  590. if len(shards) > 1 {
  591. for _, shard := range shards[1:] {
  592. if shard.Client != shards[0].Client {
  593. err := fmt.Errorf("redis: Watch requires all keys to be in the same shard")
  594. return err
  595. }
  596. }
  597. }
  598. return shards[0].Client.Watch(ctx, fn, keys...)
  599. }
  600. // Close closes the ring client, releasing any open resources.
  601. //
  602. // It is rare to Close a Ring, as the Ring is meant to be long-lived
  603. // and shared between many goroutines.
  604. func (c *Ring) Close() error {
  605. return c.shards.Close()
  606. }