You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

999 lines
26 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package pstest provides a fake Cloud PubSub service for testing. It implements a
  15. // simplified form of the service, suitable for unit tests. It may behave
  16. // differently from the actual service in ways in which the service is
  17. // non-deterministic or unspecified: timing, delivery order, etc.
  18. //
  19. // This package is EXPERIMENTAL and is subject to change without notice.
  20. //
  21. // See the example for usage.
  22. package pstest
  23. import (
  24. "context"
  25. "fmt"
  26. "io"
  27. "path"
  28. "sort"
  29. "strings"
  30. "sync"
  31. "sync/atomic"
  32. "time"
  33. "cloud.google.com/go/internal/testutil"
  34. "github.com/golang/protobuf/ptypes"
  35. durpb "github.com/golang/protobuf/ptypes/duration"
  36. emptypb "github.com/golang/protobuf/ptypes/empty"
  37. pb "google.golang.org/genproto/googleapis/pubsub/v1"
  38. "google.golang.org/grpc/codes"
  39. "google.golang.org/grpc/status"
  40. )
  41. // For testing. Note that even though changes to the now variable are atomic, a call
  42. // to the stored function can race with a change to that function. This could be a
  43. // problem if tests are run in parallel, or even if concurrent parts of the same test
  44. // change the value of the variable.
  45. var now atomic.Value
  46. func init() {
  47. now.Store(time.Now)
  48. ResetMinAckDeadline()
  49. }
  50. func timeNow() time.Time {
  51. return now.Load().(func() time.Time)()
  52. }
  53. // Server is a fake Pub/Sub server.
  54. type Server struct {
  55. srv *testutil.Server
  56. Addr string // The address that the server is listening on.
  57. GServer GServer // Not intended to be used directly.
  58. }
  59. // GServer is the underlying service implementor. It is not intended to be used
  60. // directly.
  61. type GServer struct {
  62. pb.PublisherServer
  63. pb.SubscriberServer
  64. mu sync.Mutex
  65. topics map[string]*topic
  66. subs map[string]*subscription
  67. msgs []*Message // all messages ever published
  68. msgsByID map[string]*Message
  69. wg sync.WaitGroup
  70. nextID int
  71. streamTimeout time.Duration
  72. }
  73. // NewServer creates a new fake server running in the current process.
  74. func NewServer() *Server {
  75. srv, err := testutil.NewServer()
  76. if err != nil {
  77. panic(fmt.Sprintf("pstest.NewServer: %v", err))
  78. }
  79. s := &Server{
  80. srv: srv,
  81. Addr: srv.Addr,
  82. GServer: GServer{
  83. topics: map[string]*topic{},
  84. subs: map[string]*subscription{},
  85. msgsByID: map[string]*Message{},
  86. },
  87. }
  88. pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
  89. pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
  90. srv.Start()
  91. return s
  92. }
  93. // Publish behaves as if the Publish RPC was called with a message with the given
  94. // data and attrs. It returns the ID of the message.
  95. // The topic will be created if it doesn't exist.
  96. //
  97. // Publish panics if there is an error, which is appropriate for testing.
  98. func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
  99. const topicPattern = "projects/*/topics/*"
  100. ok, err := path.Match(topicPattern, topic)
  101. if err != nil {
  102. panic(err)
  103. }
  104. if !ok {
  105. panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
  106. }
  107. _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
  108. req := &pb.PublishRequest{
  109. Topic: topic,
  110. Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
  111. }
  112. res, err := s.GServer.Publish(context.TODO(), req)
  113. if err != nil {
  114. panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
  115. }
  116. return res.MessageIds[0]
  117. }
  118. // SetStreamTimeout sets the amount of time a stream will be active before it shuts
  119. // itself down. This mimics the real service's behavior of closing streams after 30
  120. // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
  121. // down.
  122. func (s *Server) SetStreamTimeout(d time.Duration) {
  123. s.GServer.mu.Lock()
  124. defer s.GServer.mu.Unlock()
  125. s.GServer.streamTimeout = d
  126. }
  127. // A Message is a message that was published to the server.
  128. type Message struct {
  129. ID string
  130. Data []byte
  131. Attributes map[string]string
  132. PublishTime time.Time
  133. Deliveries int // number of times delivery of the message was attempted
  134. Acks int // number of acks received from clients
  135. // protected by server mutex
  136. deliveries int
  137. acks int
  138. Modacks []Modack // modacks received by server for this message
  139. }
  140. // Modack represents a modack sent to the server.
  141. type Modack struct {
  142. AckID string
  143. AckDeadline int32
  144. ReceivedAt time.Time
  145. }
  146. // Messages returns information about all messages ever published.
  147. func (s *Server) Messages() []*Message {
  148. s.GServer.mu.Lock()
  149. defer s.GServer.mu.Unlock()
  150. var msgs []*Message
  151. for _, m := range s.GServer.msgs {
  152. m.Deliveries = m.deliveries
  153. m.Acks = m.acks
  154. msgs = append(msgs, m)
  155. }
  156. return msgs
  157. }
  158. // Message returns the message with the given ID, or nil if no message
  159. // with that ID was published.
  160. func (s *Server) Message(id string) *Message {
  161. s.GServer.mu.Lock()
  162. defer s.GServer.mu.Unlock()
  163. m := s.GServer.msgsByID[id]
  164. if m != nil {
  165. m.Deliveries = m.deliveries
  166. m.Acks = m.acks
  167. }
  168. return m
  169. }
  170. // Wait blocks until all server activity has completed.
  171. func (s *Server) Wait() {
  172. s.GServer.wg.Wait()
  173. }
  174. // Close shuts down the server and releases all resources.
  175. func (s *Server) Close() error {
  176. s.srv.Close()
  177. s.GServer.mu.Lock()
  178. defer s.GServer.mu.Unlock()
  179. for _, sub := range s.GServer.subs {
  180. sub.stop()
  181. }
  182. return nil
  183. }
  184. func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
  185. s.mu.Lock()
  186. defer s.mu.Unlock()
  187. if s.topics[t.Name] != nil {
  188. return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
  189. }
  190. top := newTopic(t)
  191. s.topics[t.Name] = top
  192. return top.proto, nil
  193. }
  194. func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
  195. s.mu.Lock()
  196. defer s.mu.Unlock()
  197. if t := s.topics[req.Topic]; t != nil {
  198. return t.proto, nil
  199. }
  200. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  201. }
  202. func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
  203. s.mu.Lock()
  204. defer s.mu.Unlock()
  205. t := s.topics[req.Topic.Name]
  206. if t == nil {
  207. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
  208. }
  209. for _, path := range req.UpdateMask.Paths {
  210. switch path {
  211. case "labels":
  212. t.proto.Labels = req.Topic.Labels
  213. case "message_storage_policy": // "fetch" the policy
  214. t.proto.MessageStoragePolicy = &pb.MessageStoragePolicy{AllowedPersistenceRegions: []string{"US"}}
  215. default:
  216. return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
  217. }
  218. }
  219. return t.proto, nil
  220. }
  221. func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
  222. s.mu.Lock()
  223. defer s.mu.Unlock()
  224. var names []string
  225. for n := range s.topics {
  226. if strings.HasPrefix(n, req.Project) {
  227. names = append(names, n)
  228. }
  229. }
  230. sort.Strings(names)
  231. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  232. if err != nil {
  233. return nil, err
  234. }
  235. res := &pb.ListTopicsResponse{NextPageToken: nextToken}
  236. for i := from; i < to; i++ {
  237. res.Topics = append(res.Topics, s.topics[names[i]].proto)
  238. }
  239. return res, nil
  240. }
  241. func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
  242. s.mu.Lock()
  243. defer s.mu.Unlock()
  244. var names []string
  245. for name, sub := range s.subs {
  246. if sub.topic.proto.Name == req.Topic {
  247. names = append(names, name)
  248. }
  249. }
  250. sort.Strings(names)
  251. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  252. if err != nil {
  253. return nil, err
  254. }
  255. return &pb.ListTopicSubscriptionsResponse{
  256. Subscriptions: names[from:to],
  257. NextPageToken: nextToken,
  258. }, nil
  259. }
  260. func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
  261. s.mu.Lock()
  262. defer s.mu.Unlock()
  263. t := s.topics[req.Topic]
  264. if t == nil {
  265. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  266. }
  267. t.stop()
  268. delete(s.topics, req.Topic)
  269. return &emptypb.Empty{}, nil
  270. }
  271. func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
  272. s.mu.Lock()
  273. defer s.mu.Unlock()
  274. if ps.Name == "" {
  275. return nil, status.Errorf(codes.InvalidArgument, "missing name")
  276. }
  277. if s.subs[ps.Name] != nil {
  278. return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
  279. }
  280. if ps.Topic == "" {
  281. return nil, status.Errorf(codes.InvalidArgument, "missing topic")
  282. }
  283. top := s.topics[ps.Topic]
  284. if top == nil {
  285. return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
  286. }
  287. if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
  288. return nil, err
  289. }
  290. if ps.MessageRetentionDuration == nil {
  291. ps.MessageRetentionDuration = defaultMessageRetentionDuration
  292. }
  293. if err := checkMRD(ps.MessageRetentionDuration); err != nil {
  294. return nil, err
  295. }
  296. if ps.PushConfig == nil {
  297. ps.PushConfig = &pb.PushConfig{}
  298. }
  299. sub := newSubscription(top, &s.mu, ps)
  300. top.subs[ps.Name] = sub
  301. s.subs[ps.Name] = sub
  302. sub.start(&s.wg)
  303. return ps, nil
  304. }
  305. // Can be set for testing.
  306. var minAckDeadlineSecs int32
  307. // SetMinAckDeadline changes the minack deadline to n. Must be
  308. // greater than or equal to 1 second. Remember to reset this value
  309. // to the default after your test changes it. Example usage:
  310. // pstest.SetMinAckDeadlineSecs(1)
  311. // defer pstest.ResetMinAckDeadlineSecs()
  312. func SetMinAckDeadline(n time.Duration) {
  313. if n < time.Second {
  314. panic("SetMinAckDeadline expects a value greater than 1 second")
  315. }
  316. minAckDeadlineSecs = int32(n / time.Second)
  317. }
  318. // ResetMinAckDeadline resets the minack deadline to the default.
  319. func ResetMinAckDeadline() {
  320. minAckDeadlineSecs = 10
  321. }
  322. func checkAckDeadline(ads int32) error {
  323. if ads < minAckDeadlineSecs || ads > 600 {
  324. // PubSub service returns Unknown.
  325. return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
  326. }
  327. return nil
  328. }
  329. const (
  330. minMessageRetentionDuration = 10 * time.Minute
  331. maxMessageRetentionDuration = 168 * time.Hour
  332. )
  333. var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
  334. func checkMRD(pmrd *durpb.Duration) error {
  335. mrd, err := ptypes.Duration(pmrd)
  336. if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
  337. return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
  338. }
  339. return nil
  340. }
  341. func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
  342. s.mu.Lock()
  343. defer s.mu.Unlock()
  344. sub, err := s.findSubscription(req.Subscription)
  345. if err != nil {
  346. return nil, err
  347. }
  348. return sub.proto, nil
  349. }
  350. func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
  351. if req.Subscription == nil {
  352. return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
  353. }
  354. s.mu.Lock()
  355. defer s.mu.Unlock()
  356. sub, err := s.findSubscription(req.Subscription.Name)
  357. if err != nil {
  358. return nil, err
  359. }
  360. for _, path := range req.UpdateMask.Paths {
  361. switch path {
  362. case "push_config":
  363. sub.proto.PushConfig = req.Subscription.PushConfig
  364. case "ack_deadline_seconds":
  365. a := req.Subscription.AckDeadlineSeconds
  366. if err := checkAckDeadline(a); err != nil {
  367. return nil, err
  368. }
  369. sub.proto.AckDeadlineSeconds = a
  370. case "retain_acked_messages":
  371. sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
  372. case "message_retention_duration":
  373. if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
  374. return nil, err
  375. }
  376. sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
  377. case "labels":
  378. sub.proto.Labels = req.Subscription.Labels
  379. default:
  380. return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
  381. }
  382. }
  383. return sub.proto, nil
  384. }
  385. func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
  386. s.mu.Lock()
  387. defer s.mu.Unlock()
  388. var names []string
  389. for name := range s.subs {
  390. if strings.HasPrefix(name, req.Project) {
  391. names = append(names, name)
  392. }
  393. }
  394. sort.Strings(names)
  395. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  396. if err != nil {
  397. return nil, err
  398. }
  399. res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
  400. for i := from; i < to; i++ {
  401. res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
  402. }
  403. return res, nil
  404. }
  405. func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
  406. s.mu.Lock()
  407. defer s.mu.Unlock()
  408. sub, err := s.findSubscription(req.Subscription)
  409. if err != nil {
  410. return nil, err
  411. }
  412. sub.stop()
  413. delete(s.subs, req.Subscription)
  414. sub.topic.deleteSub(sub)
  415. return &emptypb.Empty{}, nil
  416. }
  417. func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
  418. s.mu.Lock()
  419. defer s.mu.Unlock()
  420. if req.Topic == "" {
  421. return nil, status.Errorf(codes.InvalidArgument, "missing topic")
  422. }
  423. top := s.topics[req.Topic]
  424. if top == nil {
  425. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  426. }
  427. var ids []string
  428. for _, pm := range req.Messages {
  429. id := fmt.Sprintf("m%d", s.nextID)
  430. s.nextID++
  431. pm.MessageId = id
  432. pubTime := timeNow()
  433. tsPubTime, err := ptypes.TimestampProto(pubTime)
  434. if err != nil {
  435. return nil, status.Errorf(codes.Internal, err.Error())
  436. }
  437. pm.PublishTime = tsPubTime
  438. m := &Message{
  439. ID: id,
  440. Data: pm.Data,
  441. Attributes: pm.Attributes,
  442. PublishTime: pubTime,
  443. }
  444. top.publish(pm, m)
  445. ids = append(ids, id)
  446. s.msgs = append(s.msgs, m)
  447. s.msgsByID[id] = m
  448. }
  449. return &pb.PublishResponse{MessageIds: ids}, nil
  450. }
  451. type topic struct {
  452. proto *pb.Topic
  453. subs map[string]*subscription
  454. }
  455. func newTopic(pt *pb.Topic) *topic {
  456. return &topic{
  457. proto: pt,
  458. subs: map[string]*subscription{},
  459. }
  460. }
  461. func (t *topic) stop() {
  462. for _, sub := range t.subs {
  463. sub.proto.Topic = "_deleted-topic_"
  464. sub.stop()
  465. }
  466. }
  467. func (t *topic) deleteSub(sub *subscription) {
  468. delete(t.subs, sub.proto.Name)
  469. }
  470. func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
  471. for _, s := range t.subs {
  472. s.msgs[pm.MessageId] = &message{
  473. publishTime: m.PublishTime,
  474. proto: &pb.ReceivedMessage{
  475. AckId: pm.MessageId,
  476. Message: pm,
  477. },
  478. deliveries: &m.deliveries,
  479. acks: &m.acks,
  480. streamIndex: -1,
  481. }
  482. }
  483. }
  484. type subscription struct {
  485. topic *topic
  486. mu *sync.Mutex // the server mutex, here for convenience
  487. proto *pb.Subscription
  488. ackTimeout time.Duration
  489. msgs map[string]*message // unacked messages by message ID
  490. streams []*stream
  491. done chan struct{}
  492. }
  493. func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
  494. at := time.Duration(ps.AckDeadlineSeconds) * time.Second
  495. if at == 0 {
  496. at = 10 * time.Second
  497. }
  498. return &subscription{
  499. topic: t,
  500. mu: mu,
  501. proto: ps,
  502. ackTimeout: at,
  503. msgs: map[string]*message{},
  504. done: make(chan struct{}),
  505. }
  506. }
  507. func (s *subscription) start(wg *sync.WaitGroup) {
  508. wg.Add(1)
  509. go func() {
  510. defer wg.Done()
  511. for {
  512. select {
  513. case <-s.done:
  514. return
  515. case <-time.After(10 * time.Millisecond):
  516. s.deliver()
  517. }
  518. }
  519. }()
  520. }
  521. func (s *subscription) stop() {
  522. close(s.done)
  523. }
  524. func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
  525. s.mu.Lock()
  526. defer s.mu.Unlock()
  527. sub, err := s.findSubscription(req.Subscription)
  528. if err != nil {
  529. return nil, err
  530. }
  531. for _, id := range req.AckIds {
  532. sub.ack(id)
  533. }
  534. return &emptypb.Empty{}, nil
  535. }
  536. func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
  537. s.mu.Lock()
  538. defer s.mu.Unlock()
  539. sub, err := s.findSubscription(req.Subscription)
  540. if err != nil {
  541. return nil, err
  542. }
  543. now := time.Now()
  544. for _, id := range req.AckIds {
  545. s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
  546. }
  547. dur := secsToDur(req.AckDeadlineSeconds)
  548. for _, id := range req.AckIds {
  549. sub.modifyAckDeadline(id, dur)
  550. }
  551. return &emptypb.Empty{}, nil
  552. }
  553. func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
  554. s.mu.Lock()
  555. sub, err := s.findSubscription(req.Subscription)
  556. if err != nil {
  557. s.mu.Unlock()
  558. return nil, err
  559. }
  560. max := int(req.MaxMessages)
  561. if max < 0 {
  562. s.mu.Unlock()
  563. return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
  564. }
  565. if max == 0 { // MaxMessages not specified; use a default.
  566. max = 1000
  567. }
  568. msgs := sub.pull(max)
  569. s.mu.Unlock()
  570. // Implement the spec from the pubsub proto:
  571. // "If ReturnImmediately set to true, the system will respond immediately even if
  572. // it there are no messages available to return in the `Pull` response.
  573. // Otherwise, the system may wait (for a bounded amount of time) until at
  574. // least one message is available, rather than returning no messages."
  575. if len(msgs) == 0 && !req.ReturnImmediately {
  576. // Wait for a short amount of time for a message.
  577. // TODO: signal when a message arrives, so we don't wait the whole time.
  578. select {
  579. case <-ctx.Done():
  580. return nil, ctx.Err()
  581. case <-time.After(500 * time.Millisecond):
  582. s.mu.Lock()
  583. msgs = sub.pull(max)
  584. s.mu.Unlock()
  585. }
  586. }
  587. return &pb.PullResponse{ReceivedMessages: msgs}, nil
  588. }
  589. func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
  590. // Receive initial message configuring the pull.
  591. req, err := sps.Recv()
  592. if err != nil {
  593. return err
  594. }
  595. s.mu.Lock()
  596. sub, err := s.findSubscription(req.Subscription)
  597. s.mu.Unlock()
  598. if err != nil {
  599. return err
  600. }
  601. // Create a new stream to handle the pull.
  602. st := sub.newStream(sps, s.streamTimeout)
  603. err = st.pull(&s.wg)
  604. sub.deleteStream(st)
  605. return err
  606. }
  607. func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
  608. // Only handle time-based seeking for now.
  609. // This fake doesn't deal with snapshots.
  610. var target time.Time
  611. switch v := req.Target.(type) {
  612. case nil:
  613. return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
  614. case *pb.SeekRequest_Time:
  615. var err error
  616. target, err = ptypes.Timestamp(v.Time)
  617. if err != nil {
  618. return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err)
  619. }
  620. default:
  621. return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
  622. }
  623. // The entire server must be locked while doing the work below,
  624. // because the messages don't have any other synchronization.
  625. s.mu.Lock()
  626. defer s.mu.Unlock()
  627. sub, err := s.findSubscription(req.Subscription)
  628. if err != nil {
  629. return nil, err
  630. }
  631. // Drop all messages from sub that were published before the target time.
  632. for id, m := range sub.msgs {
  633. if m.publishTime.Before(target) {
  634. delete(sub.msgs, id)
  635. (*m.acks)++
  636. }
  637. }
  638. // Un-ack any already-acked messages after this time;
  639. // redelivering them to the subscription is the closest analogue here.
  640. for _, m := range s.msgs {
  641. if m.PublishTime.Before(target) {
  642. continue
  643. }
  644. sub.msgs[m.ID] = &message{
  645. publishTime: m.PublishTime,
  646. proto: &pb.ReceivedMessage{
  647. AckId: m.ID,
  648. // This was not preserved!
  649. //Message: pm,
  650. },
  651. deliveries: &m.deliveries,
  652. acks: &m.acks,
  653. streamIndex: -1,
  654. }
  655. }
  656. return &pb.SeekResponse{}, nil
  657. }
  658. // Gets a subscription that must exist.
  659. // Must be called with the lock held.
  660. func (s *GServer) findSubscription(name string) (*subscription, error) {
  661. if name == "" {
  662. return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
  663. }
  664. sub := s.subs[name]
  665. if sub == nil {
  666. return nil, status.Errorf(codes.NotFound, "subscription %s", name)
  667. }
  668. return sub, nil
  669. }
  670. // Must be called with the lock held.
  671. func (s *subscription) pull(max int) []*pb.ReceivedMessage {
  672. now := timeNow()
  673. s.maintainMessages(now)
  674. var msgs []*pb.ReceivedMessage
  675. for _, m := range s.msgs {
  676. if m.outstanding() {
  677. continue
  678. }
  679. (*m.deliveries)++
  680. m.ackDeadline = now.Add(s.ackTimeout)
  681. msgs = append(msgs, m.proto)
  682. if len(msgs) >= max {
  683. break
  684. }
  685. }
  686. return msgs
  687. }
  688. func (s *subscription) deliver() {
  689. s.mu.Lock()
  690. defer s.mu.Unlock()
  691. now := timeNow()
  692. s.maintainMessages(now)
  693. // Try to deliver each remaining message.
  694. curIndex := 0
  695. for _, m := range s.msgs {
  696. if m.outstanding() {
  697. continue
  698. }
  699. // If the message was never delivered before, start with the stream at
  700. // curIndex. If it was delivered before, start with the stream after the one
  701. // that owned it.
  702. if m.streamIndex < 0 {
  703. delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
  704. if !ok {
  705. break
  706. }
  707. curIndex = delIndex + 1
  708. m.streamIndex = curIndex
  709. } else {
  710. delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
  711. if !ok {
  712. break
  713. }
  714. m.streamIndex = delIndex
  715. }
  716. }
  717. }
  718. // tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
  719. // tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
  720. // exits.
  721. //
  722. // It returns the index of the stream it delivered the message to, or 0, false if
  723. // it didn't deliver the message.
  724. //
  725. // Must be called with the lock held.
  726. func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
  727. for i := 0; i < len(s.streams); i++ {
  728. idx := (i + start) % len(s.streams)
  729. st := s.streams[idx]
  730. select {
  731. case <-st.done:
  732. s.streams = deleteStreamAt(s.streams, idx)
  733. i--
  734. case st.msgc <- m.proto:
  735. (*m.deliveries)++
  736. m.ackDeadline = now.Add(st.ackTimeout)
  737. return idx, true
  738. default:
  739. }
  740. }
  741. return 0, false
  742. }
  743. var retentionDuration = 10 * time.Minute
  744. // Must be called with the lock held.
  745. func (s *subscription) maintainMessages(now time.Time) {
  746. for id, m := range s.msgs {
  747. // Mark a message as re-deliverable if its ack deadline has expired.
  748. if m.outstanding() && now.After(m.ackDeadline) {
  749. m.makeAvailable()
  750. }
  751. pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
  752. if err != nil {
  753. panic(err)
  754. }
  755. // Remove messages that have been undelivered for a long time.
  756. if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
  757. delete(s.msgs, id)
  758. }
  759. }
  760. }
  761. func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
  762. st := &stream{
  763. sub: s,
  764. done: make(chan struct{}),
  765. msgc: make(chan *pb.ReceivedMessage),
  766. gstream: gs,
  767. ackTimeout: s.ackTimeout,
  768. timeout: timeout,
  769. }
  770. s.mu.Lock()
  771. s.streams = append(s.streams, st)
  772. s.mu.Unlock()
  773. return st
  774. }
  775. func (s *subscription) deleteStream(st *stream) {
  776. s.mu.Lock()
  777. defer s.mu.Unlock()
  778. var i int
  779. for i = 0; i < len(s.streams); i++ {
  780. if s.streams[i] == st {
  781. break
  782. }
  783. }
  784. if i < len(s.streams) {
  785. s.streams = deleteStreamAt(s.streams, i)
  786. }
  787. }
  788. func deleteStreamAt(s []*stream, i int) []*stream {
  789. // Preserve order for round-robin delivery.
  790. return append(s[:i], s[i+1:]...)
  791. }
  792. type message struct {
  793. proto *pb.ReceivedMessage
  794. publishTime time.Time
  795. ackDeadline time.Time
  796. deliveries *int
  797. acks *int
  798. streamIndex int // index of stream that currently owns msg, for round-robin delivery
  799. }
  800. // A message is outstanding if it is owned by some stream.
  801. func (m *message) outstanding() bool {
  802. return !m.ackDeadline.IsZero()
  803. }
  804. func (m *message) makeAvailable() {
  805. m.ackDeadline = time.Time{}
  806. }
  807. type stream struct {
  808. sub *subscription
  809. done chan struct{} // closed when the stream is finished
  810. msgc chan *pb.ReceivedMessage
  811. gstream pb.Subscriber_StreamingPullServer
  812. ackTimeout time.Duration
  813. timeout time.Duration
  814. }
  815. // pull manages the StreamingPull interaction for the life of the stream.
  816. func (st *stream) pull(wg *sync.WaitGroup) error {
  817. errc := make(chan error, 2)
  818. wg.Add(2)
  819. go func() {
  820. defer wg.Done()
  821. errc <- st.sendLoop()
  822. }()
  823. go func() {
  824. defer wg.Done()
  825. errc <- st.recvLoop()
  826. }()
  827. var tchan <-chan time.Time
  828. if st.timeout > 0 {
  829. tchan = time.After(st.timeout)
  830. }
  831. // Wait until one of the goroutines returns an error, or we time out.
  832. var err error
  833. select {
  834. case err = <-errc:
  835. if err == io.EOF {
  836. err = nil
  837. }
  838. case <-tchan:
  839. }
  840. close(st.done) // stop the other goroutine
  841. return err
  842. }
  843. func (st *stream) sendLoop() error {
  844. for {
  845. select {
  846. case <-st.done:
  847. return nil
  848. case rm := <-st.msgc:
  849. res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
  850. if err := st.gstream.Send(res); err != nil {
  851. return err
  852. }
  853. }
  854. }
  855. }
  856. func (st *stream) recvLoop() error {
  857. for {
  858. req, err := st.gstream.Recv()
  859. if err != nil {
  860. return err
  861. }
  862. st.sub.handleStreamingPullRequest(st, req)
  863. }
  864. }
  865. func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
  866. // Lock the entire server.
  867. s.mu.Lock()
  868. defer s.mu.Unlock()
  869. for _, ackID := range req.AckIds {
  870. s.ack(ackID)
  871. }
  872. for i, id := range req.ModifyDeadlineAckIds {
  873. s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
  874. }
  875. if req.StreamAckDeadlineSeconds > 0 {
  876. st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
  877. }
  878. }
  879. // Must be called with the lock held.
  880. func (s *subscription) ack(id string) {
  881. m := s.msgs[id]
  882. if m != nil {
  883. (*m.acks)++
  884. delete(s.msgs, id)
  885. }
  886. }
  887. // Must be called with the lock held.
  888. func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
  889. m := s.msgs[id]
  890. if m == nil { // already acked: ignore.
  891. return
  892. }
  893. if d == 0 { // nack
  894. m.makeAvailable()
  895. } else { // extend the deadline by d
  896. m.ackDeadline = timeNow().Add(d)
  897. }
  898. }
  899. func secsToDur(secs int32) time.Duration {
  900. return time.Duration(secs) * time.Second
  901. }