You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

774 lines
17 KiB

  1. package redis
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "sync/atomic"
  7. "time"
  8. "github.com/go-redis/redis/v8/internal"
  9. "github.com/go-redis/redis/v8/internal/pool"
  10. "github.com/go-redis/redis/v8/internal/proto"
  11. )
  12. // Nil reply returned by Redis when key does not exist.
  13. const Nil = proto.Nil
  14. func SetLogger(logger internal.Logging) {
  15. internal.Logger = logger
  16. }
  17. //------------------------------------------------------------------------------
  18. type Hook interface {
  19. BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
  20. AfterProcess(ctx context.Context, cmd Cmder) error
  21. BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
  22. AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
  23. }
  24. type hooks struct {
  25. hooks []Hook
  26. }
  27. func (hs *hooks) lock() {
  28. hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
  29. }
  30. func (hs hooks) clone() hooks {
  31. clone := hs
  32. clone.lock()
  33. return clone
  34. }
  35. func (hs *hooks) AddHook(hook Hook) {
  36. hs.hooks = append(hs.hooks, hook)
  37. }
  38. func (hs hooks) process(
  39. ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
  40. ) error {
  41. if len(hs.hooks) == 0 {
  42. err := fn(ctx, cmd)
  43. cmd.SetErr(err)
  44. return err
  45. }
  46. var hookIndex int
  47. var retErr error
  48. for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
  49. ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
  50. if retErr != nil {
  51. cmd.SetErr(retErr)
  52. }
  53. }
  54. if retErr == nil {
  55. retErr = fn(ctx, cmd)
  56. cmd.SetErr(retErr)
  57. }
  58. for hookIndex--; hookIndex >= 0; hookIndex-- {
  59. if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
  60. retErr = err
  61. cmd.SetErr(retErr)
  62. }
  63. }
  64. return retErr
  65. }
  66. func (hs hooks) processPipeline(
  67. ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
  68. ) error {
  69. if len(hs.hooks) == 0 {
  70. err := fn(ctx, cmds)
  71. return err
  72. }
  73. var hookIndex int
  74. var retErr error
  75. for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ {
  76. ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
  77. if retErr != nil {
  78. setCmdsErr(cmds, retErr)
  79. }
  80. }
  81. if retErr == nil {
  82. retErr = fn(ctx, cmds)
  83. }
  84. for hookIndex--; hookIndex >= 0; hookIndex-- {
  85. if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
  86. retErr = err
  87. setCmdsErr(cmds, retErr)
  88. }
  89. }
  90. return retErr
  91. }
  92. func (hs hooks) processTxPipeline(
  93. ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
  94. ) error {
  95. cmds = wrapMultiExec(ctx, cmds)
  96. return hs.processPipeline(ctx, cmds, fn)
  97. }
  98. //------------------------------------------------------------------------------
  99. type baseClient struct {
  100. opt *Options
  101. connPool pool.Pooler
  102. onClose func() error // hook called when client is closed
  103. }
  104. func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
  105. return &baseClient{
  106. opt: opt,
  107. connPool: connPool,
  108. }
  109. }
  110. func (c *baseClient) clone() *baseClient {
  111. clone := *c
  112. return &clone
  113. }
  114. func (c *baseClient) withTimeout(timeout time.Duration) *baseClient {
  115. opt := c.opt.clone()
  116. opt.ReadTimeout = timeout
  117. opt.WriteTimeout = timeout
  118. clone := c.clone()
  119. clone.opt = opt
  120. return clone
  121. }
  122. func (c *baseClient) String() string {
  123. return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
  124. }
  125. func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
  126. cn, err := c.connPool.NewConn(ctx)
  127. if err != nil {
  128. return nil, err
  129. }
  130. err = c.initConn(ctx, cn)
  131. if err != nil {
  132. _ = c.connPool.CloseConn(cn)
  133. return nil, err
  134. }
  135. return cn, nil
  136. }
  137. func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
  138. if c.opt.Limiter != nil {
  139. err := c.opt.Limiter.Allow()
  140. if err != nil {
  141. return nil, err
  142. }
  143. }
  144. cn, err := c._getConn(ctx)
  145. if err != nil {
  146. if c.opt.Limiter != nil {
  147. c.opt.Limiter.ReportResult(err)
  148. }
  149. return nil, err
  150. }
  151. return cn, nil
  152. }
  153. func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
  154. cn, err := c.connPool.Get(ctx)
  155. if err != nil {
  156. return nil, err
  157. }
  158. if cn.Inited {
  159. return cn, nil
  160. }
  161. if err := c.initConn(ctx, cn); err != nil {
  162. c.connPool.Remove(ctx, cn, err)
  163. if err := errors.Unwrap(err); err != nil {
  164. return nil, err
  165. }
  166. return nil, err
  167. }
  168. return cn, nil
  169. }
  170. func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {
  171. if cn.Inited {
  172. return nil
  173. }
  174. cn.Inited = true
  175. if c.opt.Password == "" &&
  176. c.opt.DB == 0 &&
  177. !c.opt.readOnly &&
  178. c.opt.OnConnect == nil {
  179. return nil
  180. }
  181. connPool := pool.NewSingleConnPool(c.connPool, cn)
  182. conn := newConn(ctx, c.opt, connPool)
  183. _, err := conn.Pipelined(ctx, func(pipe Pipeliner) error {
  184. if c.opt.Password != "" {
  185. if c.opt.Username != "" {
  186. pipe.AuthACL(ctx, c.opt.Username, c.opt.Password)
  187. } else {
  188. pipe.Auth(ctx, c.opt.Password)
  189. }
  190. }
  191. if c.opt.DB > 0 {
  192. pipe.Select(ctx, c.opt.DB)
  193. }
  194. if c.opt.readOnly {
  195. pipe.ReadOnly(ctx)
  196. }
  197. return nil
  198. })
  199. if err != nil {
  200. return err
  201. }
  202. if c.opt.OnConnect != nil {
  203. return c.opt.OnConnect(ctx, conn)
  204. }
  205. return nil
  206. }
  207. func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
  208. if c.opt.Limiter != nil {
  209. c.opt.Limiter.ReportResult(err)
  210. }
  211. if isBadConn(err, false, c.opt.Addr) {
  212. c.connPool.Remove(ctx, cn, err)
  213. } else {
  214. c.connPool.Put(ctx, cn)
  215. }
  216. }
  217. func (c *baseClient) withConn(
  218. ctx context.Context, fn func(context.Context, *pool.Conn) error,
  219. ) error {
  220. cn, err := c.getConn(ctx)
  221. if err != nil {
  222. return err
  223. }
  224. defer func() {
  225. c.releaseConn(ctx, cn, err)
  226. }()
  227. done := ctx.Done() //nolint:ifshort
  228. if done == nil {
  229. err = fn(ctx, cn)
  230. return err
  231. }
  232. errc := make(chan error, 1)
  233. go func() { errc <- fn(ctx, cn) }()
  234. select {
  235. case <-done:
  236. _ = cn.Close()
  237. // Wait for the goroutine to finish and send something.
  238. <-errc
  239. err = ctx.Err()
  240. return err
  241. case err = <-errc:
  242. return err
  243. }
  244. }
  245. func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
  246. var lastErr error
  247. for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
  248. attempt := attempt
  249. retry, err := c._process(ctx, cmd, attempt)
  250. if err == nil || !retry {
  251. return err
  252. }
  253. lastErr = err
  254. }
  255. return lastErr
  256. }
  257. func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) {
  258. if attempt > 0 {
  259. if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
  260. return false, err
  261. }
  262. }
  263. retryTimeout := uint32(1)
  264. err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
  265. err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
  266. return writeCmd(wr, cmd)
  267. })
  268. if err != nil {
  269. return err
  270. }
  271. err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
  272. if err != nil {
  273. if cmd.readTimeout() == nil {
  274. atomic.StoreUint32(&retryTimeout, 1)
  275. }
  276. return err
  277. }
  278. return nil
  279. })
  280. if err == nil {
  281. return false, nil
  282. }
  283. retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
  284. return retry, err
  285. }
  286. func (c *baseClient) retryBackoff(attempt int) time.Duration {
  287. return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff)
  288. }
  289. func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration {
  290. if timeout := cmd.readTimeout(); timeout != nil {
  291. t := *timeout
  292. if t == 0 {
  293. return 0
  294. }
  295. return t + 10*time.Second
  296. }
  297. return c.opt.ReadTimeout
  298. }
  299. // Close closes the client, releasing any open resources.
  300. //
  301. // It is rare to Close a Client, as the Client is meant to be
  302. // long-lived and shared between many goroutines.
  303. func (c *baseClient) Close() error {
  304. var firstErr error
  305. if c.onClose != nil {
  306. if err := c.onClose(); err != nil {
  307. firstErr = err
  308. }
  309. }
  310. if err := c.connPool.Close(); err != nil && firstErr == nil {
  311. firstErr = err
  312. }
  313. return firstErr
  314. }
  315. func (c *baseClient) getAddr() string {
  316. return c.opt.Addr
  317. }
  318. func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
  319. return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds)
  320. }
  321. func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
  322. return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
  323. }
  324. type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)
  325. func (c *baseClient) generalProcessPipeline(
  326. ctx context.Context, cmds []Cmder, p pipelineProcessor,
  327. ) error {
  328. err := c._generalProcessPipeline(ctx, cmds, p)
  329. if err != nil {
  330. setCmdsErr(cmds, err)
  331. return err
  332. }
  333. return cmdsFirstErr(cmds)
  334. }
  335. func (c *baseClient) _generalProcessPipeline(
  336. ctx context.Context, cmds []Cmder, p pipelineProcessor,
  337. ) error {
  338. var lastErr error
  339. for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
  340. if attempt > 0 {
  341. if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil {
  342. return err
  343. }
  344. }
  345. var canRetry bool
  346. lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
  347. var err error
  348. canRetry, err = p(ctx, cn, cmds)
  349. return err
  350. })
  351. if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) {
  352. return lastErr
  353. }
  354. }
  355. return lastErr
  356. }
  357. func (c *baseClient) pipelineProcessCmds(
  358. ctx context.Context, cn *pool.Conn, cmds []Cmder,
  359. ) (bool, error) {
  360. err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
  361. return writeCmds(wr, cmds)
  362. })
  363. if err != nil {
  364. return true, err
  365. }
  366. err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
  367. return pipelineReadCmds(rd, cmds)
  368. })
  369. return true, err
  370. }
  371. func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
  372. for _, cmd := range cmds {
  373. err := cmd.readReply(rd)
  374. cmd.SetErr(err)
  375. if err != nil && !isRedisError(err) {
  376. return err
  377. }
  378. }
  379. return nil
  380. }
  381. func (c *baseClient) txPipelineProcessCmds(
  382. ctx context.Context, cn *pool.Conn, cmds []Cmder,
  383. ) (bool, error) {
  384. err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
  385. return writeCmds(wr, cmds)
  386. })
  387. if err != nil {
  388. return true, err
  389. }
  390. err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
  391. statusCmd := cmds[0].(*StatusCmd)
  392. // Trim multi and exec.
  393. cmds = cmds[1 : len(cmds)-1]
  394. err := txPipelineReadQueued(rd, statusCmd, cmds)
  395. if err != nil {
  396. return err
  397. }
  398. return pipelineReadCmds(rd, cmds)
  399. })
  400. return false, err
  401. }
  402. func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder {
  403. if len(cmds) == 0 {
  404. panic("not reached")
  405. }
  406. cmdCopy := make([]Cmder, len(cmds)+2)
  407. cmdCopy[0] = NewStatusCmd(ctx, "multi")
  408. copy(cmdCopy[1:], cmds)
  409. cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec")
  410. return cmdCopy
  411. }
  412. func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error {
  413. // Parse queued replies.
  414. if err := statusCmd.readReply(rd); err != nil {
  415. return err
  416. }
  417. for range cmds {
  418. if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) {
  419. return err
  420. }
  421. }
  422. // Parse number of replies.
  423. line, err := rd.ReadLine()
  424. if err != nil {
  425. if err == Nil {
  426. err = TxFailedErr
  427. }
  428. return err
  429. }
  430. switch line[0] {
  431. case proto.ErrorReply:
  432. return proto.ParseErrorReply(line)
  433. case proto.ArrayReply:
  434. // ok
  435. default:
  436. err := fmt.Errorf("redis: expected '*', but got line %q", line)
  437. return err
  438. }
  439. return nil
  440. }
  441. //------------------------------------------------------------------------------
  442. // Client is a Redis client representing a pool of zero or more
  443. // underlying connections. It's safe for concurrent use by multiple
  444. // goroutines.
  445. type Client struct {
  446. *baseClient
  447. cmdable
  448. hooks
  449. ctx context.Context
  450. }
  451. // NewClient returns a client to the Redis Server specified by Options.
  452. func NewClient(opt *Options) *Client {
  453. opt.init()
  454. c := Client{
  455. baseClient: newBaseClient(opt, newConnPool(opt)),
  456. ctx: context.Background(),
  457. }
  458. c.cmdable = c.Process
  459. return &c
  460. }
  461. func (c *Client) clone() *Client {
  462. clone := *c
  463. clone.cmdable = clone.Process
  464. clone.hooks.lock()
  465. return &clone
  466. }
  467. func (c *Client) WithTimeout(timeout time.Duration) *Client {
  468. clone := c.clone()
  469. clone.baseClient = c.baseClient.withTimeout(timeout)
  470. return clone
  471. }
  472. func (c *Client) Context() context.Context {
  473. return c.ctx
  474. }
  475. func (c *Client) WithContext(ctx context.Context) *Client {
  476. if ctx == nil {
  477. panic("nil context")
  478. }
  479. clone := c.clone()
  480. clone.ctx = ctx
  481. return clone
  482. }
  483. func (c *Client) Conn(ctx context.Context) *Conn {
  484. return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool))
  485. }
  486. // Do creates a Cmd from the args and processes the cmd.
  487. func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd {
  488. cmd := NewCmd(ctx, args...)
  489. _ = c.Process(ctx, cmd)
  490. return cmd
  491. }
  492. func (c *Client) Process(ctx context.Context, cmd Cmder) error {
  493. return c.hooks.process(ctx, cmd, c.baseClient.process)
  494. }
  495. func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
  496. return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
  497. }
  498. func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
  499. return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
  500. }
  501. // Options returns read-only Options that were used to create the client.
  502. func (c *Client) Options() *Options {
  503. return c.opt
  504. }
  505. type PoolStats pool.Stats
  506. // PoolStats returns connection pool stats.
  507. func (c *Client) PoolStats() *PoolStats {
  508. stats := c.connPool.Stats()
  509. return (*PoolStats)(stats)
  510. }
  511. func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  512. return c.Pipeline().Pipelined(ctx, fn)
  513. }
  514. func (c *Client) Pipeline() Pipeliner {
  515. pipe := Pipeline{
  516. ctx: c.ctx,
  517. exec: c.processPipeline,
  518. }
  519. pipe.init()
  520. return &pipe
  521. }
  522. func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  523. return c.TxPipeline().Pipelined(ctx, fn)
  524. }
  525. // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
  526. func (c *Client) TxPipeline() Pipeliner {
  527. pipe := Pipeline{
  528. ctx: c.ctx,
  529. exec: c.processTxPipeline,
  530. }
  531. pipe.init()
  532. return &pipe
  533. }
  534. func (c *Client) pubSub() *PubSub {
  535. pubsub := &PubSub{
  536. opt: c.opt,
  537. newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
  538. return c.newConn(ctx)
  539. },
  540. closeConn: c.connPool.CloseConn,
  541. }
  542. pubsub.init()
  543. return pubsub
  544. }
  545. // Subscribe subscribes the client to the specified channels.
  546. // Channels can be omitted to create empty subscription.
  547. // Note that this method does not wait on a response from Redis, so the
  548. // subscription may not be active immediately. To force the connection to wait,
  549. // you may call the Receive() method on the returned *PubSub like so:
  550. //
  551. // sub := client.Subscribe(queryResp)
  552. // iface, err := sub.Receive()
  553. // if err != nil {
  554. // // handle error
  555. // }
  556. //
  557. // // Should be *Subscription, but others are possible if other actions have been
  558. // // taken on sub since it was created.
  559. // switch iface.(type) {
  560. // case *Subscription:
  561. // // subscribe succeeded
  562. // case *Message:
  563. // // received first message
  564. // case *Pong:
  565. // // pong received
  566. // default:
  567. // // handle error
  568. // }
  569. //
  570. // ch := sub.Channel()
  571. func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub {
  572. pubsub := c.pubSub()
  573. if len(channels) > 0 {
  574. _ = pubsub.Subscribe(ctx, channels...)
  575. }
  576. return pubsub
  577. }
  578. // PSubscribe subscribes the client to the given patterns.
  579. // Patterns can be omitted to create empty subscription.
  580. func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub {
  581. pubsub := c.pubSub()
  582. if len(channels) > 0 {
  583. _ = pubsub.PSubscribe(ctx, channels...)
  584. }
  585. return pubsub
  586. }
  587. //------------------------------------------------------------------------------
  588. type conn struct {
  589. baseClient
  590. cmdable
  591. statefulCmdable
  592. hooks // TODO: inherit hooks
  593. }
  594. // Conn represents a single Redis connection rather than a pool of connections.
  595. // Prefer running commands from Client unless there is a specific need
  596. // for a continuous single Redis connection.
  597. type Conn struct {
  598. *conn
  599. ctx context.Context
  600. }
  601. func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn {
  602. c := Conn{
  603. conn: &conn{
  604. baseClient: baseClient{
  605. opt: opt,
  606. connPool: connPool,
  607. },
  608. },
  609. ctx: ctx,
  610. }
  611. c.cmdable = c.Process
  612. c.statefulCmdable = c.Process
  613. return &c
  614. }
  615. func (c *Conn) Process(ctx context.Context, cmd Cmder) error {
  616. return c.hooks.process(ctx, cmd, c.baseClient.process)
  617. }
  618. func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error {
  619. return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
  620. }
  621. func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error {
  622. return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline)
  623. }
  624. func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  625. return c.Pipeline().Pipelined(ctx, fn)
  626. }
  627. func (c *Conn) Pipeline() Pipeliner {
  628. pipe := Pipeline{
  629. ctx: c.ctx,
  630. exec: c.processPipeline,
  631. }
  632. pipe.init()
  633. return &pipe
  634. }
  635. func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) {
  636. return c.TxPipeline().Pipelined(ctx, fn)
  637. }
  638. // TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
  639. func (c *Conn) TxPipeline() Pipeliner {
  640. pipe := Pipeline{
  641. ctx: c.ctx,
  642. exec: c.processTxPipeline,
  643. }
  644. pipe.init()
  645. return &pipe
  646. }