You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

669 lines
15 KiB

  1. package redis
  2. import (
  3. "context"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/go-redis/redis/v8/internal"
  9. "github.com/go-redis/redis/v8/internal/pool"
  10. "github.com/go-redis/redis/v8/internal/proto"
  11. )
  12. // PubSub implements Pub/Sub commands as described in
  13. // http://redis.io/topics/pubsub. Message receiving is NOT safe
  14. // for concurrent use by multiple goroutines.
  15. //
  16. // PubSub automatically reconnects to Redis Server and resubscribes
  17. // to the channels in case of network errors.
  18. type PubSub struct {
  19. opt *Options
  20. newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
  21. closeConn func(*pool.Conn) error
  22. mu sync.Mutex
  23. cn *pool.Conn
  24. channels map[string]struct{}
  25. patterns map[string]struct{}
  26. closed bool
  27. exit chan struct{}
  28. cmd *Cmd
  29. chOnce sync.Once
  30. msgCh *channel
  31. allCh *channel
  32. }
  33. func (c *PubSub) init() {
  34. c.exit = make(chan struct{})
  35. }
  36. func (c *PubSub) String() string {
  37. channels := mapKeys(c.channels)
  38. channels = append(channels, mapKeys(c.patterns)...)
  39. return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
  40. }
  41. func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
  42. c.mu.Lock()
  43. cn, err := c.conn(ctx, nil)
  44. c.mu.Unlock()
  45. return cn, err
  46. }
  47. func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
  48. if c.closed {
  49. return nil, pool.ErrClosed
  50. }
  51. if c.cn != nil {
  52. return c.cn, nil
  53. }
  54. channels := mapKeys(c.channels)
  55. channels = append(channels, newChannels...)
  56. cn, err := c.newConn(ctx, channels)
  57. if err != nil {
  58. return nil, err
  59. }
  60. if err := c.resubscribe(ctx, cn); err != nil {
  61. _ = c.closeConn(cn)
  62. return nil, err
  63. }
  64. c.cn = cn
  65. return cn, nil
  66. }
  67. func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
  68. return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
  69. return writeCmd(wr, cmd)
  70. })
  71. }
  72. func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
  73. var firstErr error
  74. if len(c.channels) > 0 {
  75. firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
  76. }
  77. if len(c.patterns) > 0 {
  78. err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
  79. if err != nil && firstErr == nil {
  80. firstErr = err
  81. }
  82. }
  83. return firstErr
  84. }
  85. func mapKeys(m map[string]struct{}) []string {
  86. s := make([]string, len(m))
  87. i := 0
  88. for k := range m {
  89. s[i] = k
  90. i++
  91. }
  92. return s
  93. }
  94. func (c *PubSub) _subscribe(
  95. ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
  96. ) error {
  97. args := make([]interface{}, 0, 1+len(channels))
  98. args = append(args, redisCmd)
  99. for _, channel := range channels {
  100. args = append(args, channel)
  101. }
  102. cmd := NewSliceCmd(ctx, args...)
  103. return c.writeCmd(ctx, cn, cmd)
  104. }
  105. func (c *PubSub) releaseConnWithLock(
  106. ctx context.Context,
  107. cn *pool.Conn,
  108. err error,
  109. allowTimeout bool,
  110. ) {
  111. c.mu.Lock()
  112. c.releaseConn(ctx, cn, err, allowTimeout)
  113. c.mu.Unlock()
  114. }
  115. func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
  116. if c.cn != cn {
  117. return
  118. }
  119. if isBadConn(err, allowTimeout, c.opt.Addr) {
  120. c.reconnect(ctx, err)
  121. }
  122. }
  123. func (c *PubSub) reconnect(ctx context.Context, reason error) {
  124. _ = c.closeTheCn(reason)
  125. _, _ = c.conn(ctx, nil)
  126. }
  127. func (c *PubSub) closeTheCn(reason error) error {
  128. if c.cn == nil {
  129. return nil
  130. }
  131. if !c.closed {
  132. internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
  133. }
  134. err := c.closeConn(c.cn)
  135. c.cn = nil
  136. return err
  137. }
  138. func (c *PubSub) Close() error {
  139. c.mu.Lock()
  140. defer c.mu.Unlock()
  141. if c.closed {
  142. return pool.ErrClosed
  143. }
  144. c.closed = true
  145. close(c.exit)
  146. return c.closeTheCn(pool.ErrClosed)
  147. }
  148. // Subscribe the client to the specified channels. It returns
  149. // empty subscription if there are no channels.
  150. func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
  151. c.mu.Lock()
  152. defer c.mu.Unlock()
  153. err := c.subscribe(ctx, "subscribe", channels...)
  154. if c.channels == nil {
  155. c.channels = make(map[string]struct{})
  156. }
  157. for _, s := range channels {
  158. c.channels[s] = struct{}{}
  159. }
  160. return err
  161. }
  162. // PSubscribe the client to the given patterns. It returns
  163. // empty subscription if there are no patterns.
  164. func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
  165. c.mu.Lock()
  166. defer c.mu.Unlock()
  167. err := c.subscribe(ctx, "psubscribe", patterns...)
  168. if c.patterns == nil {
  169. c.patterns = make(map[string]struct{})
  170. }
  171. for _, s := range patterns {
  172. c.patterns[s] = struct{}{}
  173. }
  174. return err
  175. }
  176. // Unsubscribe the client from the given channels, or from all of
  177. // them if none is given.
  178. func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
  179. c.mu.Lock()
  180. defer c.mu.Unlock()
  181. for _, channel := range channels {
  182. delete(c.channels, channel)
  183. }
  184. err := c.subscribe(ctx, "unsubscribe", channels...)
  185. return err
  186. }
  187. // PUnsubscribe the client from the given patterns, or from all of
  188. // them if none is given.
  189. func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
  190. c.mu.Lock()
  191. defer c.mu.Unlock()
  192. for _, pattern := range patterns {
  193. delete(c.patterns, pattern)
  194. }
  195. err := c.subscribe(ctx, "punsubscribe", patterns...)
  196. return err
  197. }
  198. func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
  199. cn, err := c.conn(ctx, channels)
  200. if err != nil {
  201. return err
  202. }
  203. err = c._subscribe(ctx, cn, redisCmd, channels)
  204. c.releaseConn(ctx, cn, err, false)
  205. return err
  206. }
  207. func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
  208. args := []interface{}{"ping"}
  209. if len(payload) == 1 {
  210. args = append(args, payload[0])
  211. }
  212. cmd := NewCmd(ctx, args...)
  213. c.mu.Lock()
  214. defer c.mu.Unlock()
  215. cn, err := c.conn(ctx, nil)
  216. if err != nil {
  217. return err
  218. }
  219. err = c.writeCmd(ctx, cn, cmd)
  220. c.releaseConn(ctx, cn, err, false)
  221. return err
  222. }
  223. // Subscription received after a successful subscription to channel.
  224. type Subscription struct {
  225. // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
  226. Kind string
  227. // Channel name we have subscribed to.
  228. Channel string
  229. // Number of channels we are currently subscribed to.
  230. Count int
  231. }
  232. func (m *Subscription) String() string {
  233. return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
  234. }
  235. // Message received as result of a PUBLISH command issued by another client.
  236. type Message struct {
  237. Channel string
  238. Pattern string
  239. Payload string
  240. PayloadSlice []string
  241. }
  242. func (m *Message) String() string {
  243. return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
  244. }
  245. // Pong received as result of a PING command issued by another client.
  246. type Pong struct {
  247. Payload string
  248. }
  249. func (p *Pong) String() string {
  250. if p.Payload != "" {
  251. return fmt.Sprintf("Pong<%s>", p.Payload)
  252. }
  253. return "Pong"
  254. }
  255. func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
  256. switch reply := reply.(type) {
  257. case string:
  258. return &Pong{
  259. Payload: reply,
  260. }, nil
  261. case []interface{}:
  262. switch kind := reply[0].(string); kind {
  263. case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
  264. // Can be nil in case of "unsubscribe".
  265. channel, _ := reply[1].(string)
  266. return &Subscription{
  267. Kind: kind,
  268. Channel: channel,
  269. Count: int(reply[2].(int64)),
  270. }, nil
  271. case "message":
  272. switch payload := reply[2].(type) {
  273. case string:
  274. return &Message{
  275. Channel: reply[1].(string),
  276. Payload: payload,
  277. }, nil
  278. case []interface{}:
  279. ss := make([]string, len(payload))
  280. for i, s := range payload {
  281. ss[i] = s.(string)
  282. }
  283. return &Message{
  284. Channel: reply[1].(string),
  285. PayloadSlice: ss,
  286. }, nil
  287. default:
  288. return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
  289. }
  290. case "pmessage":
  291. return &Message{
  292. Pattern: reply[1].(string),
  293. Channel: reply[2].(string),
  294. Payload: reply[3].(string),
  295. }, nil
  296. case "pong":
  297. return &Pong{
  298. Payload: reply[1].(string),
  299. }, nil
  300. default:
  301. return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
  302. }
  303. default:
  304. return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
  305. }
  306. }
  307. // ReceiveTimeout acts like Receive but returns an error if message
  308. // is not received in time. This is low-level API and in most cases
  309. // Channel should be used instead.
  310. func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
  311. if c.cmd == nil {
  312. c.cmd = NewCmd(ctx)
  313. }
  314. // Don't hold the lock to allow subscriptions and pings.
  315. cn, err := c.connWithLock(ctx)
  316. if err != nil {
  317. return nil, err
  318. }
  319. err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
  320. return c.cmd.readReply(rd)
  321. })
  322. c.releaseConnWithLock(ctx, cn, err, timeout > 0)
  323. if err != nil {
  324. return nil, err
  325. }
  326. return c.newMessage(c.cmd.Val())
  327. }
  328. // Receive returns a message as a Subscription, Message, Pong or error.
  329. // See PubSub example for details. This is low-level API and in most cases
  330. // Channel should be used instead.
  331. func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
  332. return c.ReceiveTimeout(ctx, 0)
  333. }
  334. // ReceiveMessage returns a Message or error ignoring Subscription and Pong
  335. // messages. This is low-level API and in most cases Channel should be used
  336. // instead.
  337. func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
  338. for {
  339. msg, err := c.Receive(ctx)
  340. if err != nil {
  341. return nil, err
  342. }
  343. switch msg := msg.(type) {
  344. case *Subscription:
  345. // Ignore.
  346. case *Pong:
  347. // Ignore.
  348. case *Message:
  349. return msg, nil
  350. default:
  351. err := fmt.Errorf("redis: unknown message: %T", msg)
  352. return nil, err
  353. }
  354. }
  355. }
  356. func (c *PubSub) getContext() context.Context {
  357. if c.cmd != nil {
  358. return c.cmd.ctx
  359. }
  360. return context.Background()
  361. }
  362. //------------------------------------------------------------------------------
  363. // Channel returns a Go channel for concurrently receiving messages.
  364. // The channel is closed together with the PubSub. If the Go channel
  365. // is blocked full for 30 seconds the message is dropped.
  366. // Receive* APIs can not be used after channel is created.
  367. //
  368. // go-redis periodically sends ping messages to test connection health
  369. // and re-subscribes if ping can not not received for 30 seconds.
  370. func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
  371. c.chOnce.Do(func() {
  372. c.msgCh = newChannel(c, opts...)
  373. c.msgCh.initMsgChan()
  374. })
  375. if c.msgCh == nil {
  376. err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
  377. panic(err)
  378. }
  379. return c.msgCh.msgCh
  380. }
  381. // ChannelSize is like Channel, but creates a Go channel
  382. // with specified buffer size.
  383. //
  384. // Deprecated: use Channel(WithChannelSize(size)), remove in v9.
  385. func (c *PubSub) ChannelSize(size int) <-chan *Message {
  386. return c.Channel(WithChannelSize(size))
  387. }
  388. // ChannelWithSubscriptions is like Channel, but message type can be either
  389. // *Subscription or *Message. Subscription messages can be used to detect
  390. // reconnections.
  391. //
  392. // ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
  393. func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
  394. c.chOnce.Do(func() {
  395. c.allCh = newChannel(c, WithChannelSize(size))
  396. c.allCh.initAllChan()
  397. })
  398. if c.allCh == nil {
  399. err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
  400. panic(err)
  401. }
  402. return c.allCh.allCh
  403. }
  404. type ChannelOption func(c *channel)
  405. // WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
  406. //
  407. // The default is 100 messages.
  408. func WithChannelSize(size int) ChannelOption {
  409. return func(c *channel) {
  410. c.chanSize = size
  411. }
  412. }
  413. // WithChannelHealthCheckInterval specifies the health check interval.
  414. // PubSub will ping Redis Server if it does not receive any messages within the interval.
  415. // To disable health check, use zero interval.
  416. //
  417. // The default is 3 seconds.
  418. func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
  419. return func(c *channel) {
  420. c.checkInterval = d
  421. }
  422. }
  423. // WithChannelSendTimeout specifies the channel send timeout after which
  424. // the message is dropped.
  425. //
  426. // The default is 60 seconds.
  427. func WithChannelSendTimeout(d time.Duration) ChannelOption {
  428. return func(c *channel) {
  429. c.chanSendTimeout = d
  430. }
  431. }
  432. type channel struct {
  433. pubSub *PubSub
  434. msgCh chan *Message
  435. allCh chan interface{}
  436. ping chan struct{}
  437. chanSize int
  438. chanSendTimeout time.Duration
  439. checkInterval time.Duration
  440. }
  441. func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
  442. c := &channel{
  443. pubSub: pubSub,
  444. chanSize: 100,
  445. chanSendTimeout: time.Minute,
  446. checkInterval: 3 * time.Second,
  447. }
  448. for _, opt := range opts {
  449. opt(c)
  450. }
  451. if c.checkInterval > 0 {
  452. c.initHealthCheck()
  453. }
  454. return c
  455. }
  456. func (c *channel) initHealthCheck() {
  457. ctx := context.TODO()
  458. c.ping = make(chan struct{}, 1)
  459. go func() {
  460. timer := time.NewTimer(time.Minute)
  461. timer.Stop()
  462. for {
  463. timer.Reset(c.checkInterval)
  464. select {
  465. case <-c.ping:
  466. if !timer.Stop() {
  467. <-timer.C
  468. }
  469. case <-timer.C:
  470. if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
  471. c.pubSub.mu.Lock()
  472. c.pubSub.reconnect(ctx, pingErr)
  473. c.pubSub.mu.Unlock()
  474. }
  475. case <-c.pubSub.exit:
  476. return
  477. }
  478. }
  479. }()
  480. }
  481. // initMsgChan must be in sync with initAllChan.
  482. func (c *channel) initMsgChan() {
  483. ctx := context.TODO()
  484. c.msgCh = make(chan *Message, c.chanSize)
  485. go func() {
  486. timer := time.NewTimer(time.Minute)
  487. timer.Stop()
  488. var errCount int
  489. for {
  490. msg, err := c.pubSub.Receive(ctx)
  491. if err != nil {
  492. if err == pool.ErrClosed {
  493. close(c.msgCh)
  494. return
  495. }
  496. if errCount > 0 {
  497. time.Sleep(100 * time.Millisecond)
  498. }
  499. errCount++
  500. continue
  501. }
  502. errCount = 0
  503. // Any message is as good as a ping.
  504. select {
  505. case c.ping <- struct{}{}:
  506. default:
  507. }
  508. switch msg := msg.(type) {
  509. case *Subscription:
  510. // Ignore.
  511. case *Pong:
  512. // Ignore.
  513. case *Message:
  514. timer.Reset(c.chanSendTimeout)
  515. select {
  516. case c.msgCh <- msg:
  517. if !timer.Stop() {
  518. <-timer.C
  519. }
  520. case <-timer.C:
  521. internal.Logger.Printf(
  522. ctx, "redis: %s channel is full for %s (message is dropped)",
  523. c, c.chanSendTimeout)
  524. }
  525. default:
  526. internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
  527. }
  528. }
  529. }()
  530. }
  531. // initAllChan must be in sync with initMsgChan.
  532. func (c *channel) initAllChan() {
  533. ctx := context.TODO()
  534. c.allCh = make(chan interface{}, c.chanSize)
  535. go func() {
  536. timer := time.NewTimer(time.Minute)
  537. timer.Stop()
  538. var errCount int
  539. for {
  540. msg, err := c.pubSub.Receive(ctx)
  541. if err != nil {
  542. if err == pool.ErrClosed {
  543. close(c.allCh)
  544. return
  545. }
  546. if errCount > 0 {
  547. time.Sleep(100 * time.Millisecond)
  548. }
  549. errCount++
  550. continue
  551. }
  552. errCount = 0
  553. // Any message is as good as a ping.
  554. select {
  555. case c.ping <- struct{}{}:
  556. default:
  557. }
  558. switch msg := msg.(type) {
  559. case *Pong:
  560. // Ignore.
  561. case *Subscription, *Message:
  562. timer.Reset(c.chanSendTimeout)
  563. select {
  564. case c.allCh <- msg:
  565. if !timer.Stop() {
  566. <-timer.C
  567. }
  568. case <-timer.C:
  569. internal.Logger.Printf(
  570. ctx, "redis: %s channel is full for %s (message is dropped)",
  571. c, c.chanSendTimeout)
  572. }
  573. default:
  574. internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
  575. }
  576. }
  577. }()
  578. }