|
- // Copyright 2017 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- // Package pstest provides a fake Cloud PubSub service for testing. It implements a
- // simplified form of the service, suitable for unit tests. It may behave
- // differently from the actual service in ways in which the service is
- // non-deterministic or unspecified: timing, delivery order, etc.
- //
- // This package is EXPERIMENTAL and is subject to change without notice.
- //
- // See the example for usage.
- package pstest
-
- import (
- "context"
- "fmt"
- "io"
- "path"
- "sort"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "cloud.google.com/go/internal/testutil"
- "github.com/golang/protobuf/ptypes"
- durpb "github.com/golang/protobuf/ptypes/duration"
- emptypb "github.com/golang/protobuf/ptypes/empty"
- pb "google.golang.org/genproto/googleapis/pubsub/v1"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
- )
-
- // For testing. Note that even though changes to the now variable are atomic, a call
- // to the stored function can race with a change to that function. This could be a
- // problem if tests are run in parallel, or even if concurrent parts of the same test
- // change the value of the variable.
- var now atomic.Value
-
- func init() {
- now.Store(time.Now)
- ResetMinAckDeadline()
- }
-
- func timeNow() time.Time {
- return now.Load().(func() time.Time)()
- }
-
- // Server is a fake Pub/Sub server.
- type Server struct {
- srv *testutil.Server
- Addr string // The address that the server is listening on.
- GServer GServer // Not intended to be used directly.
- }
-
- // GServer is the underlying service implementor. It is not intended to be used
- // directly.
- type GServer struct {
- pb.PublisherServer
- pb.SubscriberServer
-
- mu sync.Mutex
- topics map[string]*topic
- subs map[string]*subscription
- msgs []*Message // all messages ever published
- msgsByID map[string]*Message
- wg sync.WaitGroup
- nextID int
- streamTimeout time.Duration
- }
-
- // NewServer creates a new fake server running in the current process.
- func NewServer() *Server {
- srv, err := testutil.NewServer()
- if err != nil {
- panic(fmt.Sprintf("pstest.NewServer: %v", err))
- }
- s := &Server{
- srv: srv,
- Addr: srv.Addr,
- GServer: GServer{
- topics: map[string]*topic{},
- subs: map[string]*subscription{},
- msgsByID: map[string]*Message{},
- },
- }
- pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
- pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
- srv.Start()
- return s
- }
-
- // Publish behaves as if the Publish RPC was called with a message with the given
- // data and attrs. It returns the ID of the message.
- // The topic will be created if it doesn't exist.
- //
- // Publish panics if there is an error, which is appropriate for testing.
- func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
- const topicPattern = "projects/*/topics/*"
- ok, err := path.Match(topicPattern, topic)
- if err != nil {
- panic(err)
- }
- if !ok {
- panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
- }
- _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
- req := &pb.PublishRequest{
- Topic: topic,
- Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
- }
- res, err := s.GServer.Publish(context.TODO(), req)
- if err != nil {
- panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
- }
- return res.MessageIds[0]
- }
-
- // SetStreamTimeout sets the amount of time a stream will be active before it shuts
- // itself down. This mimics the real service's behavior of closing streams after 30
- // minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
- // down.
- func (s *Server) SetStreamTimeout(d time.Duration) {
- s.GServer.mu.Lock()
- defer s.GServer.mu.Unlock()
- s.GServer.streamTimeout = d
- }
-
- // A Message is a message that was published to the server.
- type Message struct {
- ID string
- Data []byte
- Attributes map[string]string
- PublishTime time.Time
- Deliveries int // number of times delivery of the message was attempted
- Acks int // number of acks received from clients
-
- // protected by server mutex
- deliveries int
- acks int
- Modacks []Modack // modacks received by server for this message
-
- }
-
- // Modack represents a modack sent to the server.
- type Modack struct {
- AckID string
- AckDeadline int32
- ReceivedAt time.Time
- }
-
- // Messages returns information about all messages ever published.
- func (s *Server) Messages() []*Message {
- s.GServer.mu.Lock()
- defer s.GServer.mu.Unlock()
-
- var msgs []*Message
- for _, m := range s.GServer.msgs {
- m.Deliveries = m.deliveries
- m.Acks = m.acks
- msgs = append(msgs, m)
- }
- return msgs
- }
-
- // Message returns the message with the given ID, or nil if no message
- // with that ID was published.
- func (s *Server) Message(id string) *Message {
- s.GServer.mu.Lock()
- defer s.GServer.mu.Unlock()
-
- m := s.GServer.msgsByID[id]
- if m != nil {
- m.Deliveries = m.deliveries
- m.Acks = m.acks
- }
- return m
- }
-
- // Wait blocks until all server activity has completed.
- func (s *Server) Wait() {
- s.GServer.wg.Wait()
- }
-
- // Close shuts down the server and releases all resources.
- func (s *Server) Close() error {
- s.srv.Close()
- s.GServer.mu.Lock()
- defer s.GServer.mu.Unlock()
- for _, sub := range s.GServer.subs {
- sub.stop()
- }
- return nil
- }
-
- func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if s.topics[t.Name] != nil {
- return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
- }
- top := newTopic(t)
- s.topics[t.Name] = top
- return top.proto, nil
- }
-
- func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if t := s.topics[req.Topic]; t != nil {
- return t.proto, nil
- }
- return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
- }
-
- func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- t := s.topics[req.Topic.Name]
- if t == nil {
- return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
- }
- for _, path := range req.UpdateMask.Paths {
- switch path {
- case "labels":
- t.proto.Labels = req.Topic.Labels
- case "message_storage_policy": // "fetch" the policy
- t.proto.MessageStoragePolicy = &pb.MessageStoragePolicy{AllowedPersistenceRegions: []string{"US"}}
- default:
- return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
- }
- }
- return t.proto, nil
- }
-
- func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- var names []string
- for n := range s.topics {
- if strings.HasPrefix(n, req.Project) {
- names = append(names, n)
- }
- }
- sort.Strings(names)
- from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
- if err != nil {
- return nil, err
- }
- res := &pb.ListTopicsResponse{NextPageToken: nextToken}
- for i := from; i < to; i++ {
- res.Topics = append(res.Topics, s.topics[names[i]].proto)
- }
- return res, nil
- }
-
- func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- var names []string
- for name, sub := range s.subs {
- if sub.topic.proto.Name == req.Topic {
- names = append(names, name)
- }
- }
- sort.Strings(names)
- from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
- if err != nil {
- return nil, err
- }
- return &pb.ListTopicSubscriptionsResponse{
- Subscriptions: names[from:to],
- NextPageToken: nextToken,
- }, nil
- }
-
- func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- t := s.topics[req.Topic]
- if t == nil {
- return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
- }
- t.stop()
- delete(s.topics, req.Topic)
- return &emptypb.Empty{}, nil
- }
-
- func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if ps.Name == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing name")
- }
- if s.subs[ps.Name] != nil {
- return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
- }
- if ps.Topic == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing topic")
- }
- top := s.topics[ps.Topic]
- if top == nil {
- return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
- }
- if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
- return nil, err
- }
- if ps.MessageRetentionDuration == nil {
- ps.MessageRetentionDuration = defaultMessageRetentionDuration
- }
- if err := checkMRD(ps.MessageRetentionDuration); err != nil {
- return nil, err
- }
- if ps.PushConfig == nil {
- ps.PushConfig = &pb.PushConfig{}
- }
-
- sub := newSubscription(top, &s.mu, ps)
- top.subs[ps.Name] = sub
- s.subs[ps.Name] = sub
- sub.start(&s.wg)
- return ps, nil
- }
-
- // Can be set for testing.
- var minAckDeadlineSecs int32
-
- // SetMinAckDeadline changes the minack deadline to n. Must be
- // greater than or equal to 1 second. Remember to reset this value
- // to the default after your test changes it. Example usage:
- // pstest.SetMinAckDeadlineSecs(1)
- // defer pstest.ResetMinAckDeadlineSecs()
- func SetMinAckDeadline(n time.Duration) {
- if n < time.Second {
- panic("SetMinAckDeadline expects a value greater than 1 second")
- }
-
- minAckDeadlineSecs = int32(n / time.Second)
- }
-
- // ResetMinAckDeadline resets the minack deadline to the default.
- func ResetMinAckDeadline() {
- minAckDeadlineSecs = 10
- }
-
- func checkAckDeadline(ads int32) error {
- if ads < minAckDeadlineSecs || ads > 600 {
- // PubSub service returns Unknown.
- return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
- }
- return nil
- }
-
- const (
- minMessageRetentionDuration = 10 * time.Minute
- maxMessageRetentionDuration = 168 * time.Hour
- )
-
- var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
-
- func checkMRD(pmrd *durpb.Duration) error {
- mrd, err := ptypes.Duration(pmrd)
- if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
- return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
- }
- return nil
- }
-
- func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- return nil, err
- }
- return sub.proto, nil
- }
-
- func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
- if req.Subscription == nil {
- return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
- }
- s.mu.Lock()
- defer s.mu.Unlock()
- sub, err := s.findSubscription(req.Subscription.Name)
- if err != nil {
- return nil, err
- }
- for _, path := range req.UpdateMask.Paths {
- switch path {
- case "push_config":
- sub.proto.PushConfig = req.Subscription.PushConfig
-
- case "ack_deadline_seconds":
- a := req.Subscription.AckDeadlineSeconds
- if err := checkAckDeadline(a); err != nil {
- return nil, err
- }
- sub.proto.AckDeadlineSeconds = a
-
- case "retain_acked_messages":
- sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
-
- case "message_retention_duration":
- if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
- return nil, err
- }
- sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
-
- case "labels":
- sub.proto.Labels = req.Subscription.Labels
-
- default:
- return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
- }
- }
- return sub.proto, nil
- }
-
- func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- var names []string
- for name := range s.subs {
- if strings.HasPrefix(name, req.Project) {
- names = append(names, name)
- }
- }
- sort.Strings(names)
- from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
- if err != nil {
- return nil, err
- }
- res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
- for i := from; i < to; i++ {
- res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
- }
- return res, nil
- }
-
- func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- return nil, err
- }
- sub.stop()
- delete(s.subs, req.Subscription)
- sub.topic.deleteSub(sub)
- return &emptypb.Empty{}, nil
- }
-
- func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if req.Topic == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing topic")
- }
- top := s.topics[req.Topic]
- if top == nil {
- return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
- }
- var ids []string
- for _, pm := range req.Messages {
- id := fmt.Sprintf("m%d", s.nextID)
- s.nextID++
- pm.MessageId = id
- pubTime := timeNow()
- tsPubTime, err := ptypes.TimestampProto(pubTime)
- if err != nil {
- return nil, status.Errorf(codes.Internal, err.Error())
- }
- pm.PublishTime = tsPubTime
- m := &Message{
- ID: id,
- Data: pm.Data,
- Attributes: pm.Attributes,
- PublishTime: pubTime,
- }
- top.publish(pm, m)
- ids = append(ids, id)
- s.msgs = append(s.msgs, m)
- s.msgsByID[id] = m
- }
- return &pb.PublishResponse{MessageIds: ids}, nil
- }
-
- type topic struct {
- proto *pb.Topic
- subs map[string]*subscription
- }
-
- func newTopic(pt *pb.Topic) *topic {
- return &topic{
- proto: pt,
- subs: map[string]*subscription{},
- }
- }
-
- func (t *topic) stop() {
- for _, sub := range t.subs {
- sub.proto.Topic = "_deleted-topic_"
- sub.stop()
- }
- }
-
- func (t *topic) deleteSub(sub *subscription) {
- delete(t.subs, sub.proto.Name)
- }
-
- func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
- for _, s := range t.subs {
- s.msgs[pm.MessageId] = &message{
- publishTime: m.PublishTime,
- proto: &pb.ReceivedMessage{
- AckId: pm.MessageId,
- Message: pm,
- },
- deliveries: &m.deliveries,
- acks: &m.acks,
- streamIndex: -1,
- }
- }
- }
-
- type subscription struct {
- topic *topic
- mu *sync.Mutex // the server mutex, here for convenience
- proto *pb.Subscription
- ackTimeout time.Duration
- msgs map[string]*message // unacked messages by message ID
- streams []*stream
- done chan struct{}
- }
-
- func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
- at := time.Duration(ps.AckDeadlineSeconds) * time.Second
- if at == 0 {
- at = 10 * time.Second
- }
- return &subscription{
- topic: t,
- mu: mu,
- proto: ps,
- ackTimeout: at,
- msgs: map[string]*message{},
- done: make(chan struct{}),
- }
- }
-
- func (s *subscription) start(wg *sync.WaitGroup) {
- wg.Add(1)
- go func() {
- defer wg.Done()
- for {
- select {
- case <-s.done:
- return
- case <-time.After(10 * time.Millisecond):
- s.deliver()
- }
- }
- }()
- }
-
- func (s *subscription) stop() {
- close(s.done)
- }
-
- func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- return nil, err
- }
- for _, id := range req.AckIds {
- sub.ack(id)
- }
- return &emptypb.Empty{}, nil
- }
-
- func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
- s.mu.Lock()
- defer s.mu.Unlock()
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- return nil, err
- }
- now := time.Now()
- for _, id := range req.AckIds {
- s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
- }
- dur := secsToDur(req.AckDeadlineSeconds)
- for _, id := range req.AckIds {
- sub.modifyAckDeadline(id, dur)
- }
- return &emptypb.Empty{}, nil
- }
-
- func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
- s.mu.Lock()
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- s.mu.Unlock()
- return nil, err
- }
- max := int(req.MaxMessages)
- if max < 0 {
- s.mu.Unlock()
- return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
- }
- if max == 0 { // MaxMessages not specified; use a default.
- max = 1000
- }
- msgs := sub.pull(max)
- s.mu.Unlock()
- // Implement the spec from the pubsub proto:
- // "If ReturnImmediately set to true, the system will respond immediately even if
- // it there are no messages available to return in the `Pull` response.
- // Otherwise, the system may wait (for a bounded amount of time) until at
- // least one message is available, rather than returning no messages."
- if len(msgs) == 0 && !req.ReturnImmediately {
- // Wait for a short amount of time for a message.
- // TODO: signal when a message arrives, so we don't wait the whole time.
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-time.After(500 * time.Millisecond):
- s.mu.Lock()
- msgs = sub.pull(max)
- s.mu.Unlock()
- }
- }
- return &pb.PullResponse{ReceivedMessages: msgs}, nil
- }
-
- func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
- // Receive initial message configuring the pull.
- req, err := sps.Recv()
- if err != nil {
- return err
- }
- s.mu.Lock()
- sub, err := s.findSubscription(req.Subscription)
- s.mu.Unlock()
- if err != nil {
- return err
- }
- // Create a new stream to handle the pull.
- st := sub.newStream(sps, s.streamTimeout)
- err = st.pull(&s.wg)
- sub.deleteStream(st)
- return err
- }
-
- func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
- // Only handle time-based seeking for now.
- // This fake doesn't deal with snapshots.
- var target time.Time
- switch v := req.Target.(type) {
- case nil:
- return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
- case *pb.SeekRequest_Time:
- var err error
- target, err = ptypes.Timestamp(v.Time)
- if err != nil {
- return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err)
- }
- default:
- return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
- }
-
- // The entire server must be locked while doing the work below,
- // because the messages don't have any other synchronization.
- s.mu.Lock()
- defer s.mu.Unlock()
- sub, err := s.findSubscription(req.Subscription)
- if err != nil {
- return nil, err
- }
- // Drop all messages from sub that were published before the target time.
- for id, m := range sub.msgs {
- if m.publishTime.Before(target) {
- delete(sub.msgs, id)
- (*m.acks)++
- }
- }
- // Un-ack any already-acked messages after this time;
- // redelivering them to the subscription is the closest analogue here.
- for _, m := range s.msgs {
- if m.PublishTime.Before(target) {
- continue
- }
- sub.msgs[m.ID] = &message{
- publishTime: m.PublishTime,
- proto: &pb.ReceivedMessage{
- AckId: m.ID,
- // This was not preserved!
- //Message: pm,
- },
- deliveries: &m.deliveries,
- acks: &m.acks,
- streamIndex: -1,
- }
- }
- return &pb.SeekResponse{}, nil
- }
-
- // Gets a subscription that must exist.
- // Must be called with the lock held.
- func (s *GServer) findSubscription(name string) (*subscription, error) {
- if name == "" {
- return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
- }
- sub := s.subs[name]
- if sub == nil {
- return nil, status.Errorf(codes.NotFound, "subscription %s", name)
- }
- return sub, nil
- }
-
- // Must be called with the lock held.
- func (s *subscription) pull(max int) []*pb.ReceivedMessage {
- now := timeNow()
- s.maintainMessages(now)
- var msgs []*pb.ReceivedMessage
- for _, m := range s.msgs {
- if m.outstanding() {
- continue
- }
- (*m.deliveries)++
- m.ackDeadline = now.Add(s.ackTimeout)
- msgs = append(msgs, m.proto)
- if len(msgs) >= max {
- break
- }
- }
- return msgs
- }
-
- func (s *subscription) deliver() {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- now := timeNow()
- s.maintainMessages(now)
- // Try to deliver each remaining message.
- curIndex := 0
- for _, m := range s.msgs {
- if m.outstanding() {
- continue
- }
- // If the message was never delivered before, start with the stream at
- // curIndex. If it was delivered before, start with the stream after the one
- // that owned it.
- if m.streamIndex < 0 {
- delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
- if !ok {
- break
- }
- curIndex = delIndex + 1
- m.streamIndex = curIndex
- } else {
- delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
- if !ok {
- break
- }
- m.streamIndex = delIndex
- }
- }
- }
-
- // tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
- // tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
- // exits.
- //
- // It returns the index of the stream it delivered the message to, or 0, false if
- // it didn't deliver the message.
- //
- // Must be called with the lock held.
- func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
- for i := 0; i < len(s.streams); i++ {
- idx := (i + start) % len(s.streams)
-
- st := s.streams[idx]
- select {
- case <-st.done:
- s.streams = deleteStreamAt(s.streams, idx)
- i--
-
- case st.msgc <- m.proto:
- (*m.deliveries)++
- m.ackDeadline = now.Add(st.ackTimeout)
- return idx, true
-
- default:
- }
- }
- return 0, false
- }
-
- var retentionDuration = 10 * time.Minute
-
- // Must be called with the lock held.
- func (s *subscription) maintainMessages(now time.Time) {
- for id, m := range s.msgs {
- // Mark a message as re-deliverable if its ack deadline has expired.
- if m.outstanding() && now.After(m.ackDeadline) {
- m.makeAvailable()
- }
- pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
- if err != nil {
- panic(err)
- }
- // Remove messages that have been undelivered for a long time.
- if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
- delete(s.msgs, id)
- }
- }
- }
-
- func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
- st := &stream{
- sub: s,
- done: make(chan struct{}),
- msgc: make(chan *pb.ReceivedMessage),
- gstream: gs,
- ackTimeout: s.ackTimeout,
- timeout: timeout,
- }
- s.mu.Lock()
- s.streams = append(s.streams, st)
- s.mu.Unlock()
- return st
- }
-
- func (s *subscription) deleteStream(st *stream) {
- s.mu.Lock()
- defer s.mu.Unlock()
- var i int
- for i = 0; i < len(s.streams); i++ {
- if s.streams[i] == st {
- break
- }
- }
- if i < len(s.streams) {
- s.streams = deleteStreamAt(s.streams, i)
- }
- }
- func deleteStreamAt(s []*stream, i int) []*stream {
- // Preserve order for round-robin delivery.
- return append(s[:i], s[i+1:]...)
- }
-
- type message struct {
- proto *pb.ReceivedMessage
- publishTime time.Time
- ackDeadline time.Time
- deliveries *int
- acks *int
- streamIndex int // index of stream that currently owns msg, for round-robin delivery
- }
-
- // A message is outstanding if it is owned by some stream.
- func (m *message) outstanding() bool {
- return !m.ackDeadline.IsZero()
- }
-
- func (m *message) makeAvailable() {
- m.ackDeadline = time.Time{}
- }
-
- type stream struct {
- sub *subscription
- done chan struct{} // closed when the stream is finished
- msgc chan *pb.ReceivedMessage
- gstream pb.Subscriber_StreamingPullServer
- ackTimeout time.Duration
- timeout time.Duration
- }
-
- // pull manages the StreamingPull interaction for the life of the stream.
- func (st *stream) pull(wg *sync.WaitGroup) error {
- errc := make(chan error, 2)
- wg.Add(2)
- go func() {
- defer wg.Done()
- errc <- st.sendLoop()
- }()
- go func() {
- defer wg.Done()
- errc <- st.recvLoop()
- }()
- var tchan <-chan time.Time
- if st.timeout > 0 {
- tchan = time.After(st.timeout)
- }
- // Wait until one of the goroutines returns an error, or we time out.
- var err error
- select {
- case err = <-errc:
- if err == io.EOF {
- err = nil
- }
- case <-tchan:
- }
- close(st.done) // stop the other goroutine
- return err
- }
-
- func (st *stream) sendLoop() error {
- for {
- select {
- case <-st.done:
- return nil
- case rm := <-st.msgc:
- res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
- if err := st.gstream.Send(res); err != nil {
- return err
- }
- }
- }
- }
-
- func (st *stream) recvLoop() error {
- for {
- req, err := st.gstream.Recv()
- if err != nil {
- return err
- }
- st.sub.handleStreamingPullRequest(st, req)
- }
- }
-
- func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
- // Lock the entire server.
- s.mu.Lock()
- defer s.mu.Unlock()
-
- for _, ackID := range req.AckIds {
- s.ack(ackID)
- }
- for i, id := range req.ModifyDeadlineAckIds {
- s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
- }
- if req.StreamAckDeadlineSeconds > 0 {
- st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
- }
- }
-
- // Must be called with the lock held.
- func (s *subscription) ack(id string) {
- m := s.msgs[id]
- if m != nil {
- (*m.acks)++
- delete(s.msgs, id)
- }
- }
-
- // Must be called with the lock held.
- func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
- m := s.msgs[id]
- if m == nil { // already acked: ignore.
- return
- }
- if d == 0 { // nack
- m.makeAvailable()
- } else { // extend the deadline by d
- m.ackDeadline = timeNow().Add(d)
- }
- }
-
- func secsToDur(secs int32) time.Duration {
- return time.Duration(secs) * time.Second
- }
|