No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.
 
 
 

799 líneas
20 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package pstest provides a fake Cloud PubSub service for testing. It implements a
  15. // simplified form of the service, suitable for unit tests. It may behave
  16. // differently from the actual service in ways in which the service is
  17. // non-deterministic or unspecified: timing, delivery order, etc.
  18. //
  19. // This package is EXPERIMENTAL and is subject to change without notice.
  20. //
  21. // See the example for usage.
  22. package pstest
  23. import (
  24. "fmt"
  25. "io"
  26. "path"
  27. "sort"
  28. "strings"
  29. "sync"
  30. "sync/atomic"
  31. "time"
  32. "cloud.google.com/go/internal/testutil"
  33. "github.com/golang/protobuf/ptypes"
  34. durpb "github.com/golang/protobuf/ptypes/duration"
  35. emptypb "github.com/golang/protobuf/ptypes/empty"
  36. "golang.org/x/net/context"
  37. pb "google.golang.org/genproto/googleapis/pubsub/v1"
  38. "google.golang.org/grpc/codes"
  39. "google.golang.org/grpc/status"
  40. )
  41. // For testing. Note that even though changes to the now variable are atomic, a call
  42. // to the stored function can race with a change to that function. This could be a
  43. // problem if tests are run in parallel, or even if concurrent parts of the same test
  44. // change the value of the variable.
  45. var now atomic.Value
  46. func init() {
  47. now.Store(time.Now)
  48. }
  49. func timeNow() time.Time {
  50. return now.Load().(func() time.Time)()
  51. }
  52. type Server struct {
  53. Addr string // The address that the server is listening on.
  54. gServer gServer
  55. }
  56. type gServer struct {
  57. pb.PublisherServer
  58. pb.SubscriberServer
  59. mu sync.Mutex
  60. topics map[string]*topic
  61. subs map[string]*subscription
  62. msgs []*Message // all messages ever published
  63. msgsByID map[string]*Message
  64. wg sync.WaitGroup
  65. nextID int
  66. streamTimeout time.Duration
  67. }
  68. // NewServer creates a new fake server running in the current process.
  69. func NewServer() *Server {
  70. srv, err := testutil.NewServer()
  71. if err != nil {
  72. panic(fmt.Sprintf("pstest.NewServer: %v", err))
  73. }
  74. s := &Server{
  75. Addr: srv.Addr,
  76. gServer: gServer{
  77. topics: map[string]*topic{},
  78. subs: map[string]*subscription{},
  79. msgsByID: map[string]*Message{},
  80. },
  81. }
  82. pb.RegisterPublisherServer(srv.Gsrv, &s.gServer)
  83. pb.RegisterSubscriberServer(srv.Gsrv, &s.gServer)
  84. srv.Start()
  85. return s
  86. }
  87. // Publish behaves as if the Publish RPC was called with a message with the given
  88. // data and attrs. It returns the ID of the message.
  89. // The topic will be created if it doesn't exist.
  90. //
  91. // Publish panics if there is an error, which is appropriate for testing.
  92. func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
  93. const topicPattern = "projects/*/topics/*"
  94. ok, err := path.Match(topicPattern, topic)
  95. if err != nil {
  96. panic(err)
  97. }
  98. if !ok {
  99. panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
  100. }
  101. _, _ = s.gServer.CreateTopic(nil, &pb.Topic{Name: topic})
  102. req := &pb.PublishRequest{
  103. Topic: topic,
  104. Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
  105. }
  106. res, err := s.gServer.Publish(nil, req)
  107. if err != nil {
  108. panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
  109. }
  110. return res.MessageIds[0]
  111. }
  112. // SetStreamTimeout sets the amount of time a stream will be active before it shuts
  113. // itself down. This mimics the real service's behavior of closing streams after 30
  114. // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
  115. // down.
  116. func (s *Server) SetStreamTimeout(d time.Duration) {
  117. s.gServer.mu.Lock()
  118. defer s.gServer.mu.Unlock()
  119. s.gServer.streamTimeout = d
  120. }
  121. // A Message is a message that was published to the server.
  122. type Message struct {
  123. ID string
  124. Data []byte
  125. Attributes map[string]string
  126. PublishTime time.Time
  127. Deliveries int // number of times delivery of the message was attempted
  128. Acks int // number of acks received from clients
  129. // protected by server mutex
  130. deliveries int
  131. acks int
  132. }
  133. // Messages returns information about all messages ever published.
  134. func (s *Server) Messages() []*Message {
  135. s.gServer.mu.Lock()
  136. defer s.gServer.mu.Unlock()
  137. var msgs []*Message
  138. for _, m := range s.gServer.msgs {
  139. m.Deliveries = m.deliveries
  140. m.Acks = m.acks
  141. msgs = append(msgs, m)
  142. }
  143. return msgs
  144. }
  145. // Message returns the message with the given ID, or nil if no message
  146. // with that ID was published.
  147. func (s *Server) Message(id string) *Message {
  148. s.gServer.mu.Lock()
  149. defer s.gServer.mu.Unlock()
  150. m := s.gServer.msgsByID[id]
  151. if m != nil {
  152. m.Deliveries = m.deliveries
  153. m.Acks = m.acks
  154. }
  155. return m
  156. }
  157. // Wait blocks until all server activity has completed.
  158. func (s *Server) Wait() {
  159. s.gServer.wg.Wait()
  160. }
  161. func (s *gServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
  162. s.mu.Lock()
  163. defer s.mu.Unlock()
  164. if s.topics[t.Name] != nil {
  165. return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
  166. }
  167. top := newTopic(t)
  168. s.topics[t.Name] = top
  169. return top.proto, nil
  170. }
  171. func (s *gServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
  172. s.mu.Lock()
  173. defer s.mu.Unlock()
  174. if t := s.topics[req.Topic]; t != nil {
  175. return t.proto, nil
  176. }
  177. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  178. }
  179. func (s *gServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
  180. return nil, status.Errorf(codes.Unimplemented, "unimplemented")
  181. }
  182. func (s *gServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
  183. s.mu.Lock()
  184. defer s.mu.Unlock()
  185. var names []string
  186. for n := range s.topics {
  187. if strings.HasPrefix(n, req.Project) {
  188. names = append(names, n)
  189. }
  190. }
  191. sort.Strings(names)
  192. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  193. if err != nil {
  194. return nil, err
  195. }
  196. res := &pb.ListTopicsResponse{NextPageToken: nextToken}
  197. for i := from; i < to; i++ {
  198. res.Topics = append(res.Topics, s.topics[names[i]].proto)
  199. }
  200. return res, nil
  201. }
  202. func (s *gServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
  203. s.mu.Lock()
  204. defer s.mu.Unlock()
  205. var names []string
  206. for name, sub := range s.subs {
  207. if sub.topic.proto.Name == req.Topic {
  208. names = append(names, name)
  209. }
  210. }
  211. sort.Strings(names)
  212. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  213. if err != nil {
  214. return nil, err
  215. }
  216. return &pb.ListTopicSubscriptionsResponse{
  217. Subscriptions: names[from:to],
  218. NextPageToken: nextToken,
  219. }, nil
  220. }
  221. func (s *gServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
  222. s.mu.Lock()
  223. defer s.mu.Unlock()
  224. t := s.topics[req.Topic]
  225. if t == nil {
  226. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  227. }
  228. t.stop()
  229. delete(s.topics, req.Topic)
  230. return &emptypb.Empty{}, nil
  231. }
  232. func (s *gServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
  233. s.mu.Lock()
  234. defer s.mu.Unlock()
  235. if ps.Name == "" {
  236. return nil, status.Errorf(codes.InvalidArgument, "missing name")
  237. }
  238. if s.subs[ps.Name] != nil {
  239. return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
  240. }
  241. if ps.Topic == "" {
  242. return nil, status.Errorf(codes.InvalidArgument, "missing topic")
  243. }
  244. top := s.topics[ps.Topic]
  245. if top == nil {
  246. return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
  247. }
  248. if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
  249. return nil, err
  250. }
  251. if ps.MessageRetentionDuration == nil {
  252. ps.MessageRetentionDuration = defaultMessageRetentionDuration
  253. }
  254. if err := checkMRD(ps.MessageRetentionDuration); err != nil {
  255. return nil, err
  256. }
  257. if ps.PushConfig == nil {
  258. ps.PushConfig = &pb.PushConfig{}
  259. }
  260. sub := newSubscription(top, &s.mu, ps)
  261. top.subs[ps.Name] = sub
  262. s.subs[ps.Name] = sub
  263. sub.start(&s.wg)
  264. return ps, nil
  265. }
  266. // Can be set for testing.
  267. var minAckDeadlineSecs int32 = 10
  268. func checkAckDeadline(ads int32) error {
  269. if ads < minAckDeadlineSecs || ads > 600 {
  270. // PubSub service returns Unknown.
  271. return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
  272. }
  273. return nil
  274. }
  275. const (
  276. minMessageRetentionDuration = 10 * time.Minute
  277. maxMessageRetentionDuration = 168 * time.Hour
  278. )
  279. var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
  280. func checkMRD(pmrd *durpb.Duration) error {
  281. mrd, err := ptypes.Duration(pmrd)
  282. if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
  283. return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
  284. }
  285. return nil
  286. }
  287. func (s *gServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
  288. s.mu.Lock()
  289. defer s.mu.Unlock()
  290. if sub := s.subs[req.Subscription]; sub != nil {
  291. return sub.proto, nil
  292. }
  293. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
  294. }
  295. func (s *gServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
  296. s.mu.Lock()
  297. defer s.mu.Unlock()
  298. sub := s.subs[req.Subscription.Name]
  299. if sub == nil {
  300. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
  301. }
  302. for _, path := range req.UpdateMask.Paths {
  303. switch path {
  304. case "push_config":
  305. sub.proto.PushConfig = req.Subscription.PushConfig
  306. case "ack_deadline_seconds":
  307. a := req.Subscription.AckDeadlineSeconds
  308. if err := checkAckDeadline(a); err != nil {
  309. return nil, err
  310. }
  311. sub.proto.AckDeadlineSeconds = a
  312. case "retain_acked_messages":
  313. sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
  314. case "message_retention_duration":
  315. if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
  316. return nil, err
  317. }
  318. sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
  319. // TODO(jba): labels
  320. default:
  321. return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
  322. }
  323. }
  324. return sub.proto, nil
  325. }
  326. func (s *gServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
  327. s.mu.Lock()
  328. defer s.mu.Unlock()
  329. var names []string
  330. for name := range s.subs {
  331. if strings.HasPrefix(name, req.Project) {
  332. names = append(names, name)
  333. }
  334. }
  335. sort.Strings(names)
  336. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  337. if err != nil {
  338. return nil, err
  339. }
  340. res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
  341. for i := from; i < to; i++ {
  342. res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
  343. }
  344. return res, nil
  345. }
  346. func (s *gServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
  347. s.mu.Lock()
  348. defer s.mu.Unlock()
  349. sub := s.subs[req.Subscription]
  350. if sub == nil {
  351. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
  352. }
  353. sub.stop()
  354. delete(s.subs, req.Subscription)
  355. sub.topic.deleteSub(sub)
  356. return &emptypb.Empty{}, nil
  357. }
  358. func (s *gServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
  359. s.mu.Lock()
  360. defer s.mu.Unlock()
  361. if req.Topic == "" {
  362. return nil, status.Errorf(codes.InvalidArgument, "missing topic")
  363. }
  364. top := s.topics[req.Topic]
  365. if top == nil {
  366. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  367. }
  368. var ids []string
  369. for _, pm := range req.Messages {
  370. id := fmt.Sprintf("m%d", s.nextID)
  371. s.nextID++
  372. pm.MessageId = id
  373. pubTime := timeNow()
  374. tsPubTime, err := ptypes.TimestampProto(pubTime)
  375. if err != nil {
  376. return nil, status.Errorf(codes.Internal, err.Error())
  377. }
  378. pm.PublishTime = tsPubTime
  379. m := &Message{
  380. ID: id,
  381. Data: pm.Data,
  382. Attributes: pm.Attributes,
  383. PublishTime: pubTime,
  384. }
  385. top.publish(pm, m)
  386. ids = append(ids, id)
  387. s.msgs = append(s.msgs, m)
  388. s.msgsByID[id] = m
  389. }
  390. return &pb.PublishResponse{MessageIds: ids}, nil
  391. }
  392. type topic struct {
  393. proto *pb.Topic
  394. subs map[string]*subscription
  395. }
  396. func newTopic(pt *pb.Topic) *topic {
  397. return &topic{
  398. proto: pt,
  399. subs: map[string]*subscription{},
  400. }
  401. }
  402. func (t *topic) stop() {
  403. for _, sub := range t.subs {
  404. sub.proto.Topic = "_deleted-topic_"
  405. sub.stop()
  406. }
  407. }
  408. func (t *topic) deleteSub(sub *subscription) {
  409. delete(t.subs, sub.proto.Name)
  410. }
  411. func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
  412. for _, s := range t.subs {
  413. s.msgs[pm.MessageId] = &message{
  414. publishTime: m.PublishTime,
  415. proto: &pb.ReceivedMessage{
  416. AckId: pm.MessageId,
  417. Message: pm,
  418. },
  419. deliveries: &m.deliveries,
  420. acks: &m.acks,
  421. streamIndex: -1,
  422. }
  423. }
  424. }
  425. type subscription struct {
  426. topic *topic
  427. mu *sync.Mutex
  428. proto *pb.Subscription
  429. ackTimeout time.Duration
  430. msgs map[string]*message // unacked messages by message ID
  431. streams []*stream
  432. done chan struct{}
  433. }
  434. func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
  435. at := time.Duration(ps.AckDeadlineSeconds) * time.Second
  436. if at == 0 {
  437. at = 10 * time.Second
  438. }
  439. return &subscription{
  440. topic: t,
  441. mu: mu,
  442. proto: ps,
  443. ackTimeout: at,
  444. msgs: map[string]*message{},
  445. done: make(chan struct{}),
  446. }
  447. }
  448. func (s *subscription) start(wg *sync.WaitGroup) {
  449. wg.Add(1)
  450. go func() {
  451. defer wg.Done()
  452. for {
  453. select {
  454. case <-s.done:
  455. return
  456. case <-time.After(10 * time.Millisecond):
  457. s.deliver()
  458. }
  459. }
  460. }()
  461. }
  462. func (s *subscription) stop() {
  463. close(s.done)
  464. }
  465. func (s *gServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
  466. s.mu.Lock()
  467. defer s.mu.Unlock()
  468. if req.Subscription == "" {
  469. return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
  470. }
  471. sub := s.subs[req.Subscription]
  472. for _, id := range req.AckIds {
  473. sub.ack(id)
  474. }
  475. return &emptypb.Empty{}, nil
  476. }
  477. func (s *gServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
  478. s.mu.Lock()
  479. defer s.mu.Unlock()
  480. if req.Subscription == "" {
  481. return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
  482. }
  483. sub := s.subs[req.Subscription]
  484. dur := secsToDur(req.AckDeadlineSeconds)
  485. for _, id := range req.AckIds {
  486. sub.modifyAckDeadline(id, dur)
  487. }
  488. return &emptypb.Empty{}, nil
  489. }
  490. func (s *gServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
  491. // Receive initial message configuring the pull.
  492. req, err := sps.Recv()
  493. if err != nil {
  494. return err
  495. }
  496. if req.Subscription == "" {
  497. return status.Errorf(codes.InvalidArgument, "missing subscription")
  498. }
  499. s.mu.Lock()
  500. sub := s.subs[req.Subscription]
  501. s.mu.Unlock()
  502. if sub == nil {
  503. return status.Errorf(codes.NotFound, "subscription %s", req.Subscription)
  504. }
  505. // Create a new stream to handle the pull.
  506. st := sub.newStream(sps, s.streamTimeout)
  507. err = st.pull(&s.wg)
  508. sub.deleteStream(st)
  509. return err
  510. }
  511. var retentionDuration = 10 * time.Minute
  512. func (s *subscription) deliver() {
  513. s.mu.Lock()
  514. defer s.mu.Unlock()
  515. tNow := timeNow()
  516. for id, m := range s.msgs {
  517. // Mark a message as re-deliverable if its ack deadline has expired.
  518. if m.outstanding() && tNow.After(m.ackDeadline) {
  519. m.makeAvailable()
  520. }
  521. pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
  522. if err != nil {
  523. panic(err)
  524. }
  525. // Remove messages that have been undelivered for a long time.
  526. if !m.outstanding() && tNow.Sub(pubTime) > retentionDuration {
  527. delete(s.msgs, id)
  528. }
  529. }
  530. // Try to deliver each remaining message.
  531. curIndex := 0
  532. for _, m := range s.msgs {
  533. if m.outstanding() {
  534. continue
  535. }
  536. // If the message was never delivered before, start with the stream at
  537. // curIndex. If it was delivered before, start with the stream after the one
  538. // that owned it.
  539. if m.streamIndex < 0 {
  540. delIndex, ok := s.deliverMessage(m, curIndex, tNow)
  541. if !ok {
  542. break
  543. }
  544. curIndex = delIndex + 1
  545. m.streamIndex = curIndex
  546. } else {
  547. delIndex, ok := s.deliverMessage(m, m.streamIndex, tNow)
  548. if !ok {
  549. break
  550. }
  551. m.streamIndex = delIndex
  552. }
  553. }
  554. }
  555. // deliverMessage attempts to deliver m to the stream at index i. If it can't, it
  556. // tries streams i+1, i+2, ..., wrapping around. It returns the index of the stream
  557. // it delivered the message to, or 0, false if it didn't deliver the message because
  558. // there are no active streams.
  559. func (s *subscription) deliverMessage(m *message, i int, tNow time.Time) (int, bool) {
  560. for len(s.streams) > 0 {
  561. if i >= len(s.streams) {
  562. i = 0
  563. }
  564. st := s.streams[i]
  565. select {
  566. case <-st.done:
  567. s.streams = deleteStreamAt(s.streams, i)
  568. case st.msgc <- m.proto:
  569. (*m.deliveries)++
  570. m.ackDeadline = tNow.Add(st.ackTimeout)
  571. return i, true
  572. }
  573. }
  574. return 0, false
  575. }
  576. func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
  577. st := &stream{
  578. sub: s,
  579. done: make(chan struct{}),
  580. msgc: make(chan *pb.ReceivedMessage),
  581. gstream: gs,
  582. ackTimeout: s.ackTimeout,
  583. timeout: timeout,
  584. }
  585. s.mu.Lock()
  586. s.streams = append(s.streams, st)
  587. s.mu.Unlock()
  588. return st
  589. }
  590. func (s *subscription) deleteStream(st *stream) {
  591. s.mu.Lock()
  592. defer s.mu.Unlock()
  593. var i int
  594. for i = 0; i < len(s.streams); i++ {
  595. if s.streams[i] == st {
  596. break
  597. }
  598. }
  599. if i < len(s.streams) {
  600. s.streams = deleteStreamAt(s.streams, i)
  601. }
  602. }
  603. func deleteStreamAt(s []*stream, i int) []*stream {
  604. // Preserve order for round-robin delivery.
  605. return append(s[:i], s[i+1:]...)
  606. }
  607. type message struct {
  608. proto *pb.ReceivedMessage
  609. publishTime time.Time
  610. ackDeadline time.Time
  611. deliveries *int
  612. acks *int
  613. streamIndex int // index of stream that currently owns msg, for round-robin delivery
  614. }
  615. // A message is outstanding if it is owned by some stream.
  616. func (m *message) outstanding() bool {
  617. return !m.ackDeadline.IsZero()
  618. }
  619. func (m *message) makeAvailable() {
  620. m.ackDeadline = time.Time{}
  621. }
  622. type stream struct {
  623. sub *subscription
  624. done chan struct{} // closed when the stream is finished
  625. msgc chan *pb.ReceivedMessage
  626. gstream pb.Subscriber_StreamingPullServer
  627. ackTimeout time.Duration
  628. timeout time.Duration
  629. }
  630. // pull manages the StreamingPull interaction for the life of the stream.
  631. func (st *stream) pull(wg *sync.WaitGroup) error {
  632. errc := make(chan error, 2)
  633. wg.Add(2)
  634. go func() {
  635. defer wg.Done()
  636. errc <- st.sendLoop()
  637. }()
  638. go func() {
  639. defer wg.Done()
  640. errc <- st.recvLoop()
  641. }()
  642. var tchan <-chan time.Time
  643. if st.timeout > 0 {
  644. tchan = time.After(st.timeout)
  645. }
  646. // Wait until one of the goroutines returns an error, or we time out.
  647. var err error
  648. select {
  649. case err = <-errc:
  650. if err == io.EOF {
  651. err = nil
  652. }
  653. case <-tchan:
  654. }
  655. close(st.done) // stop the other goroutine
  656. return err
  657. }
  658. func (st *stream) sendLoop() error {
  659. for {
  660. select {
  661. case <-st.done:
  662. return nil
  663. case rm := <-st.msgc:
  664. res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
  665. if err := st.gstream.Send(res); err != nil {
  666. return err
  667. }
  668. }
  669. }
  670. }
  671. func (st *stream) recvLoop() error {
  672. for {
  673. req, err := st.gstream.Recv()
  674. if err != nil {
  675. return err
  676. }
  677. st.sub.handleStreamingPullRequest(st, req)
  678. }
  679. }
  680. func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
  681. // Lock the entire server.
  682. s.mu.Lock()
  683. defer s.mu.Unlock()
  684. for _, ackID := range req.AckIds {
  685. s.ack(ackID)
  686. }
  687. for i, id := range req.ModifyDeadlineAckIds {
  688. s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
  689. }
  690. if req.StreamAckDeadlineSeconds > 0 {
  691. st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
  692. }
  693. }
  694. func (s *subscription) ack(id string) {
  695. m := s.msgs[id]
  696. if m != nil {
  697. (*m.acks)++
  698. delete(s.msgs, id)
  699. }
  700. }
  701. func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
  702. m := s.msgs[id]
  703. if m == nil { // already acked: ignore.
  704. return
  705. }
  706. if d == 0 { // nack
  707. m.makeAvailable()
  708. } else { // extend the deadline by d
  709. m.ackDeadline = timeNow().Add(d)
  710. }
  711. }
  712. func secsToDur(secs int32) time.Duration {
  713. return time.Duration(secs) * time.Second
  714. }