// Copyright 2017 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pubsub // This file provides a fake/mock in-memory pubsub server. import ( "io" "sort" "strings" "sync" "time" "cloud.google.com/go/internal/testutil" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" durpb "github.com/golang/protobuf/ptypes/duration" emptypb "github.com/golang/protobuf/ptypes/empty" "golang.org/x/net/context" pb "google.golang.org/genproto/googleapis/pubsub/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type fakeServer struct { pb.PublisherServer pb.SubscriberServer Addr string mu sync.Mutex Acked map[string]bool // acked message IDs Deadlines map[string]int32 // deadlines by message ID pullResponses []*pullResponse wg sync.WaitGroup subs map[string]*pb.Subscription topics map[string]*pb.Topic } type pullResponse struct { msgs []*pb.ReceivedMessage err error } func newFakeServer() (*fakeServer, error) { srv, err := testutil.NewServer() if err != nil { return nil, err } fake := &fakeServer{ Addr: srv.Addr, Acked: map[string]bool{}, Deadlines: map[string]int32{}, subs: map[string]*pb.Subscription{}, topics: map[string]*pb.Topic{}, } pb.RegisterPublisherServer(srv.Gsrv, fake) pb.RegisterSubscriberServer(srv.Gsrv, fake) srv.Start() return fake, nil } // Each call to addStreamingPullMessages results in one StreamingPullResponse. func (s *fakeServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) { s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil}) } func (s *fakeServer) addStreamingPullError(err error) { s.pullResponses = append(s.pullResponses, &pullResponse{nil, err}) } func (s *fakeServer) wait() { s.wg.Wait() } func (s *fakeServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error { s.wg.Add(1) defer s.wg.Done() errc := make(chan error, 1) s.wg.Add(1) go func() { defer s.wg.Done() for { req, err := stream.Recv() if err != nil { errc <- err return } s.mu.Lock() for _, id := range req.AckIds { s.Acked[id] = true } for i, id := range req.ModifyDeadlineAckIds { s.Deadlines[id] = req.ModifyDeadlineSeconds[i] } s.mu.Unlock() } }() // Send responses. for { s.mu.Lock() if len(s.pullResponses) == 0 { s.mu.Unlock() // Nothing to send, so wait for the client to shut down the stream. err := <-errc // a real error, or at least EOF if err == io.EOF { return nil } return err } pr := s.pullResponses[0] s.pullResponses = s.pullResponses[1:] s.mu.Unlock() if pr.err != nil { // Add a slight delay to ensure the server receives any // messages en route from the client before shutting down the stream. // This reduces flakiness of tests involving retry. time.Sleep(200 * time.Millisecond) } if pr.err == io.EOF { return nil } if pr.err != nil { return pr.err } // Return any error from Recv. select { case err := <-errc: return err default: } res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs} if err := stream.Send(res); err != nil { return err } } } const ( minMessageRetentionDuration = 10 * time.Minute maxMessageRetentionDuration = 168 * time.Hour ) var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration) func checkMRD(pmrd *durpb.Duration) error { mrd, err := ptypes.Duration(pmrd) if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) } return nil } func checkAckDeadline(ads int32) error { if ads < 10 || ads > 600 { // PubSub service returns Unknown. return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) } return nil } func (s *fakeServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { for _, id := range req.AckIds { s.Acked[id] = true } return &emptypb.Empty{}, nil } func (s *fakeServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { for _, id := range req.AckIds { s.Deadlines[id] = req.AckDeadlineSeconds } return &emptypb.Empty{}, nil } func (s *fakeServer) CreateSubscription(ctx context.Context, sub *pb.Subscription) (*pb.Subscription, error) { if s.subs[sub.Name] != nil { return nil, status.Errorf(codes.AlreadyExists, "subscription %q", sub.Name) } sub2 := proto.Clone(sub).(*pb.Subscription) if err := checkAckDeadline(sub.AckDeadlineSeconds); err != nil { return nil, err } if sub.MessageRetentionDuration == nil { sub2.MessageRetentionDuration = defaultMessageRetentionDuration } if err := checkMRD(sub2.MessageRetentionDuration); err != nil { return nil, err } if sub.PushConfig == nil { sub2.PushConfig = &pb.PushConfig{} } s.subs[sub.Name] = sub2 return sub2, nil } func (s *fakeServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { if sub := s.subs[req.Subscription]; sub != nil { return sub, nil } return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription) } func (s *fakeServer) UpdateSubscription(ctx context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { sub := s.subs[req.Subscription.Name] if sub == nil { return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name) } for _, path := range req.UpdateMask.Paths { switch path { case "push_config": sub.PushConfig = req.Subscription.PushConfig case "ack_deadline_seconds": a := req.Subscription.AckDeadlineSeconds if err := checkAckDeadline(a); err != nil { return nil, err } sub.AckDeadlineSeconds = a case "retain_acked_messages": sub.RetainAckedMessages = req.Subscription.RetainAckedMessages case "message_retention_duration": if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { return nil, err } sub.MessageRetentionDuration = req.Subscription.MessageRetentionDuration // TODO(jba): labels default: return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) } } return sub, nil } func (s *fakeServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { if s.subs[req.Subscription] == nil { return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription) } delete(s.subs, req.Subscription) return &emptypb.Empty{}, nil } func (s *fakeServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { if s.topics[t.Name] != nil { return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) } t2 := proto.Clone(t).(*pb.Topic) s.topics[t.Name] = t2 return t2, nil } func (s *fakeServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { if t := s.topics[req.Topic]; t != nil { return t, nil } return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) } func (s *fakeServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { if s.topics[req.Topic] == nil { return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) } delete(s.topics, req.Topic) return &emptypb.Empty{}, nil } func (s *fakeServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { var names []string for n := range s.topics { if strings.HasPrefix(n, req.Project) { names = append(names, n) } } sort.Strings(names) from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) if err != nil { return nil, err } res := &pb.ListTopicsResponse{NextPageToken: nextToken} for i := from; i < to; i++ { res.Topics = append(res.Topics, s.topics[names[i]]) } return res, nil } func (s *fakeServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { var names []string for _, sub := range s.subs { if strings.HasPrefix(sub.Name, req.Project) { names = append(names, sub.Name) } } sort.Strings(names) from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) if err != nil { return nil, err } res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} for i := from; i < to; i++ { res.Subscriptions = append(res.Subscriptions, s.subs[names[i]]) } return res, nil } func (s *fakeServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { var names []string for _, sub := range s.subs { if sub.Topic == req.Topic { names = append(names, sub.Name) } } sort.Strings(names) from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) if err != nil { return nil, err } return &pb.ListTopicSubscriptionsResponse{ Subscriptions: names[from:to], NextPageToken: nextToken, }, nil }